""" 定制数据源实现 具体数据源适配器继承framework.data.DataSource """ from framework.data import DataSource, OHLCVData, DataCache import pandas as pd from typing import Dict, List, Optional from datetime import datetime import os import json class LocalFileCache(DataCache): """ 本地文件缓存(定制实现) 支持版本控制和新鲜性检查 """ def __init__(self, cache_dir: str = "data/etf_cache/daily"): """初始化缓存""" self.cache_dir = cache_dir def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]: """从缓存获取数据""" cache_file = os.path.join(self.cache_dir, f"{code}.csv") if not os.path.exists(cache_file): return None try: df = pd.read_csv(cache_file, index_col=0, parse_dates=True) # 过滤日期范围 df = df.loc[start:end] if df.empty: return None return OHLCVData( code=code, data=df, start_date=df.index.min(), end_date=df.index.max() ) except Exception as e: print(f"缓存读取失败: {code} - {e}") return None def set(self, code: str, data: OHLCVData) -> None: """写入缓存""" if data.data is None: return cache_file = os.path.join(self.cache_dir, f"{code}.csv") # 如果已存在,追加新数据 if os.path.exists(cache_file): existing = pd.read_csv(cache_file, index_col=0, parse_dates=True) combined = pd.concat([existing, data.data]).drop_duplicates() combined.to_csv(cache_file) else: data.data.to_csv(cache_file) def is_fresh(self, code: str, max_age_days: int = 1) -> bool: """检查缓存是否新鲜""" meta_file = os.path.join(self.cache_dir, f"{code}.meta.json") if not os.path.exists(meta_file): return False try: with open(meta_file, 'r') as f: meta = json.load(f) last_update = datetime.fromisoformat(meta.get('last_update', '')) age = (datetime.now() - last_update).days return age <= max_age_days except: return False def clear(self, code: Optional[str] = None) -> None: """清空缓存""" if code: cache_file = os.path.join(self.cache_dir, f"{code}.csv") meta_file = os.path.join(self.cache_dir, f"{code}.meta.json") if os.path.exists(cache_file): os.remove(cache_file) if os.path.exists(meta_file): os.remove(meta_file) else: # 清空整个缓存目录 for f in os.listdir(self.cache_dir): os.remove(os.path.join(self.cache_dir, f)) class HybridDataSourceAdapter(DataSource): """ 混合数据源适配器(定制实现) 封装现有的HybridDataSource,适配到框架DataSource接口 """ name = "hybrid" def __init__( self, use_cache: bool = True, cache_dir: str = "data/etf_cache/daily", ssh_config: Optional[Dict] = None ): """初始化混合数据源""" super().__init__(use_cache=use_cache, cache_dir=cache_dir, ssh_config=ssh_config) self.use_cache = use_cache self.cache = LocalFileCache(cache_dir) if use_cache else None self.ssh_config = ssh_config or {} # 内部使用现有的HybridDataSource self._hybrid_source = None def _init_hybrid_source(self): """延迟初始化HybridDataSource""" if self._hybrid_source is None: from core.datasource.hybrid_source import HybridDataSource self._hybrid_source = HybridDataSource( ssh_config=self.ssh_config, use_cache=self.use_cache ) def fetch(self, code: str, start: str, end: str) -> OHLCVData: """获取单个标的数据""" # 先检查缓存 if self.cache and self.cache.is_fresh(code): cached = self.cache.get(code, start, end) if cached: return cached # 从数据源获取 self._init_hybrid_source() # 这里需要根据代码类型判断使用哪个数据源 # 简化实现:直接调用现有HybridDataSource # TODO: 完整实现需要适配现有数据源 return OHLCVData(code=code, data=pd.DataFrame()) def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]: """批量获取数据""" result = {} for code in codes: result[code] = self.fetch(code, start, end) return result def get_supported_codes(self) -> List[str]: """获取支持的代码列表""" # 从fund_basic.csv读取 basic_file = "data/etf_cache/fund_basic.csv" if os.path.exists(basic_file): df = pd.read_csv(basic_file) return df['code'].tolist() return [] class TushareDataSource(DataSource): """ Tushare数据源(定制实现) 用于获取A股数据 """ name = "tushare" def __init__(self, token: Optional[str] = None): """初始化Tushare数据源""" super().__init__(token=token) self.token = token def fetch(self, code: str, start: str, end: str) -> OHLCVData: """获取A股指数数据""" # TODO: 实现Tushare数据获取 return OHLCVData(code=code, data=pd.DataFrame()) def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]: """批量获取""" return {code: self.fetch(code, start, end) for code in codes} class YFinanceDataSource(DataSource): """ YFinance数据源(定制实现) 用于获取港股/美股/加密货币数据 """ name = "yfinance" def __init__(self, use_ssh_tunnel: bool = False, ssh_config: Optional[Dict] = None): """初始化YFinance数据源""" super().__init__(use_ssh_tunnel=use_ssh_tunnel, ssh_config=ssh_config) self.use_ssh_tunnel = use_ssh_tunnel self.ssh_config = ssh_config or {} def fetch(self, code: str, start: str, end: str) -> OHLCVData: """获取境外数据""" # TODO: 实现YFinance数据获取(含SSH隧道) return OHLCVData(code=code, data=pd.DataFrame()) def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]: """批量获取""" return {code: self.fetch(code, start, end) for code in codes} # 导出定制数据源 __all__ = [ 'LocalFileCache', 'HybridDataSourceAdapter', 'TushareDataSource', 'YFinanceDataSource' ]