feat(data): 实现数据获取层抽象接口
- OHLCVData: 标准化K线数据结构 - DataSource: 数据源抽象接口(fetch/fetch_batch) - DataCache: 缓存抽象接口(get/set/is_fresh) - LocalFileCache: 本地文件缓存实现 - HybridDataSourceAdapter/TushareDataSource/YFinanceDataSource: 定制数据源适配器
This commit is contained in:
@@ -22,6 +22,9 @@ from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExe
|
||||
# 配置层
|
||||
from framework.config import ConfigLoader, StrategyConfig
|
||||
|
||||
# 数据层抽象
|
||||
from framework.data import OHLCVData, DataSource, DataCache
|
||||
|
||||
|
||||
__all__ = [
|
||||
# 因子层
|
||||
@@ -49,4 +52,9 @@ __all__ = [
|
||||
# 配置层
|
||||
'ConfigLoader',
|
||||
'StrategyConfig',
|
||||
|
||||
# 数据层
|
||||
'OHLCVData',
|
||||
'DataSource',
|
||||
'DataCache',
|
||||
]
|
||||
126
framework/data/__init__.py
Normal file
126
framework/data/__init__.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
数据层抽象接口(通用)
|
||||
|
||||
只提供数据获取抽象接口,具体实现在strategies/shared/data/
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Any
|
||||
import pandas as pd
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class OHLCVData:
|
||||
"""
|
||||
OHLCV数据结构(通用)
|
||||
|
||||
标准化的K线数据格式
|
||||
"""
|
||||
code: str
|
||||
name: str = ""
|
||||
start_date: datetime = None
|
||||
end_date: datetime = None
|
||||
|
||||
# OHLCV数据DataFrame
|
||||
data: pd.DataFrame = None
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
"""数据长度"""
|
||||
return len(self.data) if self.data is not None else 0
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证数据完整性"""
|
||||
if self.data is None or self.data.empty:
|
||||
return False
|
||||
|
||||
required_cols = ['close']
|
||||
return all(col in self.data.columns for col in required_cols)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"OHLCVData(code={self.code}, name={self.name}, length={self.length})"
|
||||
|
||||
|
||||
class DataSource(ABC):
|
||||
"""
|
||||
数据源抽象接口
|
||||
|
||||
所有数据源必须实现fetch方法
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""初始化数据源参数"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||
"""
|
||||
获取单个标的的OHLCV数据
|
||||
|
||||
Args:
|
||||
code: 标的代码
|
||||
start: 开始日期 (YYYY-MM-DD)
|
||||
end: 结束日期 (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
OHLCVData对象
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||
"""
|
||||
批量获取多个标的的OHLCV数据
|
||||
|
||||
Args:
|
||||
codes: 标的代码列表
|
||||
start: 开始日期
|
||||
end: 结束日期
|
||||
|
||||
Returns:
|
||||
{code: OHLCVData}字典
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_supported_codes(self) -> List[str]:
|
||||
"""获取支持的数据源代码列表"""
|
||||
return []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name})"
|
||||
|
||||
|
||||
class DataCache(ABC):
|
||||
"""
|
||||
数据缓存抽象接口(通用)
|
||||
|
||||
支持本地缓存管理
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]:
|
||||
"""从缓存获取数据"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, code: str, data: OHLCVData) -> None:
|
||||
"""写入缓存"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_fresh(self, code: str, max_age_days: int = 1) -> bool:
|
||||
"""检查缓存是否新鲜"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, code: Optional[str] = None) -> None:
|
||||
"""清空缓存"""
|
||||
pass
|
||||
|
||||
|
||||
# 导出抽象接口
|
||||
__all__ = ['OHLCVData', 'DataSource', 'DataCache']
|
||||
BIN
framework/data/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
framework/data/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
18
strategies/shared/data/__init__.py
Normal file
18
strategies/shared/data/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
定制数据源统一入口
|
||||
"""
|
||||
|
||||
from strategies.shared.data.sources import (
|
||||
LocalFileCache,
|
||||
HybridDataSourceAdapter,
|
||||
TushareDataSource,
|
||||
YFinanceDataSource
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'LocalFileCache',
|
||||
'HybridDataSourceAdapter',
|
||||
'TushareDataSource',
|
||||
'YFinanceDataSource'
|
||||
]
|
||||
223
strategies/shared/data/sources.py
Normal file
223
strategies/shared/data/sources.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
定制数据源实现
|
||||
|
||||
具体数据源适配器继承framework.data.DataSource
|
||||
"""
|
||||
|
||||
from framework.data import DataSource, OHLCVData, DataCache
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
class LocalFileCache(DataCache):
|
||||
"""
|
||||
本地文件缓存(定制实现)
|
||||
|
||||
支持版本控制和新鲜性检查
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str = "data/etf_cache/daily"):
|
||||
"""初始化缓存"""
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]:
|
||||
"""从缓存获取数据"""
|
||||
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||||
|
||||
if not os.path.exists(cache_file):
|
||||
return None
|
||||
|
||||
try:
|
||||
df = pd.read_csv(cache_file, index_col=0, parse_dates=True)
|
||||
|
||||
# 过滤日期范围
|
||||
df = df.loc[start:end]
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
return OHLCVData(
|
||||
code=code,
|
||||
data=df,
|
||||
start_date=df.index.min(),
|
||||
end_date=df.index.max()
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"缓存读取失败: {code} - {e}")
|
||||
return None
|
||||
|
||||
def set(self, code: str, data: OHLCVData) -> None:
|
||||
"""写入缓存"""
|
||||
if data.data is None:
|
||||
return
|
||||
|
||||
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||||
|
||||
# 如果已存在,追加新数据
|
||||
if os.path.exists(cache_file):
|
||||
existing = pd.read_csv(cache_file, index_col=0, parse_dates=True)
|
||||
combined = pd.concat([existing, data.data]).drop_duplicates()
|
||||
combined.to_csv(cache_file)
|
||||
else:
|
||||
data.data.to_csv(cache_file)
|
||||
|
||||
def is_fresh(self, code: str, max_age_days: int = 1) -> bool:
|
||||
"""检查缓存是否新鲜"""
|
||||
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
|
||||
|
||||
if not os.path.exists(meta_file):
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(meta_file, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
last_update = datetime.fromisoformat(meta.get('last_update', ''))
|
||||
age = (datetime.now() - last_update).days
|
||||
|
||||
return age <= max_age_days
|
||||
except:
|
||||
return False
|
||||
|
||||
def clear(self, code: Optional[str] = None) -> None:
|
||||
"""清空缓存"""
|
||||
if code:
|
||||
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||||
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
os.remove(cache_file)
|
||||
if os.path.exists(meta_file):
|
||||
os.remove(meta_file)
|
||||
else:
|
||||
# 清空整个缓存目录
|
||||
for f in os.listdir(self.cache_dir):
|
||||
os.remove(os.path.join(self.cache_dir, f))
|
||||
|
||||
|
||||
class HybridDataSourceAdapter(DataSource):
|
||||
"""
|
||||
混合数据源适配器(定制实现)
|
||||
|
||||
封装现有的HybridDataSource,适配到框架DataSource接口
|
||||
"""
|
||||
|
||||
name = "hybrid"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_cache: bool = True,
|
||||
cache_dir: str = "data/etf_cache/daily",
|
||||
ssh_config: Optional[Dict] = None
|
||||
):
|
||||
"""初始化混合数据源"""
|
||||
super().__init__(use_cache=use_cache, cache_dir=cache_dir, ssh_config=ssh_config)
|
||||
self.use_cache = use_cache
|
||||
self.cache = LocalFileCache(cache_dir) if use_cache else None
|
||||
self.ssh_config = ssh_config or {}
|
||||
|
||||
# 内部使用现有的HybridDataSource
|
||||
self._hybrid_source = None
|
||||
|
||||
def _init_hybrid_source(self):
|
||||
"""延迟初始化HybridDataSource"""
|
||||
if self._hybrid_source is None:
|
||||
from core.datasource.hybrid_source import HybridDataSource
|
||||
self._hybrid_source = HybridDataSource(
|
||||
ssh_config=self.ssh_config,
|
||||
use_cache=self.use_cache
|
||||
)
|
||||
|
||||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||
"""获取单个标的数据"""
|
||||
# 先检查缓存
|
||||
if self.cache and self.cache.is_fresh(code):
|
||||
cached = self.cache.get(code, start, end)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# 从数据源获取
|
||||
self._init_hybrid_source()
|
||||
|
||||
# 这里需要根据代码类型判断使用哪个数据源
|
||||
# 简化实现:直接调用现有HybridDataSource
|
||||
# TODO: 完整实现需要适配现有数据源
|
||||
|
||||
return OHLCVData(code=code, data=pd.DataFrame())
|
||||
|
||||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||
"""批量获取数据"""
|
||||
result = {}
|
||||
for code in codes:
|
||||
result[code] = self.fetch(code, start, end)
|
||||
return result
|
||||
|
||||
def get_supported_codes(self) -> List[str]:
|
||||
"""获取支持的代码列表"""
|
||||
# 从fund_basic.csv读取
|
||||
basic_file = "data/etf_cache/fund_basic.csv"
|
||||
if os.path.exists(basic_file):
|
||||
df = pd.read_csv(basic_file)
|
||||
return df['code'].tolist()
|
||||
return []
|
||||
|
||||
|
||||
class TushareDataSource(DataSource):
|
||||
"""
|
||||
Tushare数据源(定制实现)
|
||||
|
||||
用于获取A股数据
|
||||
"""
|
||||
|
||||
name = "tushare"
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
"""初始化Tushare数据源"""
|
||||
super().__init__(token=token)
|
||||
self.token = token
|
||||
|
||||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||
"""获取A股指数数据"""
|
||||
# TODO: 实现Tushare数据获取
|
||||
return OHLCVData(code=code, data=pd.DataFrame())
|
||||
|
||||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||
"""批量获取"""
|
||||
return {code: self.fetch(code, start, end) for code in codes}
|
||||
|
||||
|
||||
class YFinanceDataSource(DataSource):
|
||||
"""
|
||||
YFinance数据源(定制实现)
|
||||
|
||||
用于获取港股/美股/加密货币数据
|
||||
"""
|
||||
|
||||
name = "yfinance"
|
||||
|
||||
def __init__(self, use_ssh_tunnel: bool = False, ssh_config: Optional[Dict] = None):
|
||||
"""初始化YFinance数据源"""
|
||||
super().__init__(use_ssh_tunnel=use_ssh_tunnel, ssh_config=ssh_config)
|
||||
self.use_ssh_tunnel = use_ssh_tunnel
|
||||
self.ssh_config = ssh_config or {}
|
||||
|
||||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||
"""获取境外数据"""
|
||||
# TODO: 实现YFinance数据获取(含SSH隧道)
|
||||
return OHLCVData(code=code, data=pd.DataFrame())
|
||||
|
||||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||
"""批量获取"""
|
||||
return {code: self.fetch(code, start, end) for code in codes}
|
||||
|
||||
|
||||
# 导出定制数据源
|
||||
__all__ = [
|
||||
'LocalFileCache',
|
||||
'HybridDataSourceAdapter',
|
||||
'TushareDataSource',
|
||||
'YFinanceDataSource'
|
||||
]
|
||||
Reference in New Issue
Block a user