feat(data): 实现数据获取层抽象接口

- OHLCVData: 标准化K线数据结构
- DataSource: 数据源抽象接口(fetch/fetch_batch)
- DataCache: 缓存抽象接口(get/set/is_fresh)
- LocalFileCache: 本地文件缓存实现
- HybridDataSourceAdapter/TushareDataSource/YFinanceDataSource: 定制数据源适配器
This commit is contained in:
2026-05-11 23:24:11 +08:00
parent c5a41b71ae
commit 774758c3b0
5 changed files with 375 additions and 0 deletions

View File

@@ -22,6 +22,9 @@ from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExe
# 配置层
from framework.config import ConfigLoader, StrategyConfig
# 数据层抽象
from framework.data import OHLCVData, DataSource, DataCache
__all__ = [
# 因子层
@@ -49,4 +52,9 @@ __all__ = [
# 配置层
'ConfigLoader',
'StrategyConfig',
# 数据层
'OHLCVData',
'DataSource',
'DataCache',
]

126
framework/data/__init__.py Normal file
View File

@@ -0,0 +1,126 @@
"""
数据层抽象接口(通用)
只提供数据获取抽象接口具体实现在strategies/shared/data/
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any
import pandas as pd
from dataclasses import dataclass
from datetime import datetime
@dataclass
class OHLCVData:
"""
OHLCV数据结构通用
标准化的K线数据格式
"""
code: str
name: str = ""
start_date: datetime = None
end_date: datetime = None
# OHLCV数据DataFrame
data: pd.DataFrame = None
@property
def length(self) -> int:
"""数据长度"""
return len(self.data) if self.data is not None else 0
def validate(self) -> bool:
"""验证数据完整性"""
if self.data is None or self.data.empty:
return False
required_cols = ['close']
return all(col in self.data.columns for col in required_cols)
def __repr__(self) -> str:
return f"OHLCVData(code={self.code}, name={self.name}, length={self.length})"
class DataSource(ABC):
"""
数据源抽象接口
所有数据源必须实现fetch方法
"""
name: str = "base"
def __init__(self, **params):
"""初始化数据源参数"""
self._params = params
@abstractmethod
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
"""
获取单个标的的OHLCV数据
Args:
code: 标的代码
start: 开始日期 (YYYY-MM-DD)
end: 结束日期 (YYYY-MM-DD)
Returns:
OHLCVData对象
"""
pass
@abstractmethod
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
"""
批量获取多个标的的OHLCV数据
Args:
codes: 标的代码列表
start: 开始日期
end: 结束日期
Returns:
{code: OHLCVData}字典
"""
pass
def get_supported_codes(self) -> List[str]:
"""获取支持的数据源代码列表"""
return []
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name})"
class DataCache(ABC):
"""
数据缓存抽象接口(通用)
支持本地缓存管理
"""
@abstractmethod
def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]:
"""从缓存获取数据"""
pass
@abstractmethod
def set(self, code: str, data: OHLCVData) -> None:
"""写入缓存"""
pass
@abstractmethod
def is_fresh(self, code: str, max_age_days: int = 1) -> bool:
"""检查缓存是否新鲜"""
pass
@abstractmethod
def clear(self, code: Optional[str] = None) -> None:
"""清空缓存"""
pass
# 导出抽象接口
__all__ = ['OHLCVData', 'DataSource', 'DataCache']

Binary file not shown.

View File

@@ -0,0 +1,18 @@
"""
定制数据源统一入口
"""
from strategies.shared.data.sources import (
LocalFileCache,
HybridDataSourceAdapter,
TushareDataSource,
YFinanceDataSource
)
__all__ = [
'LocalFileCache',
'HybridDataSourceAdapter',
'TushareDataSource',
'YFinanceDataSource'
]

View File

@@ -0,0 +1,223 @@
"""
定制数据源实现
具体数据源适配器继承framework.data.DataSource
"""
from framework.data import DataSource, OHLCVData, DataCache
import pandas as pd
from typing import Dict, List, Optional
from datetime import datetime
import os
import json
class LocalFileCache(DataCache):
"""
本地文件缓存(定制实现)
支持版本控制和新鲜性检查
"""
def __init__(self, cache_dir: str = "data/etf_cache/daily"):
"""初始化缓存"""
self.cache_dir = cache_dir
def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]:
"""从缓存获取数据"""
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
if not os.path.exists(cache_file):
return None
try:
df = pd.read_csv(cache_file, index_col=0, parse_dates=True)
# 过滤日期范围
df = df.loc[start:end]
if df.empty:
return None
return OHLCVData(
code=code,
data=df,
start_date=df.index.min(),
end_date=df.index.max()
)
except Exception as e:
print(f"缓存读取失败: {code} - {e}")
return None
def set(self, code: str, data: OHLCVData) -> None:
"""写入缓存"""
if data.data is None:
return
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
# 如果已存在,追加新数据
if os.path.exists(cache_file):
existing = pd.read_csv(cache_file, index_col=0, parse_dates=True)
combined = pd.concat([existing, data.data]).drop_duplicates()
combined.to_csv(cache_file)
else:
data.data.to_csv(cache_file)
def is_fresh(self, code: str, max_age_days: int = 1) -> bool:
"""检查缓存是否新鲜"""
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
if not os.path.exists(meta_file):
return False
try:
with open(meta_file, 'r') as f:
meta = json.load(f)
last_update = datetime.fromisoformat(meta.get('last_update', ''))
age = (datetime.now() - last_update).days
return age <= max_age_days
except:
return False
def clear(self, code: Optional[str] = None) -> None:
"""清空缓存"""
if code:
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
if os.path.exists(cache_file):
os.remove(cache_file)
if os.path.exists(meta_file):
os.remove(meta_file)
else:
# 清空整个缓存目录
for f in os.listdir(self.cache_dir):
os.remove(os.path.join(self.cache_dir, f))
class HybridDataSourceAdapter(DataSource):
"""
混合数据源适配器(定制实现)
封装现有的HybridDataSource适配到框架DataSource接口
"""
name = "hybrid"
def __init__(
self,
use_cache: bool = True,
cache_dir: str = "data/etf_cache/daily",
ssh_config: Optional[Dict] = None
):
"""初始化混合数据源"""
super().__init__(use_cache=use_cache, cache_dir=cache_dir, ssh_config=ssh_config)
self.use_cache = use_cache
self.cache = LocalFileCache(cache_dir) if use_cache else None
self.ssh_config = ssh_config or {}
# 内部使用现有的HybridDataSource
self._hybrid_source = None
def _init_hybrid_source(self):
"""延迟初始化HybridDataSource"""
if self._hybrid_source is None:
from core.datasource.hybrid_source import HybridDataSource
self._hybrid_source = HybridDataSource(
ssh_config=self.ssh_config,
use_cache=self.use_cache
)
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
"""获取单个标的数据"""
# 先检查缓存
if self.cache and self.cache.is_fresh(code):
cached = self.cache.get(code, start, end)
if cached:
return cached
# 从数据源获取
self._init_hybrid_source()
# 这里需要根据代码类型判断使用哪个数据源
# 简化实现直接调用现有HybridDataSource
# TODO: 完整实现需要适配现有数据源
return OHLCVData(code=code, data=pd.DataFrame())
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
"""批量获取数据"""
result = {}
for code in codes:
result[code] = self.fetch(code, start, end)
return result
def get_supported_codes(self) -> List[str]:
"""获取支持的代码列表"""
# 从fund_basic.csv读取
basic_file = "data/etf_cache/fund_basic.csv"
if os.path.exists(basic_file):
df = pd.read_csv(basic_file)
return df['code'].tolist()
return []
class TushareDataSource(DataSource):
"""
Tushare数据源定制实现
用于获取A股数据
"""
name = "tushare"
def __init__(self, token: Optional[str] = None):
"""初始化Tushare数据源"""
super().__init__(token=token)
self.token = token
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
"""获取A股指数数据"""
# TODO: 实现Tushare数据获取
return OHLCVData(code=code, data=pd.DataFrame())
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
"""批量获取"""
return {code: self.fetch(code, start, end) for code in codes}
class YFinanceDataSource(DataSource):
"""
YFinance数据源定制实现
用于获取港股/美股/加密货币数据
"""
name = "yfinance"
def __init__(self, use_ssh_tunnel: bool = False, ssh_config: Optional[Dict] = None):
"""初始化YFinance数据源"""
super().__init__(use_ssh_tunnel=use_ssh_tunnel, ssh_config=ssh_config)
self.use_ssh_tunnel = use_ssh_tunnel
self.ssh_config = ssh_config or {}
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
"""获取境外数据"""
# TODO: 实现YFinance数据获取含SSH隧道
return OHLCVData(code=code, data=pd.DataFrame())
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
"""批量获取"""
return {code: self.fetch(code, start, end) for code in codes}
# 导出定制数据源
__all__ = [
'LocalFileCache',
'HybridDataSourceAdapter',
'TushareDataSource',
'YFinanceDataSource'
]