Newer
Older
go-algo-server / tcp_client.py
import threading
import socket
import queue
import time
import asyncio

from global_logger import logger


class AsyncTCPClient:
    def __init__(self, server_ip, server_port):
        self.server_ip = server_ip
        self.server_port = server_port
        self.writer = None
        self.reader = None
        self.queue = asyncio.Queue()
        self.connected = False
        self.running = True
        self.command_callbacks = []

    def register_command_handler(self, callback):
        """注册命令处理回调函数
        
        Args:
            callback: 异步回调函数,接收消息字符串作为参数
        """
        self.command_callbacks.append(callback)
        logger.info(f"已注册TCP命令处理回调函数: {callback.__name__}")

    async def connect(self):
        # 先关闭现有的writer
        if self.writer is not None:
            try:
                self.writer.close()
                await self.writer.wait_closed()
            except Exception as e:
                logger.exception(f"关闭现有连接时出错: {e}")
            self.writer = None

        while self.running:
            try:
                logger.info(f"连接 TCP 服务器 {self.server_ip}:{self.server_port}")
                self.reader, self.writer = await asyncio.open_connection(self.server_ip, self.server_port)
                self.connected = True
                logger.info("TCP 连接成功")
                return
            except Exception as e:
                logger.exception(f"TCP 连接失败: {e}, 5 秒后重试")
                self.connected = False
                await asyncio.sleep(5)

    async def send_loop(self):
        while self.running:
            if not self.connected:
                logger.info(f"TCP 服务器 {self.server_ip}:{self.server_port}连接断开,重新连接")
                await self.connect()
            try:
                message, msg_time, expire_second = await self.queue.get()
                # 检查消息是否已过期
                current_time = time.time()
                if current_time - msg_time > expire_second:
                    logger.info(f"丢弃过期消息: {message}, 已过期 {current_time - msg_time:.1f} 秒")
                    continue

                self.writer.write(message.encode('utf-8'))
                await self.writer.drain()
                logger.info(f"TCP 发送数据: {message}")
            except Exception as e:
                logger.exception(f"TCP 发送失败: {e}")
                self.connected = False
                await asyncio.sleep(5)

    async def send(self, message: str, expire_second=60):
        """向队列添加消息,带过期时间
        Args:
            message: 要发送的消息
            expire_second: 消息过期时间(秒),默认60秒
        """
        logger.debug(f"添加消息到TCP队列: {message}")
        await self.queue.put((message, time.time(), expire_second))

    async def receive_loop(self):
        """接收服务器消息的循环"""
        while self.running:
            if not self.connected:
                logger.info(f"TCP 服务器 {self.server_ip}:{self.server_port}连接断开,重新连接")
                await self.connect()
            try:
                if self.reader:
                    data = await self.reader.read(1024)
                    if not data:
                        logger.info("服务器断开连接")
                        self.connected = False
                        continue
                    
                    message = data.decode('utf-8').strip()
                    logger.info(f"收到TCP服务器消息: {message}")
                    
                    # 调用所有注册的回调函数处理消息
                    for callback in self.command_callbacks:
                        try:
                            await callback(message)
                        except Exception as e:
                            logger.exception(f"处理TCP消息时出错: {e}")
            except Exception as e:
                logger.exception(f"TCP 接收失败: {e}")
                self.connected = False
                await asyncio.sleep(5)


class TCPClient:
    """TCP客户端管理"""

    def __init__(self, server_ip, server_port):
        self.server_ip = server_ip
        self.server_port = server_port
        self.sock = None
        self.queue = queue.Queue()
        self.lock = threading.Lock()
        self.connected = False
        self.running = True
        self.send_thread = None

    def connect(self):
        """建立TCP连接"""
        while self.running:
            try:
                logger.info(f"连接 TCP 服务器 {self.server_ip}:{self.server_port}")
                self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                self.sock.connect((self.server_ip, self.server_port))
                self.connected = True
                logger.info("TCP 连接成功")
                return True
            except Exception as e:
                logger.exception(f"TCP 连接失败: {e}, 5 秒后重试")
                self.connected = False
                time.sleep(5)
        return False

    def send_data(self, data: str):
        """线程安全的数据发送"""
        with self.lock:
            if not self.connected:
                if not self.connect():
                    return False

            try:
                self.sock.sendall(data.encode('utf-8'))
                logger.info(f"TCP 发送数据: {data}")
                return True
            except Exception as e:
                logger.exception(f"TCP 发送失败: {e}")
                self.connected = False
                self.sock.close()
                self.sock = None
                return False

    def send_loop(self):
        """发送循环处理队列数据"""
        while self.running:
            try:
                if not self.connected:
                    logger.info(f"TCP 服务器 {self.server_ip}:{self.server_port}连接断开,重新连接")
                    self.connect()

                message, msg_time, expire_second = self.queue.get(timeout=1)
                # 检查消息是否已过期
                current_time = time.time()
                if current_time - msg_time > expire_second:
                    logger.info(f"丢弃过期消息: {message}, 已过期 {current_time - msg_time:.1f} 秒")
                    continue

                if not self.send_data(message):
                    self.queue.put((message, msg_time, expire_second))  # 重新放回队列?
                    time.sleep(5)
            except queue.Empty:
                continue
            except Exception as e:
                logger.exception(f"TCP 工作线程错误: {e}")

    def start(self):
        """启动发送线程"""
        self.send_thread = threading.Thread(target=self.send_loop, daemon=True)
        self.send_thread.start()
        return self

    def send(self, message: str, expire_second=60):
        """向队列添加消息,带过期时间
        
        Args:
            message: 要发送的消息
            expire_second: 消息过期时间(秒),默认60秒
        """
        logger.debug(f"添加消息到TCP队列: {message}")
        self.queue.put((message, time.time(), expire_second))

    def stop(self):
        """停止服务"""
        self.running = False
        if self.sock:
            self.sock.close()