diff --git a/datasource/flask_api_source.py b/datasource/flask_api_source.py new file mode 100644 index 0000000..affb542 --- /dev/null +++ b/datasource/flask_api_source.py @@ -0,0 +1,293 @@ +""" +Flask API 数据源 + +通过部署后的 Flask API 服务获取 OHLCV 数据 +支持远程调用,无需本地 SSH 隧道 +""" + +import os +import json +import requests +import pandas as pd +from typing import Optional, Dict, List +from datetime import datetime +from pathlib import Path +from dotenv import load_dotenv + +load_dotenv() + + +class FlaskAPIDataSource: + """ + Flask API 数据源 + + 通过 HTTP API 获取数据,无需本地配置 SSH 隧道 + 适用于远程调用或生产环境 + + 用法: + source = FlaskAPIDataSource(base_url="https://k3s.tokenpluse.xyz") + df = source.fetch("000300.SH", "2024-01-01", "2024-12-31") + """ + + def __init__( + self, + base_url: str = None, + api_path: str = "/api/v1/ohlcv", + timeout: int = 120, + retries: int = 3 + ): + """ + 初始化 + + Args: + base_url: API 服务基础地址,默认从环境变量读取 + api_path: API 路径 + timeout: 请求超时时间(秒) + retries: 重试次数 + """ + self.base_url = base_url or os.getenv( + 'FLASK_API_URL', + 'https://k3s.tokenpluse.xyz' + ) + self.api_path = api_path + self.timeout = timeout + self.retries = retries + + # 确保 base_url 不以 / 结尾 + self.base_url = self.base_url.rstrip('/') + + def fetch( + self, + code: str, + start_date: str, + end_date: str, + asset_type: str = None, + timeframe: str = '1d' + ) -> Optional[pd.DataFrame]: + """ + 获取单只标的 OHLCV 数据 + + Args: + code: 标的代码 + start_date: 开始日期 YYYY-MM-DD + end_date: 结束日期 YYYY-MM-DD + asset_type: 资产类型(可选,用于覆盖自动检测) + timeframe: K线周期(加密货币需要) + + Returns: + DataFrame with columns: date, open, high, low, close, volume + """ + # 构建请求 URL + url = f"{self.base_url}{self.api_path}" + + # 构建请求参数 + params = { + 'code': code, + 'start': start_date, + 'end': end_date, + } + + # 加密货币需要 timeframe 参数 + if asset_type == 'crypto' or code.upper() in ['BTC', 'ETH']: + params['timeframe'] = timeframe + + # 可选:强制指定 asset_type + if asset_type: + params['asset_type'] = asset_type + + for attempt in range(self.retries): + try: + response = requests.get( + url, + params=params, + timeout=self.timeout + ) + + if response.status_code != 200: + if attempt < self.retries - 1: + continue + print(f"✗ API请求失败: {response.status_code} - {response.text[:100]}") + return None + + data = response.json() + + # 检查错误 + if 'error' in data: + print(f"✗ API返回错误: {data['error']}") + return None + + # 解析数据 + records = data.get('data', []) + if not records: + print(f"⚠ {code}: 无数据返回") + return None + + # 转换为 DataFrame + df = pd.DataFrame(records) + + # 处理日期列 + if 'date' in df.columns: + df['date'] = pd.to_datetime(df['date']) + df = df.set_index('date') + + # 确保列名标准化 + df = df[['open', 'high', 'low', 'close', 'volume']] + + # 缓存 info 信息(如果有) + if 'info' in data: + df.attrs['info'] = data['info'] + + print(f"✓ {code}: {len(df)} 条数据 ({start_date} ~ {end_date})") + return df + + except requests.exceptions.Timeout: + if attempt < self.retries - 1: + print(f"⚠ {code}: 请求超时,重试 {attempt + 2}/{self.retries}") + continue + print(f"✗ {code}: 请求超时") + return None + + except requests.exceptions.RequestException as e: + if attempt < self.retries - 1: + continue + print(f"✗ {code}: 请求异常 - {e}") + return None + + except json.JSONDecodeError as e: + print(f"✗ {code}: JSON解析失败 - {e}") + return None + + return None + + def fetch_batch( + self, + codes: List[str], + start_date: str, + end_date: str, + asset_types: Dict[str, str] = None + ) -> Dict[str, Optional[pd.DataFrame]]: + """ + 批量获取多只标的数据 + + Args: + codes: 标的代码列表 + start_date: 开始日期 + end_date: 结束日期 + asset_types: 资产类型映射 {code: asset_type} + + Returns: + {code: DataFrame} + """ + results = {} + asset_types = asset_types or {} + + print(f"从 Flask API 获取 {len(codes)} 只标的...") + + for i, code in enumerate(codes, 1): + asset_type = asset_types.get(code) + df = self.fetch(code, start_date, end_date, asset_type) + results[code] = df + + # 显示进度 + if i % 5 == 0 or i == len(codes): + success = sum(1 for v in results.values() if v is not None) + print(f" 进度: {i}/{len(codes)} (成功: {success})") + + return results + + def fetch_etf_nav( + self, + code: str, + start_date: str, + end_date: str + ) -> Optional[pd.DataFrame]: + """ + 获取 ETF 净值数据 + + Args: + code: ETF代码 + start_date: 开始日期 + end_date: 结束日期 + + Returns: + DataFrame with nav column + """ + url = f"{self.base_url}/api/v1/etf/nav" + + params = { + 'code': code, + 'start': start_date, + 'end': end_date + } + + try: + response = requests.get(url, params=params, timeout=self.timeout) + + if response.status_code != 200: + return None + + data = response.json() + + if 'error' in data or 'nav' not in data: + return None + + nav_data = data.get('nav', {}) + records = nav_data.get('data', []) + + if not records: + return None + + df = pd.DataFrame(records) + if 'date' in df.columns: + df['date'] = pd.to_datetime(df['date']) + df = df.set_index('date') + + return df + + except Exception as e: + print(f"✗ {code} 净值获取失败: {e}") + return None + + def get_health(self) -> Dict: + """获取服务健康状态""" + # 先尝试 ohlcv 端点检查服务是否可用 + url = f"{self.base_url}{self.api_path}" + params = {'code': '000300.SH', 'start': '2024-01-01', 'end': '2024-01-05'} + + try: + response = requests.get(url, params=params, timeout=self.timeout) + if response.status_code == 200: + data = response.json() + return { + 'status': 'healthy', + 'ssh_configured': True, + 'available': True + } + else: + return {'status': 'error', 'available': False} + except Exception as e: + return {'status': 'error', 'message': str(e), 'available': False} + + def get_service_info(self) -> Dict: + """获取服务信息""" + url = f"{self.base_url}/" + + try: + response = requests.get(url, timeout=10) + return response.json() + except Exception as e: + return {"error": str(e)} + + +# 全局实例 +_flask_api_source: Optional[FlaskAPIDataSource] = None + + +def get_flask_api_source(base_url: str = None) -> FlaskAPIDataSource: + """获取 Flask API 数据源实例""" + global _flask_api_source + + if _flask_api_source is None: + _flask_api_source = FlaskAPIDataSource(base_url=base_url) + + return _flask_api_source \ No newline at end of file diff --git a/strategies/rotation/strategy.py b/strategies/rotation/strategy.py index d50255e..cf6b103 100644 --- a/strategies/rotation/strategy.py +++ b/strategies/rotation/strategy.py @@ -105,8 +105,14 @@ class RotationStrategy(StrategyBase): group_mapping[code] = cfg.get('market', 'default') return group_mapping - def get_data(self) -> dict: - """获取数据(使用新数据源模块)""" + def get_data(self, use_flask_api: bool = True) -> dict: + """ + 获取数据 + + Args: + use_flask_api: 是否使用 Flask API 服务获取数据(默认 True) + False 则使用本地 HybridDataSource + """ code_list_config = self.config.get('code_list', {}) benchmark_config = self.config.get('benchmark', {}) benchmark_code = benchmark_config.get('code', '000300.SH') @@ -114,7 +120,123 @@ class RotationStrategy(StrategyBase): if not code_list_config: raise ValueError("配置中未找到 code_list") - # 使用新数据源模块 + # 获取 Flask API 地址 + flask_api_url = self.config.get('flask_api_url') + + if use_flask_api: + # 使用 Flask API 服务获取数据(远程调用) + return self._get_data_from_flask_api( + code_list_config, + benchmark_code, + flask_api_url + ) + else: + # 使用本地 HybridDataSource(需要本地 SSH 隧道) + return self._get_data_from_local( + code_list_config, + benchmark_code + ) + + def _get_data_from_flask_api( + self, + code_list_config: dict, + benchmark_code: str, + flask_api_url: str = None + ) -> dict: + """通过 Flask API 服务获取数据""" + from datasource.flask_api_source import FlaskAPIDataSource + + # 初始化 Flask API 数据源 + api_source = FlaskAPIDataSource(base_url=flask_api_url) + + # 检查服务状态 + health = api_source.get_health() + if health.get('status') != 'healthy': + print(f"⚠ Flask API 服务状态: {health}") + else: + print(f"✓ Flask API 服务正常 (SSH: {health.get('ssh_configured', False)})") + + # 获取指数代码列表 + index_codes = list(code_list_config.keys()) + + # 获取 ETF 代码映射 + etf_code_map = {} + etf_codes = [] + for index_code, cfg in code_list_config.items(): + if isinstance(cfg, dict) and cfg.get('etf'): + etf_code_map[index_code] = cfg['etf'] + etf_codes.append(cfg['etf']) + + # 获取指数 OHLCV 数据 + print(f"\n获取指数数据 ({len(index_codes)} 只)...") + index_ohlcv_data = api_source.fetch_batch( + index_codes, + self.start_date, + self.end_date + ) + + # 过滤有效代码 + valid_codes = [code for code, df in index_ohlcv_data.items() if df is not None and len(df) > 0] + print(f"有效指数: {len(valid_codes)} 只") + + # 获取 ETF 价格数据 + print(f"\n获取 ETF 数据 ({len(etf_codes)} 只)...") + etf_ohlcv_data = api_source.fetch_batch( + etf_codes, + self.start_date, + self.end_date + ) + + # 转换为宽格式 DataFrame + etf_data = None + if etf_ohlcv_data: + etf_close_dict = {} + for etf_code, df in etf_ohlcv_data.items(): + if df is not None and 'close' in df.columns: + etf_close_dict[etf_code] = df['close'] + if etf_close_dict: + etf_data = pd.DataFrame(etf_close_dict) + + # 获取基准数据 + print(f"\n获取基准数据 ({benchmark_code})...") + benchmark_ohlcv = api_source.fetch(benchmark_code, self.start_date, self.end_date) + benchmark_data = None + if benchmark_ohlcv is not None: + benchmark_data = benchmark_ohlcv['close'] + + # 构建指数收盘价宽格式 DataFrame(用于因子计算) + index_close_dict = {} + for code in valid_codes: + df = index_ohlcv_data.get(code) + if df is not None and 'close' in df.columns: + index_close_dict[code] = df['close'] + index_close = pd.DataFrame(index_close_dict) if index_close_dict else None + + # 获取 ETF 净值数据(用于溢价率计算) + print(f"\n获取 ETF 净值数据...") + etf_nav_data = {} + for etf_code in etf_codes: + nav_df = api_source.fetch_etf_nav(etf_code, self.start_date, self.end_date) + if nav_df is not None: + etf_nav_data[etf_code] = nav_df + print(f"有效净值: {len(etf_nav_data)} 只") + + return { + 'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame} + 'index_close': index_close, # 对齐后的收盘价(宽格式) + 'etf_data': etf_data, # ETF 收盘价(宽格式) + 'etf_nav_data': etf_nav_data, # ETF 净值数据 {code: DataFrame} + 'benchmark_data': benchmark_data, # 基准收盘价 Series + 'valid_codes': valid_codes, # 有效指数代码列表 + 'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射 + } + + def _get_data_from_local( + self, + code_list_config: dict, + benchmark_code: str + ) -> dict: + """使用本地 HybridDataSource 获取数据""" from datasource import HybridDataSource ssh_config = self.config.get('ssh_tunnel', {})