feat(data): 实现数据获取层抽象接口
- OHLCVData: 标准化K线数据结构 - DataSource: 数据源抽象接口(fetch/fetch_batch) - DataCache: 缓存抽象接口(get/set/is_fresh) - LocalFileCache: 本地文件缓存实现 - HybridDataSourceAdapter/TushareDataSource/YFinanceDataSource: 定制数据源适配器
This commit is contained in:
@@ -22,6 +22,9 @@ from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExe
|
|||||||
# 配置层
|
# 配置层
|
||||||
from framework.config import ConfigLoader, StrategyConfig
|
from framework.config import ConfigLoader, StrategyConfig
|
||||||
|
|
||||||
|
# 数据层抽象
|
||||||
|
from framework.data import OHLCVData, DataSource, DataCache
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 因子层
|
# 因子层
|
||||||
@@ -49,4 +52,9 @@ __all__ = [
|
|||||||
# 配置层
|
# 配置层
|
||||||
'ConfigLoader',
|
'ConfigLoader',
|
||||||
'StrategyConfig',
|
'StrategyConfig',
|
||||||
|
|
||||||
|
# 数据层
|
||||||
|
'OHLCVData',
|
||||||
|
'DataSource',
|
||||||
|
'DataCache',
|
||||||
]
|
]
|
||||||
126
framework/data/__init__.py
Normal file
126
framework/data/__init__.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
"""
|
||||||
|
数据层抽象接口(通用)
|
||||||
|
|
||||||
|
只提供数据获取抽象接口,具体实现在strategies/shared/data/
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
import pandas as pd
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OHLCVData:
|
||||||
|
"""
|
||||||
|
OHLCV数据结构(通用)
|
||||||
|
|
||||||
|
标准化的K线数据格式
|
||||||
|
"""
|
||||||
|
code: str
|
||||||
|
name: str = ""
|
||||||
|
start_date: datetime = None
|
||||||
|
end_date: datetime = None
|
||||||
|
|
||||||
|
# OHLCV数据DataFrame
|
||||||
|
data: pd.DataFrame = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def length(self) -> int:
|
||||||
|
"""数据长度"""
|
||||||
|
return len(self.data) if self.data is not None else 0
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
"""验证数据完整性"""
|
||||||
|
if self.data is None or self.data.empty:
|
||||||
|
return False
|
||||||
|
|
||||||
|
required_cols = ['close']
|
||||||
|
return all(col in self.data.columns for col in required_cols)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"OHLCVData(code={self.code}, name={self.name}, length={self.length})"
|
||||||
|
|
||||||
|
|
||||||
|
class DataSource(ABC):
|
||||||
|
"""
|
||||||
|
数据源抽象接口
|
||||||
|
|
||||||
|
所有数据源必须实现fetch方法
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "base"
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
"""初始化数据源参数"""
|
||||||
|
self._params = params
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||||
|
"""
|
||||||
|
获取单个标的的OHLCV数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: 标的代码
|
||||||
|
start: 开始日期 (YYYY-MM-DD)
|
||||||
|
end: 结束日期 (YYYY-MM-DD)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OHLCVData对象
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||||
|
"""
|
||||||
|
批量获取多个标的的OHLCV数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
codes: 标的代码列表
|
||||||
|
start: 开始日期
|
||||||
|
end: 结束日期
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{code: OHLCVData}字典
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_supported_codes(self) -> List[str]:
|
||||||
|
"""获取支持的数据源代码列表"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(name={self.name})"
|
||||||
|
|
||||||
|
|
||||||
|
class DataCache(ABC):
|
||||||
|
"""
|
||||||
|
数据缓存抽象接口(通用)
|
||||||
|
|
||||||
|
支持本地缓存管理
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]:
|
||||||
|
"""从缓存获取数据"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set(self, code: str, data: OHLCVData) -> None:
|
||||||
|
"""写入缓存"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_fresh(self, code: str, max_age_days: int = 1) -> bool:
|
||||||
|
"""检查缓存是否新鲜"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def clear(self, code: Optional[str] = None) -> None:
|
||||||
|
"""清空缓存"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# 导出抽象接口
|
||||||
|
__all__ = ['OHLCVData', 'DataSource', 'DataCache']
|
||||||
BIN
framework/data/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
framework/data/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
18
strategies/shared/data/__init__.py
Normal file
18
strategies/shared/data/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
定制数据源统一入口
|
||||||
|
"""
|
||||||
|
|
||||||
|
from strategies.shared.data.sources import (
|
||||||
|
LocalFileCache,
|
||||||
|
HybridDataSourceAdapter,
|
||||||
|
TushareDataSource,
|
||||||
|
YFinanceDataSource
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'LocalFileCache',
|
||||||
|
'HybridDataSourceAdapter',
|
||||||
|
'TushareDataSource',
|
||||||
|
'YFinanceDataSource'
|
||||||
|
]
|
||||||
223
strategies/shared/data/sources.py
Normal file
223
strategies/shared/data/sources.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""
|
||||||
|
定制数据源实现
|
||||||
|
|
||||||
|
具体数据源适配器继承framework.data.DataSource
|
||||||
|
"""
|
||||||
|
|
||||||
|
from framework.data import DataSource, OHLCVData, DataCache
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFileCache(DataCache):
|
||||||
|
"""
|
||||||
|
本地文件缓存(定制实现)
|
||||||
|
|
||||||
|
支持版本控制和新鲜性检查
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cache_dir: str = "data/etf_cache/daily"):
|
||||||
|
"""初始化缓存"""
|
||||||
|
self.cache_dir = cache_dir
|
||||||
|
|
||||||
|
def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]:
|
||||||
|
"""从缓存获取数据"""
|
||||||
|
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||||||
|
|
||||||
|
if not os.path.exists(cache_file):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(cache_file, index_col=0, parse_dates=True)
|
||||||
|
|
||||||
|
# 过滤日期范围
|
||||||
|
df = df.loc[start:end]
|
||||||
|
|
||||||
|
if df.empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return OHLCVData(
|
||||||
|
code=code,
|
||||||
|
data=df,
|
||||||
|
start_date=df.index.min(),
|
||||||
|
end_date=df.index.max()
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"缓存读取失败: {code} - {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, code: str, data: OHLCVData) -> None:
|
||||||
|
"""写入缓存"""
|
||||||
|
if data.data is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||||||
|
|
||||||
|
# 如果已存在,追加新数据
|
||||||
|
if os.path.exists(cache_file):
|
||||||
|
existing = pd.read_csv(cache_file, index_col=0, parse_dates=True)
|
||||||
|
combined = pd.concat([existing, data.data]).drop_duplicates()
|
||||||
|
combined.to_csv(cache_file)
|
||||||
|
else:
|
||||||
|
data.data.to_csv(cache_file)
|
||||||
|
|
||||||
|
def is_fresh(self, code: str, max_age_days: int = 1) -> bool:
|
||||||
|
"""检查缓存是否新鲜"""
|
||||||
|
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
|
||||||
|
|
||||||
|
if not os.path.exists(meta_file):
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(meta_file, 'r') as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
last_update = datetime.fromisoformat(meta.get('last_update', ''))
|
||||||
|
age = (datetime.now() - last_update).days
|
||||||
|
|
||||||
|
return age <= max_age_days
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def clear(self, code: Optional[str] = None) -> None:
|
||||||
|
"""清空缓存"""
|
||||||
|
if code:
|
||||||
|
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||||||
|
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
|
||||||
|
|
||||||
|
if os.path.exists(cache_file):
|
||||||
|
os.remove(cache_file)
|
||||||
|
if os.path.exists(meta_file):
|
||||||
|
os.remove(meta_file)
|
||||||
|
else:
|
||||||
|
# 清空整个缓存目录
|
||||||
|
for f in os.listdir(self.cache_dir):
|
||||||
|
os.remove(os.path.join(self.cache_dir, f))
|
||||||
|
|
||||||
|
|
||||||
|
class HybridDataSourceAdapter(DataSource):
|
||||||
|
"""
|
||||||
|
混合数据源适配器(定制实现)
|
||||||
|
|
||||||
|
封装现有的HybridDataSource,适配到框架DataSource接口
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "hybrid"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
use_cache: bool = True,
|
||||||
|
cache_dir: str = "data/etf_cache/daily",
|
||||||
|
ssh_config: Optional[Dict] = None
|
||||||
|
):
|
||||||
|
"""初始化混合数据源"""
|
||||||
|
super().__init__(use_cache=use_cache, cache_dir=cache_dir, ssh_config=ssh_config)
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.cache = LocalFileCache(cache_dir) if use_cache else None
|
||||||
|
self.ssh_config = ssh_config or {}
|
||||||
|
|
||||||
|
# 内部使用现有的HybridDataSource
|
||||||
|
self._hybrid_source = None
|
||||||
|
|
||||||
|
def _init_hybrid_source(self):
|
||||||
|
"""延迟初始化HybridDataSource"""
|
||||||
|
if self._hybrid_source is None:
|
||||||
|
from core.datasource.hybrid_source import HybridDataSource
|
||||||
|
self._hybrid_source = HybridDataSource(
|
||||||
|
ssh_config=self.ssh_config,
|
||||||
|
use_cache=self.use_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||||
|
"""获取单个标的数据"""
|
||||||
|
# 先检查缓存
|
||||||
|
if self.cache and self.cache.is_fresh(code):
|
||||||
|
cached = self.cache.get(code, start, end)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# 从数据源获取
|
||||||
|
self._init_hybrid_source()
|
||||||
|
|
||||||
|
# 这里需要根据代码类型判断使用哪个数据源
|
||||||
|
# 简化实现:直接调用现有HybridDataSource
|
||||||
|
# TODO: 完整实现需要适配现有数据源
|
||||||
|
|
||||||
|
return OHLCVData(code=code, data=pd.DataFrame())
|
||||||
|
|
||||||
|
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||||
|
"""批量获取数据"""
|
||||||
|
result = {}
|
||||||
|
for code in codes:
|
||||||
|
result[code] = self.fetch(code, start, end)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_supported_codes(self) -> List[str]:
|
||||||
|
"""获取支持的代码列表"""
|
||||||
|
# 从fund_basic.csv读取
|
||||||
|
basic_file = "data/etf_cache/fund_basic.csv"
|
||||||
|
if os.path.exists(basic_file):
|
||||||
|
df = pd.read_csv(basic_file)
|
||||||
|
return df['code'].tolist()
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class TushareDataSource(DataSource):
|
||||||
|
"""
|
||||||
|
Tushare数据源(定制实现)
|
||||||
|
|
||||||
|
用于获取A股数据
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "tushare"
|
||||||
|
|
||||||
|
def __init__(self, token: Optional[str] = None):
|
||||||
|
"""初始化Tushare数据源"""
|
||||||
|
super().__init__(token=token)
|
||||||
|
self.token = token
|
||||||
|
|
||||||
|
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||||
|
"""获取A股指数数据"""
|
||||||
|
# TODO: 实现Tushare数据获取
|
||||||
|
return OHLCVData(code=code, data=pd.DataFrame())
|
||||||
|
|
||||||
|
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||||
|
"""批量获取"""
|
||||||
|
return {code: self.fetch(code, start, end) for code in codes}
|
||||||
|
|
||||||
|
|
||||||
|
class YFinanceDataSource(DataSource):
|
||||||
|
"""
|
||||||
|
YFinance数据源(定制实现)
|
||||||
|
|
||||||
|
用于获取港股/美股/加密货币数据
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "yfinance"
|
||||||
|
|
||||||
|
def __init__(self, use_ssh_tunnel: bool = False, ssh_config: Optional[Dict] = None):
|
||||||
|
"""初始化YFinance数据源"""
|
||||||
|
super().__init__(use_ssh_tunnel=use_ssh_tunnel, ssh_config=ssh_config)
|
||||||
|
self.use_ssh_tunnel = use_ssh_tunnel
|
||||||
|
self.ssh_config = ssh_config or {}
|
||||||
|
|
||||||
|
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||||
|
"""获取境外数据"""
|
||||||
|
# TODO: 实现YFinance数据获取(含SSH隧道)
|
||||||
|
return OHLCVData(code=code, data=pd.DataFrame())
|
||||||
|
|
||||||
|
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||||
|
"""批量获取"""
|
||||||
|
return {code: self.fetch(code, start, end) for code in codes}
|
||||||
|
|
||||||
|
|
||||||
|
# 导出定制数据源
|
||||||
|
__all__ = [
|
||||||
|
'LocalFileCache',
|
||||||
|
'HybridDataSourceAdapter',
|
||||||
|
'TushareDataSource',
|
||||||
|
'YFinanceDataSource'
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user