diff --git a/app.py b/app.py index 1508e3d..7f0cf4b 100644 --- a/app.py +++ b/app.py @@ -12,6 +12,7 @@ from services.tts import close_all_tts_services from session import cleanup_inactive_sessions +from utils.http_client import close_http_client from websocket.handler import handle_websocket_connection # Module-level cache for HTML content @@ -54,6 +55,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # Shutdown: Cleanup resources and cancel tasks await close_all_tts_services() + await close_http_client() # Close shared HTTP client cleanup_task.cancel() try: await cleanup_task diff --git a/config.py b/config.py index e5f8394..ac2314c 100644 --- a/config.py +++ b/config.py @@ -4,6 +4,8 @@ from dotenv import load_dotenv from loguru import logger +from utils.security import mask_sensitive + # Load environment variables load_dotenv() @@ -144,6 +146,12 @@ def get_service_config(cls, service_type: str) -> Dict[str, Any]: return config + @classmethod + def get_service_config_masked(cls, service_type: str) -> Dict[str, str]: + """Get provider-specific configuration with sensitive data masked (for logging)""" + config = cls.get_service_config(service_type) + return mask_sensitive(config) + # Validate configuration Config.validate() diff --git a/requirements.txt b/requirements.txt index 143e2de..65719c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ fastapi>=0.109.0 uvicorn>=0.25.0 python-dotenv>=1.0.0 azure-cognitiveservices-speech>=1.31.0 -httpx>=0.25.2 +httpx[http2]>=0.25.2 openai>=1.11.0 async-timeout>=4.0.3 websockets>=11.0.3 diff --git a/services/tts/azure_tts.py b/services/tts/azure_tts.py index 68ceca9..3345503 100644 --- a/services/tts/azure_tts.py +++ b/services/tts/azure_tts.py @@ -9,13 +9,13 @@ from config import Config from services.tts.base import BaseTTSService +from utils.http_client import HTTPClientManager class AzureTTSService(BaseTTSService): """Azure TTS服务实现""" # 全局资源 - _http_client: Optional[httpx.AsyncClient] = None # 共享HTTP客户端 active_tasks: Set[asyncio.Task] = set() # 活动任务集合,用于中断 def __init__(self, subscription_key: str, region: str, voice_name: str = Config.AZURE_TTS_VOICE) -> None: @@ -38,14 +38,12 @@ def __init__(self, subscription_key: str, region: str, voice_name: str = Config. @classmethod async def get_http_client(cls) -> httpx.AsyncClient: - """获取或创建HTTP客户端 + """获取共享HTTP客户端(使用连接池) Returns: HTTP客户端实例 """ - if cls._http_client is None or (cls._http_client is not None and cls._http_client.is_closed): - cls._http_client = httpx.AsyncClient() - return cls._http_client + return await HTTPClientManager.get_client() async def synthesize_text(self, text: str, websocket: WebSocket, is_first: bool = False) -> None: """将文本合成为语音并发送到客户端 @@ -249,11 +247,7 @@ async def close_all(cls) -> None: """关闭所有TTS资源""" # 取消所有活动任务 await cls.interrupt_all() - - # 关闭HTTP客户端 - if cls._http_client is not None and not cls._http_client.is_closed: - await cls._http_client.aclose() - cls._http_client = None + # HTTP client is managed by HTTPClientManager, no need to close here async def close(self) -> None: """关闭TTS服务,释放资源""" diff --git a/services/tts/minimax_tts.py b/services/tts/minimax_tts.py index fa96067..a999638 100644 --- a/services/tts/minimax_tts.py +++ b/services/tts/minimax_tts.py @@ -9,13 +9,13 @@ from loguru import logger from services.tts.base import BaseTTSService +from utils.http_client import HTTPClientManager class MiniMaxTTSService(BaseTTSService): """MiniMax TTS服务实现""" # 全局资源 - _http_client: Optional[httpx.AsyncClient] = None # 共享HTTP客户端 active_tasks: Set[asyncio.Task] = set() # 活动任务集合,用于中断 def __init__(self, api_key: str, voice_id: str = "male-qn-qingse") -> None: @@ -47,16 +47,12 @@ def __init__(self, api_key: str, voice_id: str = "male-qn-qingse") -> None: @classmethod async def get_http_client(cls) -> httpx.AsyncClient: - """获取或创建HTTP客户端 + """获取共享HTTP客户端(使用连接池) Returns: HTTP客户端实例 """ - if cls._http_client is None or (cls._http_client is not None and cls._http_client.is_closed): - # 设置超时参数 - timeout = httpx.Timeout(30.0, connect=10.0) - cls._http_client = httpx.AsyncClient(timeout=timeout) - return cls._http_client + return await HTTPClientManager.get_client() async def synthesize_text(self, text: str, websocket: WebSocket, is_first: bool = False) -> None: """将文本合成为语音并发送到客户端 @@ -369,11 +365,7 @@ async def close_all(cls) -> None: """关闭所有MiniMax TTS资源""" # 中断所有活动任务 await cls.interrupt_all() - - # 关闭HTTP客户端 - if cls._http_client is not None and not cls._http_client.is_closed: - await cls._http_client.aclose() - cls._http_client = None + # HTTP client is managed by HTTPClientManager, no need to close here async def close(self) -> None: """关闭当前TTS服务实例""" diff --git a/session.py b/session.py index e61fd6f..7d19ca5 100644 --- a/session.py +++ b/session.py @@ -1,12 +1,16 @@ import asyncio import time import uuid +from threading import RLock from typing import Any, Dict, List, Optional from loguru import logger from config import Config +# Thread-safe lock for session dictionary +_sessions_lock = RLock() + class SessionState: """Manages user session state and pipeline resources""" @@ -98,25 +102,27 @@ def _clear_queues(self) -> None: def get_session(session_id: str) -> SessionState: - """Get or create session state""" - if session_id not in _sessions: - _sessions[session_id] = SessionState(session_id) - - # Update activity timestamp - _sessions[session_id].update_activity() - return _sessions[session_id] + """Get or create session state (thread-safe)""" + with _sessions_lock: + if session_id not in _sessions: + _sessions[session_id] = SessionState(session_id) + # Update activity timestamp + _sessions[session_id].update_activity() + return _sessions[session_id] def remove_session(session_id: str) -> None: - """Remove a session""" - if session_id in _sessions: - del _sessions[session_id] - logger.info(f"Session removed: {session_id}") + """Remove a session (thread-safe)""" + with _sessions_lock: + if session_id in _sessions: + del _sessions[session_id] + logger.info(f"Session removed: {session_id}") def get_all_sessions() -> Dict[str, SessionState]: - """Get all active sessions""" - return _sessions + """Get a copy of all active sessions (thread-safe)""" + with _sessions_lock: + return _sessions.copy() async def cleanup_inactive_sessions() -> None: @@ -125,17 +131,20 @@ async def cleanup_inactive_sessions() -> None: try: await asyncio.sleep(60) # Check every minute - inactive_session_ids = [ - session_id - for session_id, state in _sessions.items() - if state.is_inactive() - ] + # Get inactive sessions with lock + with _sessions_lock: + inactive_session_ids = [ + session_id + for session_id, state in _sessions.items() + if state.is_inactive() + ] for session_id in inactive_session_ids: logger.info(f"Cleaning up inactive session: {session_id}") try: - if _sessions[session_id].tts_processor: - await _sessions[session_id].tts_processor.interrupt() + session = get_session(session_id) + if session.tts_processor: + await session.tts_processor.interrupt() except Exception as e: logger.error(f"Error interrupting TTS processor: {e}") diff --git a/static/js/app.js b/static/js/app.js index 8c81664..fe4fa7e 100644 --- a/static/js/app.js +++ b/static/js/app.js @@ -218,6 +218,28 @@ class EventManager { ui.EventBinder.bindButtonClick('stop-btn', SessionManager.endConversation); ui.EventBinder.bindButtonClick('reset-btn', SessionManager.resetConversation); + // 绑定文本输入事件 + ui.EventBinder.bindTextInput( + // 输入变化处理 - 根据输入内容和连接状态更新发送按钮 + () => { + const hasText = ui.elements.textInput?.value.trim().length > 0; + const isConnected = websocketHandler.getSocket()?.readyState === WebSocket.OPEN; + ui.StateManager.updateSendButtonState(hasText && isConnected); + }, + // 提交处理 - 发送文本消息 + () => { + const text = ui.elements.textInput?.value.trim(); + if (text && websocketHandler.getSocket()?.readyState === WebSocket.OPEN) { + // 发送文本输入命令 + websocketHandler.sendCommand('text_input', { text }); + // 清空输入框 + ui.elements.textInput.value = ''; + // 禁用发送按钮 + ui.StateManager.updateSendButtonState(false); + } + } + ); + // 音频上下文恢复 ui.EventBinder.bindAudioContextResume(() => { const audioContext = audioProcessor.getAudioContext(); @@ -255,7 +277,17 @@ class EventManager { */ function init() { try { - websocketHandler.initializeWebSocket(ui.StateManager.updateStatus, ui.elements.startButton); + // 连接成功回调 - 根据输入框内容更新发送按钮状态 + const onConnected = () => { + const hasText = ui.elements.textInput?.value.trim().length > 0; + ui.StateManager.updateSendButtonState(hasText); + }; + + websocketHandler.initializeWebSocket( + ui.StateManager.updateStatus, + ui.elements.startButton, + onConnected + ); EventManager.initialize(); } catch (error) { console.error('应用初始化错误:', error); diff --git a/static/js/websocket-handler.js b/static/js/websocket-handler.js index 7e6ef93..54bbe24 100644 --- a/static/js/websocket-handler.js +++ b/static/js/websocket-handler.js @@ -82,9 +82,11 @@ const websocketHandler = { * 建立与服务器的WebSocket连接,并设置各种事件处理器 * @param {Function} updateStatus - 状态更新函数,用于更新UI状态 * @param {HTMLButtonElement} startButton - 开始按钮元素,用于控制按钮状态 + * @param {Function} onConnected - 连接成功回调函数(可选) */ - initializeWebSocket(updateStatus, startButton) { + initializeWebSocket(updateStatus, startButton, onConnected = null) { this.statusCallback = updateStatus; + this._onConnectedCallback = onConnected; const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; const wsUrl = `${protocol}//${window.location.host}/ws`; @@ -102,6 +104,10 @@ const websocketHandler = { this._updateStatus('idle', '已连接,准备就绪'); startButton.disabled = false; audioProcessor.initAudioContext(); + // 触发连接成功回调 + if (this._onConnectedCallback) { + this._onConnectedCallback(); + } }; // 消息处理 @@ -116,7 +122,7 @@ const websocketHandler = { // 延迟重连 setTimeout(() => { console.log('尝试重新连接...'); - this.initializeWebSocket(updateStatus, startButton); + this.initializeWebSocket(updateStatus, startButton, onConnected); }, WS_CONFIG.RECONNECT_DELAY); }; @@ -370,13 +376,13 @@ const websocketHandler = { /** * 发送命令到服务器 - * @param {string} command - 命令名称 + * @param {string} command - 命令名称 (type字段) * @param {Object} commandData - 命令附加数据 */ sendCommand(command, commandData = {}) { if (this.socket?.readyState === WebSocket.OPEN) { const message = { - command, + type: command, // 后端期望 'type' 字段 ...commandData }; this.socket.send(JSON.stringify(message)); diff --git a/utils/http_client.py b/utils/http_client.py new file mode 100644 index 0000000..53b4b10 --- /dev/null +++ b/utils/http_client.py @@ -0,0 +1,95 @@ +"""HTTP client manager with connection pooling and thread safety""" + +import asyncio +from typing import Optional + +import httpx +from loguru import logger + + +class HTTPClientManager: + """Manages a shared HTTP client with connection pooling + + Thread-safe singleton pattern ensures only one client instance is created, + avoiding connection leaks and improving performance through connection reuse. + """ + + _client: Optional[httpx.AsyncClient] = None + _lock: asyncio.Lock = asyncio.Lock() + + @classmethod + async def get_client( + cls, + timeout: float = 30.0, + connect_timeout: float = 10.0, + max_keepalive_connections: int = 50, + max_connections: int = 100, + ) -> httpx.AsyncClient: + """Get or create a shared HTTP client + + Args: + timeout: Total request timeout in seconds + connect_timeout: Connection timeout in seconds + max_keepalive_connections: Maximum keep-alive connections + max_connections: Maximum total connections + + Returns: + Shared httpx.AsyncClient instance + """ + # Fast path: client already exists and is open + if cls._client is not None and not cls._client.is_closed: + return cls._client + + # Slow path: need to create client (with lock for thread safety) + async with cls._lock: + # Double-check after acquiring lock + if cls._client is not None and not cls._client.is_closed: + return cls._client + + # Create new client with connection pooling + timeout_config = httpx.Timeout( + timeout, + connect=connect_timeout, + pool=connect_timeout, + ) + limits = httpx.Limits( + max_keepalive_connections=max_keepalive_connections, + max_connections=max_connections, + ) + + cls._client = httpx.AsyncClient( + timeout=timeout_config, + limits=limits, + http2=True, # Enable HTTP/2 for better performance + ) + + logger.debug( + f"Created HTTP client: max_connections={max_connections}, " + f"keepalive={max_keepalive_connections}" + ) + + return cls._client + + @classmethod + async def close(cls) -> None: + """Close the shared HTTP client and release resources""" + async with cls._lock: + if cls._client is not None and not cls._client.is_closed: + await cls._client.aclose() + logger.debug("HTTP client closed") + cls._client = None + + @classmethod + def is_available(cls) -> bool: + """Check if the HTTP client is available""" + return cls._client is not None and not cls._client.is_closed + + +async def get_http_client() -> httpx.AsyncClient: + """Convenience function to get the shared HTTP client""" + return await HTTPClientManager.get_client() + + +async def close_http_client() -> None: + """Convenience function to close the shared HTTP client""" + await HTTPClientManager.close() diff --git a/utils/security.py b/utils/security.py new file mode 100644 index 0000000..00bd533 --- /dev/null +++ b/utils/security.py @@ -0,0 +1,108 @@ +"""Security utilities for handling sensitive data""" + +from typing import Any, Dict, Set + + +class SensitiveDataMasker: + """Utility class for masking sensitive data in logs and outputs""" + + # Keys that indicate sensitive values + SENSITIVE_KEYS: Set[str] = { + "api_key", + "apikey", + "api-key", + "subscription_key", + "key", + "token", + "password", + "pwd", + "secret", + "bearer_token", + "authorization", + "credential", + "credentials", + } + + @classmethod + def is_sensitive_key(cls, key: str) -> bool: + """Check if a key name indicates sensitive data""" + key_lower = key.lower().replace("-", "_") + return any(sensitive in key_lower for sensitive in cls.SENSITIVE_KEYS) + + @classmethod + def mask_value(cls, value: Any, key: str = "") -> str: + """Mask a sensitive value, showing only first and last 2 characters + + Args: + value: The value to potentially mask + key: The key name (used to determine if masking is needed) + + Returns: + Masked string if sensitive, otherwise original string representation + """ + if not isinstance(value, str): + value = str(value) + + if not value: + return value + + # Check if this is a sensitive key + if key and cls.is_sensitive_key(key): + if len(value) <= 4: + return "*" * len(value) + return value[:2] + "*" * (len(value) - 4) + value[-2:] + + return value + + @classmethod + def mask_dict(cls, data: Dict[str, Any]) -> Dict[str, str]: + """Mask all sensitive values in a dictionary + + Args: + data: Dictionary potentially containing sensitive values + + Returns: + New dictionary with sensitive values masked + """ + masked = {} + for key, value in data.items(): + if isinstance(value, dict): + masked[key] = cls.mask_dict(value) + else: + masked[key] = cls.mask_value(value, key) + return masked + + @classmethod + def mask_url(cls, url: str) -> str: + """Mask sensitive parameters in a URL + + Args: + url: URL that may contain sensitive query parameters + + Returns: + URL with sensitive parameters masked + """ + if not url: + return url + + # Simple approach: mask common patterns + import re + + # Mask api_key, key, token parameters + patterns = [ + (r"(api_key=)[^&]+", r"\1****"), + (r"(key=)[^&]+", r"\1****"), + (r"(token=)[^&]+", r"\1****"), + (r"(secret=)[^&]+", r"\1****"), + ] + + masked_url = url + for pattern, replacement in patterns: + masked_url = re.sub(pattern, replacement, masked_url, flags=re.IGNORECASE) + + return masked_url + + +def mask_sensitive(data: Dict[str, Any]) -> Dict[str, str]: + """Convenience function to mask sensitive data in a dictionary""" + return SensitiveDataMasker.mask_dict(data) diff --git a/websocket/handler.py b/websocket/handler.py index 19ad12a..bcfa909 100644 --- a/websocket/handler.py +++ b/websocket/handler.py @@ -82,14 +82,29 @@ async def _setup_asr_service(self, websocket: WebSocket, session_id: str, loop: return asr_service async def _handle_messages(self, websocket: WebSocket, asr_service: BaseASRService, session_id: str) -> None: - """Process incoming WebSocket messages""" + """Process incoming WebSocket messages with timeout protection""" + # Timeout for receiving messages (seconds) + # Audio streams should send data frequently, so 60s is generous + MESSAGE_TIMEOUT = 60 + while True: try: - data = await websocket.receive() + # Add timeout to prevent zombie connections + data = await asyncio.wait_for( + websocket.receive(), + timeout=MESSAGE_TIMEOUT + ) if "bytes" in data: await self._handle_audio_data(data["bytes"], asr_service, session_id) elif "text" in data: await self._handle_text_command(data["text"], websocket, asr_service, session_id) + except asyncio.TimeoutError: + logger.warning(f"WebSocket timeout (no message in {MESSAGE_TIMEOUT}s), closing: {session_id}") + try: + await websocket.close(code=1000, reason="Timeout - no activity") + except Exception: + pass + break except WebSocketDisconnect: break except Exception as e: @@ -134,12 +149,15 @@ async def _handle_text_command(self, text: str, websocket: WebSocket, asr_servic "stop": self._handle_stop_command, "start": lambda ws, asr, sid: asr.start_recognition(), "reset": self._handle_reset_command, - "interrupt": self._handle_interrupt_command + "interrupt": self._handle_interrupt_command, } handler = command_handlers.get(command.type) if handler: await handler(websocket, asr_service, session_id) + elif command.type == "text_input": + # Handle text input separately since it needs the text parameter + await self._handle_text_input_command(websocket, command.text, session_id) except json.JSONDecodeError as e: logger.error(f"Invalid JSON in command: {e}") @@ -208,6 +226,25 @@ async def _handle_interrupt_command(self, websocket: WebSocket, asr_service: Bas else: logger.error(f"Cannot get session {session_id}, unable to process interrupt command") + async def _handle_text_input_command(self, websocket: WebSocket, text: str, session_id: str) -> None: + """Handle text input command - send text directly to LLM (bypassing ASR)""" + if not text.strip(): + logger.warning(f"Empty text input received, session ID: {session_id}") + return + + logger.info(f"Text input received: '{text[:50]}...', session ID: {session_id}") + + # Send transcript to client (echo back user input) - use final_transcript type + await websocket.send_json({ + "type": "final_transcript", + "content": text, + "is_partial": False, + "session_id": session_id + }) + + # Process the text through LLM pipeline + await process_final_transcript(websocket, text, session_id) + async def _cleanup(self, websocket: WebSocket, asr_service: BaseASRService, session_id: str, pipeline: PipelineHandler) -> None: """Clean up resources when connection ends""" # Stop ASR service diff --git a/websocket/models.py b/websocket/models.py index 30cec69..466383b 100644 --- a/websocket/models.py +++ b/websocket/models.py @@ -1,6 +1,6 @@ """WebSocket message models for validation using Pydantic""" -from typing import Literal, Optional +from typing import Literal, Optional, Union from pydantic import BaseModel @@ -8,7 +8,7 @@ class WebSocketCommand(BaseModel): """Base model for WebSocket commands from client""" - type: Literal["stop", "start", "reset", "interrupt"] + type: Literal["stop", "start", "reset", "interrupt", "text_input"] class StopCommand(WebSocketCommand): @@ -35,6 +35,13 @@ class InterruptCommand(WebSocketCommand): type: Literal["interrupt"] = "interrupt" +class TextInputCommand(WebSocketCommand): + """Text input command to send text directly to LLM (bypassing ASR)""" + + type: Literal["text_input"] = "text_input" + text: str + + class WebSocketResponse(BaseModel): """Base model for WebSocket responses to client""" @@ -94,7 +101,7 @@ class LLMStreamResponse(WebSocketResponse): is_final: bool = False -def parse_command(data: dict) -> Optional[WebSocketCommand]: +def parse_command(data: dict) -> Optional[Union[WebSocketCommand, "TextInputCommand"]]: """Parse and validate a WebSocket command Args: @@ -110,6 +117,7 @@ def parse_command(data: dict) -> Optional[WebSocketCommand]: "start": StartCommand, "reset": ResetCommand, "interrupt": InterruptCommand, + "text_input": TextInputCommand, } model_class = command_models.get(cmd_type)