From b7f7a756b650d1ce1129d462737d61ff34432c76 Mon Sep 17 00:00:00 2001 From: aszerW Date: Sat, 23 May 2026 21:20:43 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20SSH=E9=85=8D=E7=BD=AE=E5=AE=8C?= =?UTF-8?q?=E5=85=A8=E5=B0=81=E8=A3=85=E5=88=B0UniversalDataFetcher?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 变更内容: 1. UniversalDataFetcher 新增方法: - get_ssh_config_from_env(): 从环境变量读取 SSH 配置 - from_env(): 工厂方法,自动读取环境变量创建实例 - get_ssh_status(): 返回 SSH 状态信息字典 2. flask_server.py 简化: - 移除 get_ssh_config() 函数(18行) - 移除 ssh_config 全局变量 - get_fetcher() 使用 from_env() - / 和 /health 路由使用 get_ssh_status() 架构改进: - SSH 配置逻辑完全封装在 UniversalDataFetcher - flask_server.py 只依赖 fetcher 接口 - 减少 24 行重复代码 --- datasource/flask_server.py | 43 +++++++------------------- datasource/universal_fetcher.py | 53 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 33 deletions(-) diff --git a/datasource/flask_server.py b/datasource/flask_server.py index 2d7a926..abcd70c 100644 --- a/datasource/flask_server.py +++ b/datasource/flask_server.py @@ -57,7 +57,6 @@ Compress(app) # 启用 gzip 压缩 # 全局数据获取器实例 fetcher: Optional[UniversalDataFetcher] = None -ssh_config: Optional[Dict] = None # 缓存配置 CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128')) @@ -86,30 +85,12 @@ _ttl_cache: Dict[Tuple, TimedCacheEntry] = {} # 初始化 # ============================================================ -def get_ssh_config() -> Optional[Dict]: - """从环境变量获取 SSH 配置""" - enabled = os.getenv('SSH_ENABLED', 'false').lower() == 'true' - - if not enabled: - return None - - return { - "enabled": True, - "host": os.getenv('SSH_HOST', ''), - "port": int(os.getenv('SSH_PORT', '22')), - "username": os.getenv('SSH_USERNAME', ''), - "key_path": os.getenv('SSH_KEY_PATH', 'hk_ecs.pem'), - "local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')), - } - - def get_fetcher() -> UniversalDataFetcher: - """获取或创建数据获取器实例""" - global fetcher, ssh_config + """获取或创建数据获取器实例(从环境变量读取 SSH 配置)""" + global fetcher if fetcher is None: - ssh_config = get_ssh_config() - fetcher = UniversalDataFetcher(ssh_config=ssh_config) + fetcher = UniversalDataFetcher.from_env() return fetcher @@ -562,12 +543,7 @@ def index(): "crypto": ["BTC", "ETH"], }, "cache_config": get_cache_info(), - "ssh": { - "status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled", - "host": ssh_config.get('host', '') if ssh_config else '', - "required_types": [t.value for t in UniversalDataFetcher.SSH_REQUIRED_TYPES], - "description": "港美股/加密货币数据获取需要 SSH 隧道", - }, + "ssh": get_fetcher().get_ssh_status(), }) @@ -577,7 +553,7 @@ def health(): return jsonify({ "status": "healthy", "timestamp": datetime.now().isoformat(), - "ssh_configured": ssh_config is not None and ssh_config.get('enabled', False), + "ssh": get_fetcher().get_ssh_status(), }) @@ -936,10 +912,11 @@ if __name__ == '__main__': args = parser.parse_args() - # 预加载 SSH 配置 - ssh_config = get_ssh_config() - if ssh_config and ssh_config.get('enabled'): - print(f"✓ SSH 隧道已配置: {ssh_config['host']}:{ssh_config['port']}") + # 预加载 fetcher 并显示 SSH 配置 + f = get_fetcher() + ssh_status = f.get_ssh_status() + if ssh_status['status'] == 'enabled': + print(f"✓ SSH 隧道已配置: {ssh_status['host']}:{ssh_status['port']}") else: print("✗ SSH 隧道未启用(仅支持A股数据)") diff --git a/datasource/universal_fetcher.py b/datasource/universal_fetcher.py index 0b882b8..b3dc0d2 100644 --- a/datasource/universal_fetcher.py +++ b/datasource/universal_fetcher.py @@ -41,6 +41,42 @@ class UniversalDataFetcher: - 对内:各资产类型独立方法,职责单一 """ + @staticmethod + def get_ssh_config_from_env() -> Optional[Dict]: + """ + 从环境变量获取 SSH 配置 + + Returns: + SSH 配置字典或 None + """ + enabled = os.getenv('SSH_ENABLED', 'false').lower() == 'true' + + if not enabled: + return None + + return { + "enabled": True, + "host": os.getenv('SSH_HOST', ''), + "port": int(os.getenv('SSH_PORT', '22')), + "username": os.getenv('SSH_USERNAME', ''), + "key_path": os.getenv('SSH_KEY_PATH', 'hk_ecs.pem'), + "local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')), + } + + @classmethod + def from_env(cls, **kwargs) -> 'UniversalDataFetcher': + """ + 从环境变量创建实例 + + Args: + **kwargs: 其他初始化参数(use_cache, cache_dir 等) + + Returns: + UniversalDataFetcher 实例 + """ + ssh_config = cls.get_ssh_config_from_env() + return cls(ssh_config=ssh_config, **kwargs) + def __init__( self, ssh_config: Optional[Dict] = None, @@ -103,6 +139,23 @@ class UniversalDataFetcher: self._tunnel = None self._tunnel_started = False + def get_ssh_status(self) -> Dict: + """ + 获取 SSH 隧道状态 + + Returns: + SSH 状态信息字典 + """ + enabled = self.ssh_config.get('enabled', False) + return { + "status": "enabled" if enabled else "disabled", + "host": self.ssh_config.get('host', '') if enabled else '', + "port": self.ssh_config.get('port', 22) if enabled else None, + "tunnel_started": self._tunnel_started, + "required_types": [t.value for t in self.SSH_REQUIRED_TYPES], + "description": "港美股/加密货币数据获取需要 SSH 隧道", + } + # ============================================================ # 统一入口(自动路由) # ============================================================