diff --git a/framework/__init__.py b/framework/__init__.py index 5ab2962..b2119c0 100644 --- a/framework/__init__.py +++ b/framework/__init__.py @@ -22,6 +22,9 @@ from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExe # 配置层 from framework.config import ConfigLoader, StrategyConfig +# 数据层抽象 +from framework.data import OHLCVData, DataSource, DataCache + __all__ = [ # 因子层 @@ -49,4 +52,9 @@ __all__ = [ # 配置层 'ConfigLoader', 'StrategyConfig', + + # 数据层 + 'OHLCVData', + 'DataSource', + 'DataCache', ] \ No newline at end of file diff --git a/framework/data/__init__.py b/framework/data/__init__.py new file mode 100644 index 0000000..9a507d8 --- /dev/null +++ b/framework/data/__init__.py @@ -0,0 +1,126 @@ +""" +数据层抽象接口(通用) + +只提供数据获取抽象接口,具体实现在strategies/shared/data/ +""" + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Any +import pandas as pd +from dataclasses import dataclass +from datetime import datetime + + +@dataclass +class OHLCVData: + """ + OHLCV数据结构(通用) + + 标准化的K线数据格式 + """ + code: str + name: str = "" + start_date: datetime = None + end_date: datetime = None + + # OHLCV数据DataFrame + data: pd.DataFrame = None + + @property + def length(self) -> int: + """数据长度""" + return len(self.data) if self.data is not None else 0 + + def validate(self) -> bool: + """验证数据完整性""" + if self.data is None or self.data.empty: + return False + + required_cols = ['close'] + return all(col in self.data.columns for col in required_cols) + + def __repr__(self) -> str: + return f"OHLCVData(code={self.code}, name={self.name}, length={self.length})" + + +class DataSource(ABC): + """ + 数据源抽象接口 + + 所有数据源必须实现fetch方法 + """ + + name: str = "base" + + def __init__(self, **params): + """初始化数据源参数""" + self._params = params + + @abstractmethod + def fetch(self, code: str, start: str, end: str) -> OHLCVData: + """ + 获取单个标的的OHLCV数据 + + Args: + code: 标的代码 + start: 开始日期 (YYYY-MM-DD) + end: 结束日期 (YYYY-MM-DD) + + Returns: + OHLCVData对象 + """ + pass + + @abstractmethod + def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]: + """ + 批量获取多个标的的OHLCV数据 + + Args: + codes: 标的代码列表 + start: 开始日期 + end: 结束日期 + + Returns: + {code: OHLCVData}字典 + """ + pass + + def get_supported_codes(self) -> List[str]: + """获取支持的数据源代码列表""" + return [] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(name={self.name})" + + +class DataCache(ABC): + """ + 数据缓存抽象接口(通用) + + 支持本地缓存管理 + """ + + @abstractmethod + def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]: + """从缓存获取数据""" + pass + + @abstractmethod + def set(self, code: str, data: OHLCVData) -> None: + """写入缓存""" + pass + + @abstractmethod + def is_fresh(self, code: str, max_age_days: int = 1) -> bool: + """检查缓存是否新鲜""" + pass + + @abstractmethod + def clear(self, code: Optional[str] = None) -> None: + """清空缓存""" + pass + + +# 导出抽象接口 +__all__ = ['OHLCVData', 'DataSource', 'DataCache'] \ No newline at end of file diff --git a/framework/data/__pycache__/__init__.cpython-312.pyc b/framework/data/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..db21a09 Binary files /dev/null and b/framework/data/__pycache__/__init__.cpython-312.pyc differ diff --git a/strategies/shared/data/__init__.py b/strategies/shared/data/__init__.py new file mode 100644 index 0000000..4d41264 --- /dev/null +++ b/strategies/shared/data/__init__.py @@ -0,0 +1,18 @@ +""" +定制数据源统一入口 +""" + +from strategies.shared.data.sources import ( + LocalFileCache, + HybridDataSourceAdapter, + TushareDataSource, + YFinanceDataSource +) + + +__all__ = [ + 'LocalFileCache', + 'HybridDataSourceAdapter', + 'TushareDataSource', + 'YFinanceDataSource' +] \ No newline at end of file diff --git a/strategies/shared/data/sources.py b/strategies/shared/data/sources.py new file mode 100644 index 0000000..127adad --- /dev/null +++ b/strategies/shared/data/sources.py @@ -0,0 +1,223 @@ +""" +定制数据源实现 + +具体数据源适配器继承framework.data.DataSource +""" + +from framework.data import DataSource, OHLCVData, DataCache +import pandas as pd +from typing import Dict, List, Optional +from datetime import datetime +import os +import json + + +class LocalFileCache(DataCache): + """ + 本地文件缓存(定制实现) + + 支持版本控制和新鲜性检查 + """ + + def __init__(self, cache_dir: str = "data/etf_cache/daily"): + """初始化缓存""" + self.cache_dir = cache_dir + + def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]: + """从缓存获取数据""" + cache_file = os.path.join(self.cache_dir, f"{code}.csv") + + if not os.path.exists(cache_file): + return None + + try: + df = pd.read_csv(cache_file, index_col=0, parse_dates=True) + + # 过滤日期范围 + df = df.loc[start:end] + + if df.empty: + return None + + return OHLCVData( + code=code, + data=df, + start_date=df.index.min(), + end_date=df.index.max() + ) + except Exception as e: + print(f"缓存读取失败: {code} - {e}") + return None + + def set(self, code: str, data: OHLCVData) -> None: + """写入缓存""" + if data.data is None: + return + + cache_file = os.path.join(self.cache_dir, f"{code}.csv") + + # 如果已存在,追加新数据 + if os.path.exists(cache_file): + existing = pd.read_csv(cache_file, index_col=0, parse_dates=True) + combined = pd.concat([existing, data.data]).drop_duplicates() + combined.to_csv(cache_file) + else: + data.data.to_csv(cache_file) + + def is_fresh(self, code: str, max_age_days: int = 1) -> bool: + """检查缓存是否新鲜""" + meta_file = os.path.join(self.cache_dir, f"{code}.meta.json") + + if not os.path.exists(meta_file): + return False + + try: + with open(meta_file, 'r') as f: + meta = json.load(f) + + last_update = datetime.fromisoformat(meta.get('last_update', '')) + age = (datetime.now() - last_update).days + + return age <= max_age_days + except: + return False + + def clear(self, code: Optional[str] = None) -> None: + """清空缓存""" + if code: + cache_file = os.path.join(self.cache_dir, f"{code}.csv") + meta_file = os.path.join(self.cache_dir, f"{code}.meta.json") + + if os.path.exists(cache_file): + os.remove(cache_file) + if os.path.exists(meta_file): + os.remove(meta_file) + else: + # 清空整个缓存目录 + for f in os.listdir(self.cache_dir): + os.remove(os.path.join(self.cache_dir, f)) + + +class HybridDataSourceAdapter(DataSource): + """ + 混合数据源适配器(定制实现) + + 封装现有的HybridDataSource,适配到框架DataSource接口 + """ + + name = "hybrid" + + def __init__( + self, + use_cache: bool = True, + cache_dir: str = "data/etf_cache/daily", + ssh_config: Optional[Dict] = None + ): + """初始化混合数据源""" + super().__init__(use_cache=use_cache, cache_dir=cache_dir, ssh_config=ssh_config) + self.use_cache = use_cache + self.cache = LocalFileCache(cache_dir) if use_cache else None + self.ssh_config = ssh_config or {} + + # 内部使用现有的HybridDataSource + self._hybrid_source = None + + def _init_hybrid_source(self): + """延迟初始化HybridDataSource""" + if self._hybrid_source is None: + from core.datasource.hybrid_source import HybridDataSource + self._hybrid_source = HybridDataSource( + ssh_config=self.ssh_config, + use_cache=self.use_cache + ) + + def fetch(self, code: str, start: str, end: str) -> OHLCVData: + """获取单个标的数据""" + # 先检查缓存 + if self.cache and self.cache.is_fresh(code): + cached = self.cache.get(code, start, end) + if cached: + return cached + + # 从数据源获取 + self._init_hybrid_source() + + # 这里需要根据代码类型判断使用哪个数据源 + # 简化实现:直接调用现有HybridDataSource + # TODO: 完整实现需要适配现有数据源 + + return OHLCVData(code=code, data=pd.DataFrame()) + + def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]: + """批量获取数据""" + result = {} + for code in codes: + result[code] = self.fetch(code, start, end) + return result + + def get_supported_codes(self) -> List[str]: + """获取支持的代码列表""" + # 从fund_basic.csv读取 + basic_file = "data/etf_cache/fund_basic.csv" + if os.path.exists(basic_file): + df = pd.read_csv(basic_file) + return df['code'].tolist() + return [] + + +class TushareDataSource(DataSource): + """ + Tushare数据源(定制实现) + + 用于获取A股数据 + """ + + name = "tushare" + + def __init__(self, token: Optional[str] = None): + """初始化Tushare数据源""" + super().__init__(token=token) + self.token = token + + def fetch(self, code: str, start: str, end: str) -> OHLCVData: + """获取A股指数数据""" + # TODO: 实现Tushare数据获取 + return OHLCVData(code=code, data=pd.DataFrame()) + + def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]: + """批量获取""" + return {code: self.fetch(code, start, end) for code in codes} + + +class YFinanceDataSource(DataSource): + """ + YFinance数据源(定制实现) + + 用于获取港股/美股/加密货币数据 + """ + + name = "yfinance" + + def __init__(self, use_ssh_tunnel: bool = False, ssh_config: Optional[Dict] = None): + """初始化YFinance数据源""" + super().__init__(use_ssh_tunnel=use_ssh_tunnel, ssh_config=ssh_config) + self.use_ssh_tunnel = use_ssh_tunnel + self.ssh_config = ssh_config or {} + + def fetch(self, code: str, start: str, end: str) -> OHLCVData: + """获取境外数据""" + # TODO: 实现YFinance数据获取(含SSH隧道) + return OHLCVData(code=code, data=pd.DataFrame()) + + def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]: + """批量获取""" + return {code: self.fetch(code, start, end) for code in codes} + + +# 导出定制数据源 +__all__ = [ + 'LocalFileCache', + 'HybridDataSourceAdapter', + 'TushareDataSource', + 'YFinanceDataSource' +] \ No newline at end of file