Newer
Older
safe-algo-pro / tcp / tcp_connection.py
import asyncio
from common.global_logger import logger


# TcpConnection 类只负责连接、重连、收发消息,同时维护一个消息队列
class TcpConnection:
    def __init__(self, ip: str, port: int, timeout: float = 5, reconnect_interval: float = 3):
        self.ip = ip
        self.port = port
        self.timeout = timeout
        self.reconnect_interval = reconnect_interval
        self.reader = None
        self.writer = None
        self.is_connected = False
        self.data_handler = None  # 上层业务处理回调 async function(data: bytes)
        self.send_lock = asyncio.Lock()
        self.read_lock = asyncio.Lock()
        self.message_queue = asyncio.Queue()  # 存放待发送的消息元组 (message, have_response)
        self.response_queue = asyncio.Queue()  # 存放所有读取到的响应数据
        self._read_task = None
        self._message_task = None

    async def connect(self):
        while not self.is_connected:
            try:
                logger.info(f"正在连接 {self.ip}:{self.port}")
                self.reader, self.writer = await asyncio.wait_for(
                    asyncio.open_connection(self.ip, self.port),
                    timeout=self.timeout
                )

                # 验证连接是否真正建立
                if self.writer is None or self.writer.is_closing():
                    raise ConnectionError("连接未能成功建立")

                self.is_connected = True
                logger.info(f"已连接 {self.ip}:{self.port}")

                # 取消可能遗留的任务,再启动新的后台任务
                if self._read_task is not None:
                    self._read_task.cancel()
                if self._message_task is not None:
                    self._message_task.cancel()
                self._read_task = asyncio.create_task(self._read_loop())
                self._message_task = asyncio.create_task(self._process_message_queue())
            except Exception as e:
                logger.error(f"连接 {self.ip}:{self.port} 失败: {e},{self.reconnect_interval}s后重试")
                await asyncio.sleep(self.reconnect_interval)

    async def connection_monitor(self):
        """外部启动的连接监控任务,负责在连接断开时重连"""
        while True:
            if not self.is_connected:
                await self.connect()
            await asyncio.sleep(1)  # 根据需要调整检测间隔

    async def disconnect(self):
        if self.writer:
            self.writer.close()
            try:
                await self.writer.wait_closed()
            except Exception as e:
                logger.error(f"关闭连接异常: {e}")
        self.reader = None
        self.writer = None
        self.is_connected = False
        # 取消后台任务,防止它们继续访问已断开的连接
        if self._read_task:
            self._read_task.cancel()
            self._read_task = None
        if self._message_task:
            self._message_task.cancel()
            self._message_task = None
        logger.info(f"断开连接 {self.ip}:{self.port}")

    async def _read_loop(self):
        while self.is_connected:
            # 如果 self.reader 为 None,等待后再重试
            if self.reader is None:
                await asyncio.sleep(0.1)
                continue
            try:
                async with self.read_lock:
                    data = await asyncio.wait_for(self.reader.read(1024), timeout=self.timeout)
                if data:
                    logger.info(f"从 {self.ip}:{self.port} 收到数据: {data}")
                    await self.response_queue.put(data)
                    if self.data_handler:
                        await self.data_handler(data)
                # else:
                #     logger.warning("未收到数据,断开连接")
                #     await self.disconnect()
                #     await self.connect()
            except Exception as e:
                logger.exception(f"读取数据出错 {self.ip}:{self.port}: {e}")
                await self.disconnect()
                # await self.connect()

    async def _process_message_queue(self):
        while self.is_connected:
            message, have_response = await self.message_queue.get()
            await self._send_message_with_retry(message, have_response)

    async def _send_message_with_retry(self, message: bytes, have_response: bool):
        try:
            async with self.send_lock:
                if not self.is_connected:
                    await self.connect()
                self.writer.write(message)
                await self.writer.drain()
                logger.info(f"向 {self.ip}:{self.port} 发送消息: {message}")
                if have_response:
                    # 等待统一读取任务读取到响应
                    response = await asyncio.wait_for(self.response_queue.get(), timeout=self.timeout)
                    return response
                # if have_response and self.data_handler:
                #     async with self.read_lock:
                #         data = await asyncio.wait_for(self.reader.read(1024), timeout=self.timeout)
                #     await self.data_handler(data)
        except Exception as e:
            logger.exception(f"发送消息失败: {e}")
            # 重新入队,等待重连后再次发送
            await self.message_queue.put((message, have_response))
            await self.disconnect()
            # await self.connect()

    async def send_message(self, message: bytes, have_response: bool = True):
        """将消息放入发送队列"""
        await self.message_queue.put((message, have_response))
        logger.info(f"消息已加入队列: {message}")

    def register_data_handler(self, handler):
        """注册处理接收数据的回调函数,handler 应为 async function(data: bytes)"""
        self.data_handler = handler

    async def start_periodic_query(self, query_command: bytes, interval: float):
        """自动定时发送查询指令,无论连接状态如何均持续执行"""
        while True:
            if not self.is_connected:
                logger.info("当前未连接,等待重连...")
                await asyncio.sleep(1)
                continue
            try:
                # 将查询消息也放入队列,确保顺序一致
                await self.send_message(query_command, have_response=True)
            except Exception as e:
                logger.error(f"定时查询发送失败: {e}")
            await asyncio.sleep(interval)