diff --git a/datasource/flask_server.py b/datasource/flask_server.py index eedf9e4..6e85c1b 100644 --- a/datasource/flask_server.py +++ b/datasource/flask_server.py @@ -61,6 +61,9 @@ ssh_config: Optional[Dict] = None CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128')) CACHE_TTL_SECONDS = int(os.getenv('CACHE_TTL_SECONDS', '7200')) # 默认2小时 +# 默认数据起点(下载全量数据时使用) +DEFAULT_START_DATE = os.getenv('DEFAULT_START_DATE', '2015-01-01') + class TimedCacheEntry: """带时间戳的缓存条目""" @@ -113,29 +116,97 @@ def get_fetcher() -> UniversalDataFetcher: # ============================================================ @lru_cache(maxsize=CACHE_MAXSIZE) -def _fetch_data_cached(code: str, start: str, end: str) -> Optional[str]: +def _fetch_full_data_cached(code: str, today: str) -> Optional[str]: """ - 获取数据的缓存版本 - 返回 JSON 序列化的字符串 + 缓存全量数据(从 DEFAULT_START_DATE 到 today) + + 缓存Key: (code, today_date) + - today: 实际的今天日期,用于每日更新缓存 + - 每天下载一次全量数据,避免重复请求 + + Returns: + JSON 序列化的全量数据 """ f = get_fetcher() try: with f: - df = f.fetch(code, start, end) + # 下载全量数据:从默认起点到今天 + df = f.fetch(code, DEFAULT_START_DATE, today) if df is None or len(df) == 0: return None - result = dataframe_to_json(df) - result['code'] = code - result['asset_type'] = AssetTypeDetector.detect(code).value + # 保存为 DataFrame 格式(方便后续切片) + result = { + 'df_json': dataframe_to_json(df), + 'code': code, + 'asset_type': AssetTypeDetector.detect(code).value, + 'data_start': df.index.min().strftime('%Y-%m-%d') if len(df) > 0 else None, + 'data_end': df.index.max().strftime('%Y-%m-%d') if len(df) > 0 else None, + } return json.dumps(result) except Exception as e: return json.dumps({"error": str(e)}) +def _slice_data_from_cache(cached_data: Dict, start: str, end: str) -> Dict: + """ + 从缓存的全量数据中切片指定日期范围 + + Args: + cached_data: 缓存的全量数据 + start: 用户请求的开始日期 + end: 用户请求的结束日期 + + Returns: + 切片后的数据(JSON格式) + """ + if 'df_json' not in cached_data or 'data' not in cached_data['df_json']: + return cached_data + + # 从缓存数据中重建 DataFrame + records = cached_data['df_json']['data'] + if not records: + return cached_data + + # 转换为 DataFrame + df = pd.DataFrame(records) + if 'date' in df.columns: + df['date'] = pd.to_datetime(df['date']) + df = df.set_index('date') + + # 切片日期范围 + start_dt = pd.to_datetime(start) + end_dt = pd.to_datetime(end) + + # 确保索引已排序 + df = df.sort_index() + + # 切片(使用 loc 进行日期范围选择) + sliced_df = df.loc[start_dt:end_dt] + + if len(sliced_df) == 0: + return { + 'data': [], + 'count': 0, + 'code': cached_data['code'], + 'asset_type': cached_data['asset_type'], + 'requested_range': {'start': start, 'end': end}, + 'available_range': {'start': cached_data['data_start'], 'end': cached_data['data_end']}, + } + + # 转换为 JSON 格式 + result = dataframe_to_json(sliced_df) + result['code'] = cached_data['code'] + result['asset_type'] = cached_data['asset_type'] + result['requested_range'] = {'start': start, 'end': end} + result['available_range'] = {'start': cached_data['data_start'], 'end': cached_data['data_end']} + + return result + + def fetch_data_with_ttl( code: str, start: str, @@ -145,56 +216,76 @@ def fetch_data_with_ttl( """ 获取数据,支持 TTL 缓存 + 缓存策略: + - Key: (code, today_date) 缓存全量数据 + - 每天下载一次全量数据(从 DEFAULT_START_DATE 到今天) + - 用户请求时从缓存切片 start-end 范围返回 + Args: code: 标的代码 - start: 开始日期 - end: 结束日期 + start: 用户请求的开始日期 + end: 用户请求的结束日期 nocache: 是否跳过缓存 Returns: - (data, is_cached): 数据和是否命中缓存 + (data, is_cached): 切片后的数据和是否命中缓存 """ - cache_key = (code, start, end) + # 获取今天的实际日期(用于缓存Key) + today = datetime.now().strftime('%Y-%m-%d') + full_cache_key = (code, today) - # 跳过缓存 + # 跳过缓存:清理缓存后重新下载 if nocache: - _fetch_data_cached.cache_clear() - result_json = _fetch_data_cached(code, start, end) - return (json.loads(result_json) if result_json else None, False) + _fetch_full_data_cached.cache_clear() + global _ttl_cache + _ttl_cache.clear() + result_json = _fetch_full_data_cached(code, today) + if result_json is None: + return None, False + full_data = json.loads(result_json) + return (_slice_data_from_cache(full_data, start, end), False) - # 检查 TTL 缓存 - global _ttl_cache - if cache_key in _ttl_cache: - entry = _ttl_cache[cache_key] + # 检查 TTL 缓存(全量数据缓存) + if full_cache_key in _ttl_cache: + entry = _ttl_cache[full_cache_key] if not entry.is_expired(): - return entry.data, True + # 从缓存切片 + sliced_data = _slice_data_from_cache(entry.data, start, end) + return sliced_data, True # 过期,删除 - del _ttl_cache[cache_key] + del _ttl_cache[full_cache_key] - # 从 LRU 缓存获取 - result_json = _fetch_data_cached(code, start, end) + # 从 LRU 缓存获取全量数据 + result_json = _fetch_full_data_cached(code, today) if result_json is None: return None, False - result = json.loads(result_json) + full_data = json.loads(result_json) - # 存入 TTL 缓存 - _ttl_cache[cache_key] = TimedCacheEntry(result) + # 检查是否有错误 + if "error" in full_data: + return full_data, False - return result, False + # 存入 TTL 缓存(存全量数据) + _ttl_cache[full_cache_key] = TimedCacheEntry(full_data) + + # 从全量数据切片返回用户请求的范围 + sliced_data = _slice_data_from_cache(full_data, start, end) + + return sliced_data, False def clear_cache(): """清理所有缓存""" global _ttl_cache - _fetch_data_cached.cache_clear() + _fetch_full_data_cached.cache_clear() _ttl_cache.clear() def get_cache_info() -> Dict: """获取缓存统计信息""" - info = _fetch_data_cached.cache_info() + info = _fetch_full_data_cached.cache_info() return { "lru_cache": { "hits": info.hits, @@ -204,6 +295,8 @@ def get_cache_info() -> Dict: }, "ttl_cache_size": len(_ttl_cache), "ttl_seconds": CACHE_TTL_SECONDS, + "default_start_date": DEFAULT_START_DATE, + "cache_strategy": "full_data_by_code_and_today", }