Newer
Older
go-algo-server / tcp_client.py
zhangyingjie on 6 Mar 5 KB 初版提交(本地测试版)
import threading
import socket
import queue
import time
import asyncio

from global_logger import logger


class AsyncTCPClient:
    def __init__(self, server_ip, server_port):
        self.server_ip = server_ip
        self.server_port = server_port
        self.writer = None
        self.queue = asyncio.Queue()
        self.connected = False
        self.running = True

    async def connect(self):
        # 先关闭现有的writer
        if self.writer is not None:
            try:
                self.writer.close()
                await self.writer.wait_closed()
            except Exception as e:
                logger.exception(f"关闭现有连接时出错: {e}")
            self.writer = None

        while self.running:
            try:
                logger.info(f"连接 TCP 服务器 {self.server_ip}:{self.server_port}")
                _, self.writer = await asyncio.open_connection(self.server_ip, self.server_port)
                self.connected = True
                logger.info("TCP 连接成功")
                return
            except Exception as e:
                logger.exception(f"TCP 连接失败: {e}, 5 秒后重试")
                self.connected = False
                await asyncio.sleep(5)

    async def send_loop(self):
        while self.running:
            if not self.connected:
                logger.info(f"TCP 服务器 {self.server_ip}:{self.server_port}连接断开,重新连接")
                await self.connect()
            try:
                message, msg_time, expire_second = await self.queue.get()
                # 检查消息是否已过期
                current_time = time.time()
                if current_time - msg_time > expire_second:
                    logger.info(f"丢弃过期消息: {message}, 已过期 {current_time - msg_time:.1f} 秒")
                    continue

                self.writer.write(message.encode('utf-8'))
                await self.writer.drain()
                logger.info(f"TCP 发送数据: {message}")
            except Exception as e:
                logger.exception(f"TCP 发送失败: {e}")
                self.connected = False
                await asyncio.sleep(5)

    async def send(self, message: str, expire_second=60):
        """向队列添加消息,带过期时间
        Args:
            message: 要发送的消息
            expire_second: 消息过期时间(秒),默认60秒
        """
        logger.debug(f"添加消息到TCP队列: {message}")
        await self.queue.put((message, time.time(), expire_second))


class TCPClient:
    """TCP客户端管理"""

    def __init__(self, server_ip, server_port):
        self.server_ip = server_ip
        self.server_port = server_port
        self.sock = None
        self.queue = queue.Queue()
        self.lock = threading.Lock()
        self.connected = False
        self.running = True
        self.send_thread = None

    def connect(self):
        """建立TCP连接"""
        while self.running:
            try:
                logger.info(f"连接 TCP 服务器 {self.server_ip}:{self.server_port}")
                self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                self.sock.connect((self.server_ip, self.server_port))
                self.connected = True
                logger.info("TCP 连接成功")
                return True
            except Exception as e:
                logger.exception(f"TCP 连接失败: {e}, 5 秒后重试")
                self.connected = False
                time.sleep(5)
        return False

    def send_data(self, data: str):
        """线程安全的数据发送"""
        with self.lock:
            if not self.connected:
                if not self.connect():
                    return False

            try:
                self.sock.sendall(data.encode('utf-8'))
                logger.info(f"TCP 发送数据: {data}")
                return True
            except Exception as e:
                logger.exception(f"TCP 发送失败: {e}")
                self.connected = False
                self.sock.close()
                self.sock = None
                return False

    def send_loop(self):
        """发送循环处理队列数据"""
        while self.running:
            try:
                if not self.connected:
                    logger.info(f"TCP 服务器 {self.server_ip}:{self.server_port}连接断开,重新连接")
                    self.connect()

                message, msg_time, expire_second = self.queue.get(timeout=1)
                # 检查消息是否已过期
                current_time = time.time()
                if current_time - msg_time > expire_second:
                    logger.info(f"丢弃过期消息: {message}, 已过期 {current_time - msg_time:.1f} 秒")
                    continue

                if not self.send_data(message):
                    self.queue.put((message, msg_time, expire_second))  # 重新放回队列?
                    time.sleep(5)
            except queue.Empty:
                continue
            except Exception as e:
                logger.exception(f"TCP 工作线程错误: {e}")

    def start(self):
        """启动发送线程"""
        self.send_thread = threading.Thread(target=self.send_loop, daemon=True)
        self.send_thread.start()
        return self

    def send(self, message: str, expire_second=60):
        """向队列添加消息,带过期时间
        
        Args:
            message: 要发送的消息
            expire_second: 消息过期时间(秒),默认60秒
        """
        logger.debug(f"添加消息到TCP队列: {message}")
        self.queue.put((message, time.time(), expire_second))

    def stop(self):
        """停止服务"""
        self.running = False
        if self.sock:
            self.sock.close()