Newer
Older
safe-algo-pro / tcp / tcp_server.py
# TCP服务器类
import asyncio
from typing import Dict, List, Callable, Optional, Any, Type
import traceback
from common.global_logger import logger


class TcpServer:
    """TCP服务器,处理设备连接和数据接收"""

    def __init__(self, host: str = "0.0.0.0", port: int = 9001):
        self.host = host
        self.port = port
        self.server = None
        self.clients: Dict[str, asyncio.StreamWriter] = {}
        self.on_data_callbacks: List[Callable] = []

    def register_data_callback(self, callback: Callable) -> None:
        """注册数据回调函数"""
        self.on_data_callbacks.append(callback)

    async def start(self) -> None:
        """启动TCP服务器"""
        try:
            self.server = await asyncio.start_server(
                self._handle_client, self.host, self.port
            )
            addr = self.server.sockets[0].getsockname()
            logger.info(f"TCP服务器启动在 {addr}")

            async with self.server:
                await self.server.serve_forever()
        except Exception as e:
            logger.error(f"启动TCP服务器时出错: {e}")
            logger.error(traceback.format_exc())

    async def _handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
        """处理客户端连接"""
        addr = writer.get_extra_info("peername")
        client_id = f"{addr[0]}:{addr[1]}"
        self.clients[client_id] = writer
        logger.info(f"客户端连接: {client_id}")

        try:
            while True:
                # 接收数据
                data = await reader.read(1024)
                if not data:
                    logger.info(f"连接关闭: {client_id}")
                    break

                try:
                    message = data.decode('utf-8')
                    logger.info(f"收到数据({client_id}): {repr(message)}")
                    await self._process_data(message)
                except UnicodeDecodeError as e:
                    logger.error(f"无法解析消息来自 {client_id}: {e}")
                    logger.debug(f"原始数据: {data}")
        except ConnectionResetError:
            logger.info(f"客户端断开: {client_id}")
        except Exception as e:
            logger.error(f"处理客户端 {client_id} 时出错: {e}")
            logger.error(traceback.format_exc())
        finally:
            if client_id in self.clients:
                del self.clients[client_id]
            writer.close()
            await writer.wait_closed()

    async def _process_data(self, data: str) -> None:
        """处理接收到的数据"""
        # 触发数据回调
        for callback in self.on_data_callbacks:
            try:
                if asyncio.iscoroutinefunction(callback):
                    await callback(data)
                else:
                    callback(data)
            except Exception as e:
                logger.error(f"执行回调时出错: {e}")

    async def send_data(self, client_id: str, data: bytes) -> bool:
        """发送数据到客户端"""
        if client_id not in self.clients:
            logger.warning(f"客户端 {client_id} 未连接")
            return False

        writer = self.clients[client_id]
        try:
            writer.write(data)
            await writer.drain()
            return True
        except Exception as e:
            logger.error(f"向客户端 {client_id} 发送数据时出错: {e}")
            logger.error(traceback.format_exc())
            return False

    async def broadcast(self, data: bytes, exclude: Optional[List[str]] = None) -> None:
        """广播数据到所有客户端"""
        exclude = exclude or []
        for client_id, writer in self.clients.items():
            if client_id not in exclude:
                try:
                    writer.write(data)
                    await writer.drain()
                except Exception as e:
                    logger.error(f"向客户端 {client_id} 广播时出错: {e}")