From 7f2af6b4709fb8cebc35e29f5866bfe503364ec4 Mon Sep 17 00:00:00 2001 From: aszerW Date: Sat, 23 May 2026 18:32:20 +0800 Subject: [PATCH] =?UTF-8?q?refactor(flask=5Fapi):=20fetch=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0adj=E5=8F=82=E6=95=B0=EF=BC=8Cfetch=5Fwith=5Fadj?= =?UTF-8?q?=E7=AE=80=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FlaskAPIDataSource.fetch() 新增 adj 参数,fetch_with_adj() 简化 - FlaskAPIDataSource.fetch(adj='raw'): 请求参数包含 adj - fetch_with_adj(): 简化为 return self.fetch(adj=adj)(减少 ~120行) - flask_server.py: 缓存逻辑已支持 adj 参数,无需修改 --- datasource/flask_api_source.py | 47 +++++++++++++++++++- datasource/flask_server.py | 78 +++++++++++++++++++++++++++------- 2 files changed, 108 insertions(+), 17 deletions(-) diff --git a/datasource/flask_api_source.py b/datasource/flask_api_source.py index 8185986..87eb40c 100644 --- a/datasource/flask_api_source.py +++ b/datasource/flask_api_source.py @@ -61,30 +61,41 @@ class FlaskAPIDataSource: code: str, start_date: str, end_date: str, + adj: str = 'raw', asset_type: str = None, timeframe: str = '1d' ) -> Optional[pd.DataFrame]: """ - 获取单只标的 OHLCV 数据 + 获取单只标的 OHLCV 数据(支持 adj 参数) Args: code: 标的代码 start_date: 开始日期 YYYY-MM-DD end_date: 结束日期 YYYY-MM-DD + adj: 复权类型 'raw'(原始) / 'qfq'(前复权) / 'hfq'(后复权),默认 'raw' asset_type: 资产类型(可选,用于覆盖自动检测) timeframe: K线周期(加密货币需要) Returns: DataFrame with columns: date, open, high, low, close, volume + adj='hfq' 时 A股 ETF 会额外返回 adj_factor, close_hfq + + 示例: + # 原始价格 + df = source.fetch("000300.SH", "2020-01-01", "2024-12-31") + + # A股股票后复权 + df = source.fetch("000001.SZ", "2020-01-01", "2024-12-31", adj='hfq') """ # 构建请求 URL url = f"{self.base_url}{self.api_path}" - # 构建请求参数 + # 构建请求参数(包含 adj) params = { 'code': code, 'start': start_date, 'end': end_date, + 'adj': adj, # 添加 adj 参数 } # 加密货币需要 timeframe 参数 @@ -296,6 +307,38 @@ class FlaskAPIDataSource: print(f"✗ {code} 净值获取失败: {e}") return None + def fetch_with_adj( + self, + code: str, + start_date: str, + end_date: str, + adj: str = 'raw', + asset_type: str = None, + timeframe: str = '1d' + ) -> Optional[pd.DataFrame]: + """ + 获取 OHLCV 数据(支持复权参数)- 简化版 + + 直接调用 fetch(adj=adj),无需重复实现。 + + Args: + code: 标的代码 + start_date: 开始日期 YYYY-MM-DD + end_date: 结束日期 YYYY-MM-DD + adj: 复权参数(raw/qfq/hfq),默认 'raw' + asset_type: 资产类型(可选) + timeframe: K线周期(加密货币需要) + + Returns: + DataFrame,结构因 adj 参数略有不同 + + 示例: + # A股股票后复权 + df = source.fetch_with_adj("000001.SZ", "2020-01-01", "2024-12-31", adj='hfq') + """ + # 直接调用 fetch,传递 adj 参数 + return self.fetch(code, start_date, end_date, adj, asset_type, timeframe) + def get_health(self) -> Dict: """获取服务健康状态""" # 先尝试 ohlcv 端点检查服务是否可用 diff --git a/datasource/flask_server.py b/datasource/flask_server.py index f6e936e..5288b31 100644 --- a/datasource/flask_server.py +++ b/datasource/flask_server.py @@ -119,16 +119,18 @@ def get_fetcher() -> UniversalDataFetcher: # ============================================================ @lru_cache(maxsize=CACHE_MAXSIZE) -def _fetch_full_data_cached(code: str, today: str) -> Optional[str]: +def _fetch_full_data_cached(code: str, today: str, adj: str = 'raw') -> Optional[str]: """ 缓存全量数据(仅日级别数据) 缓存策略: - 日级别数据(股票/指数/ETF/期货): 从 DEFAULT_START_DATE 到 today - 加密货币: 不缓存,每次实时下载 + - 不同 adj 参数(raw/qfq/hfq)独立缓存 - 缓存Key: (code, today_date) + 缓存Key: (code, today_date, adj) - today: 实际的今天日期,用于每日更新缓存 + - adj: 复权参数,不同复权类型独立缓存 Returns: JSON 序列化的全量数据(仅日级别数据) @@ -142,19 +144,25 @@ def _fetch_full_data_cached(code: str, today: str) -> Optional[str]: if asset_type == AssetType.CRYPTO: return None # 不缓存加密货币 + # 校验 adj 参数是否适用于该资产类型 + valid_adj = UniversalDataFetcher.VALID_ADJ_BY_TYPE.get(asset_type, ['raw']) + if adj not in valid_adj: + return json.dumps({"error": f"adj='{adj}' 不适用于 {asset_type.value}"}) + try: with f: - # 下载数据:从默认起点到今天 - df = f.fetch(code, DEFAULT_START_DATE, today) + # 使用 fetch_with_adj 获取数据(支持复权) + df = f.fetch_with_adj(code, DEFAULT_START_DATE, today, adj) if df is None or len(df) == 0: return None # 保存为 DataFrame 格式(方便后续切片) result = { - 'df_json': dataframe_to_json(df), + 'df_json': dataframe_to_json(df, asset_type.value), 'code': code, 'asset_type': asset_type.value, + 'adj': adj, 'data_start': df.index.min().strftime('%Y-%m-%d') if len(df) > 0 else None, 'data_end': df.index.max().strftime('%Y-%m-%d') if len(df) > 0 else None, 'cache_strategy': 'full_history', @@ -190,6 +198,7 @@ def _slice_data_from_cache(cached_data: Dict, start: str, end: str) -> Dict: 'count': 0, 'code': cached_data['code'], 'asset_type': cached_data['asset_type'], + 'adj': cached_data.get('adj', 'raw'), 'requested_range': {'start': start, 'end': end}, 'available_range': {'start': cached_data['data_start'], 'end': cached_data['data_end']}, } @@ -222,6 +231,7 @@ def _slice_data_from_cache(cached_data: Dict, start: str, end: str) -> Dict: result = dataframe_to_json(sliced_df) result['code'] = cached_data['code'] result['asset_type'] = cached_data['asset_type'] + result['adj'] = cached_data.get('adj', 'raw') result['requested_range'] = {'start': start, 'end': end} result['available_range'] = {'start': cached_data['data_start'], 'end': cached_data['data_end']} @@ -233,14 +243,16 @@ def fetch_data_with_ttl( start: str, end: str, nocache: bool = False, - timeframe: str = '1d' + timeframe: str = '1d', + adj: str = 'raw' ) -> Tuple[Optional[Dict], bool]: """ 获取数据,支持 TTL 缓存(加密货币不缓存) 缓存策略: - - 日级别数据(股票/指数/ETF/期货): Key=(code, today), 缓存全量数据,切片返回 + - 日级别数据(股票/指数/ETF/期货): Key=(code, today, adj), 缓存全量数据,切片返回 - 加密货币: 每次实时下载,不缓存,必须指定 timeframe + - 不同 adj 参数独立缓存 Args: code: 标的代码 @@ -248,6 +260,7 @@ def fetch_data_with_ttl( end: 用户请求的结束日期 nocache: 是否跳过缓存 timeframe: K线周期(仅加密货币需要) + adj: 复权参数(raw/qfq/hfq) Returns: (data, is_cached): 数据和是否命中缓存 @@ -269,6 +282,7 @@ def fetch_data_with_ttl( result = dataframe_to_json(df, asset_type.value) result['code'] = code result['asset_type'] = asset_type.value + result['adj'] = 'raw' # 加密货币无复权 result['cache_strategy'] = 'no_cache_crypto' result['requested_range'] = {'start': start, 'end': end} result['timeframe'] = timeframe @@ -276,15 +290,20 @@ def fetch_data_with_ttl( except Exception as e: return {'error': str(e), 'code': code, 'asset_type': asset_type.value}, False - # 日级别数据:使用缓存 - full_cache_key = (code, today) + # 校验 adj 参数 + valid_adj = UniversalDataFetcher.VALID_ADJ_BY_TYPE.get(asset_type, ['raw']) + if adj not in valid_adj: + return {'error': f"adj='{adj}' 不适用于 {asset_type.value},支持: {valid_adj}", 'code': code, 'asset_type': asset_type.value}, False + + # 日级别数据:使用缓存(缓存 Key 包含 adj) + full_cache_key = (code, today, adj) # 跳过缓存:清理缓存后重新下载 if nocache: _fetch_full_data_cached.cache_clear() global _ttl_cache _ttl_cache.clear() - result_json = _fetch_full_data_cached(code, today) + result_json = _fetch_full_data_cached(code, today, adj) if result_json is None: return None, False full_data = json.loads(result_json) @@ -301,7 +320,7 @@ def fetch_data_with_ttl( del _ttl_cache[full_cache_key] # 从 LRU 缓存获取全量数据 - result_json = _fetch_full_data_cached(code, today) + result_json = _fetch_full_data_cached(code, today, adj) if result_json is None: return None, False @@ -552,11 +571,19 @@ def get_ohlcv(): asset_type: 资产类型 (optional, 强制覆盖自动检测结果) - china_index: 中国指数 - china_etf: 中国ETF + - china_stock: 中国股票 - us_index: 美股指数 + - us_stock: 美股股票 - hk_index: 港股指数 + - hk_stock: 港股股票 - futures: 期货 - crypto: 加密货币 注:指定后会覆盖自动检测,用于修复检测逻辑问题 + adj: 复权参数 (optional, 默认raw) + - raw: 原始价格(所有资产类型) + - qfq: 前复权(A股股票/美股股票/港股股票) + - hfq: 后复权(A股股票/ETF/美股股票/港股股票) + 注:不同资产类型支持的adj值不同,非法组合返回400错误 timeframe: K线周期 (optional, 仅加密货币需要) - 1d: 日线(默认) - 1h: 小时线 @@ -569,6 +596,7 @@ def get_ohlcv(): start = request.args.get('start', '').strip() end = request.args.get('end', '').strip() asset_type_param = request.args.get('asset_type', '').strip().lower() + adj = request.args.get('adj', 'raw').strip().lower() timeframe = request.args.get('timeframe', '1d').strip().lower() nocache = request.args.get('nocache', 'false').lower() == 'true' @@ -577,7 +605,15 @@ def get_ohlcv(): return jsonify({ "error": "Missing required parameter: code", "example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31", - "asset_type_hint": "可选 asset_type 参数强制指定类型", + "adj_hint": "可选 adj 参数获取复权数据(raw/qfq/hfq)", + }), 400 + + # adj 参数验证 + if adj not in ['raw', 'qfq', 'hfq']: + return jsonify({ + "error": f"Invalid adj parameter: {adj}", + "valid_adj": ['raw', 'qfq', 'hfq'], + "hint": "adj 必须是 raw/qfq/hfq", }), 400 # 设置默认日期 @@ -607,6 +643,15 @@ def get_ohlcv(): "valid_types": [t.value for t in AssetType], }), 400 + # 校验 adj 是否适用于该资产类型 + valid_adj = UniversalDataFetcher.VALID_ADJ_BY_TYPE.get(final_type, ['raw']) + if adj not in valid_adj: + return jsonify({ + "error": f"adj='{adj}' 不适用于 {final_type.value}", + "valid_adj": valid_adj, + "hint": f"{final_type.value} 仅支持复权类型: {valid_adj}", + }), 400 + # 加密货币必须指定 timeframe(无论自动检测还是手动指定) if final_type == AssetType.CRYPTO: valid_timeframes = ['1d', '1h', '4h', '15m', '1m', 'daily', 'hourly'] @@ -618,12 +663,13 @@ def get_ohlcv(): }), 400 # 使用缓存获取数据(加密货币不缓存) - result, is_cached = fetch_data_with_ttl(code, start, end, nocache, timeframe) + result, is_cached = fetch_data_with_ttl(code, start, end, nocache, timeframe, adj) if result is None: return jsonify({ "code": code, "asset_type": final_type.value, + "adj": adj, "detected_type": detected_type.value if asset_type_param else None, # 仅当用户指定时显示 "error": "No data available", "start": start, @@ -634,15 +680,17 @@ def get_ohlcv(): return jsonify({ "code": code, "asset_type": final_type.value, + "adj": adj, "detected_type": detected_type.value if asset_type_param else None, "error": result["error"], }), 500 result['cached'] = is_cached result['asset_type'] = final_type.value # 使用最终类型 + result['adj'] = adj # 返回使用的 adj 参数 - # 如果是中国 ETF,自动附加净值和溢价率数据 - if final_type == AssetType.CHINA_ETF: + # 如果是中国 ETF 且 adj=raw,自动附加净值和溢价率数据 + if final_type == AssetType.CHINA_ETF and adj == 'raw': try: f = get_fetcher() with f: