Archive legacy framework and utility modules that are no longer referenced by the active core (datasource/ and rotation/): - framework/ -> archive/framework/ - framework_v2/ -> archive/framework_v2/ - strategies/ -> archive/strategies/ - config/ -> archive/config/ - visualization/ -> archive/visualization/ - scripts/ -> archive/scripts/ - tests/ -> archive/tests/ - run_rotation.py, run_us_rotation.py -> archive/single_files/ - compare_*.py, test_api_dates.py -> archive/single_files/
223 lines
6.9 KiB
Python
223 lines
6.9 KiB
Python
"""
|
||
定制数据源实现
|
||
|
||
具体数据源适配器继承framework.data.DataSource
|
||
"""
|
||
|
||
from framework.data import DataSource, OHLCVData, DataCache
|
||
import pandas as pd
|
||
from typing import Dict, List, Optional
|
||
from datetime import datetime
|
||
import os
|
||
import json
|
||
|
||
|
||
class LocalFileCache(DataCache):
|
||
"""
|
||
本地文件缓存(定制实现)
|
||
|
||
支持版本控制和新鲜性检查
|
||
"""
|
||
|
||
def __init__(self, cache_dir: str = "data/etf_cache/daily"):
|
||
"""初始化缓存"""
|
||
self.cache_dir = cache_dir
|
||
|
||
def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]:
|
||
"""从缓存获取数据"""
|
||
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||
|
||
if not os.path.exists(cache_file):
|
||
return None
|
||
|
||
try:
|
||
df = pd.read_csv(cache_file, index_col=0, parse_dates=True)
|
||
|
||
# 过滤日期范围
|
||
df = df.loc[start:end]
|
||
|
||
if df.empty:
|
||
return None
|
||
|
||
return OHLCVData(
|
||
code=code,
|
||
data=df,
|
||
start_date=df.index.min(),
|
||
end_date=df.index.max()
|
||
)
|
||
except Exception as e:
|
||
print(f"缓存读取失败: {code} - {e}")
|
||
return None
|
||
|
||
def set(self, code: str, data: OHLCVData) -> None:
|
||
"""写入缓存"""
|
||
if data.data is None:
|
||
return
|
||
|
||
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||
|
||
# 如果已存在,追加新数据
|
||
if os.path.exists(cache_file):
|
||
existing = pd.read_csv(cache_file, index_col=0, parse_dates=True)
|
||
combined = pd.concat([existing, data.data]).drop_duplicates()
|
||
combined.to_csv(cache_file)
|
||
else:
|
||
data.data.to_csv(cache_file)
|
||
|
||
def is_fresh(self, code: str, max_age_days: int = 1) -> bool:
|
||
"""检查缓存是否新鲜"""
|
||
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
|
||
|
||
if not os.path.exists(meta_file):
|
||
return False
|
||
|
||
try:
|
||
with open(meta_file, 'r') as f:
|
||
meta = json.load(f)
|
||
|
||
last_update = datetime.fromisoformat(meta.get('last_update', ''))
|
||
age = (datetime.now() - last_update).days
|
||
|
||
return age <= max_age_days
|
||
except:
|
||
return False
|
||
|
||
def clear(self, code: Optional[str] = None) -> None:
|
||
"""清空缓存"""
|
||
if code:
|
||
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
|
||
|
||
if os.path.exists(cache_file):
|
||
os.remove(cache_file)
|
||
if os.path.exists(meta_file):
|
||
os.remove(meta_file)
|
||
else:
|
||
# 清空整个缓存目录
|
||
for f in os.listdir(self.cache_dir):
|
||
os.remove(os.path.join(self.cache_dir, f))
|
||
|
||
|
||
class HybridDataSourceAdapter(DataSource):
|
||
"""
|
||
混合数据源适配器(定制实现)
|
||
|
||
封装现有的HybridDataSource,适配到框架DataSource接口
|
||
"""
|
||
|
||
name = "hybrid"
|
||
|
||
def __init__(
|
||
self,
|
||
use_cache: bool = True,
|
||
cache_dir: str = "data/etf_cache/daily",
|
||
ssh_config: Optional[Dict] = None
|
||
):
|
||
"""初始化混合数据源"""
|
||
super().__init__(use_cache=use_cache, cache_dir=cache_dir, ssh_config=ssh_config)
|
||
self.use_cache = use_cache
|
||
self.cache = LocalFileCache(cache_dir) if use_cache else None
|
||
self.ssh_config = ssh_config or {}
|
||
|
||
# 内部使用现有的HybridDataSource
|
||
self._hybrid_source = None
|
||
|
||
def _init_hybrid_source(self):
|
||
"""延迟初始化HybridDataSource"""
|
||
if self._hybrid_source is None:
|
||
from core.datasource.hybrid_source import HybridDataSource
|
||
self._hybrid_source = HybridDataSource(
|
||
ssh_config=self.ssh_config,
|
||
use_cache=self.use_cache
|
||
)
|
||
|
||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||
"""获取单个标的数据"""
|
||
# 先检查缓存
|
||
if self.cache and self.cache.is_fresh(code):
|
||
cached = self.cache.get(code, start, end)
|
||
if cached:
|
||
return cached
|
||
|
||
# 从数据源获取
|
||
self._init_hybrid_source()
|
||
|
||
# 这里需要根据代码类型判断使用哪个数据源
|
||
# 简化实现:直接调用现有HybridDataSource
|
||
# TODO: 完整实现需要适配现有数据源
|
||
|
||
return OHLCVData(code=code, data=pd.DataFrame())
|
||
|
||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||
"""批量获取数据"""
|
||
result = {}
|
||
for code in codes:
|
||
result[code] = self.fetch(code, start, end)
|
||
return result
|
||
|
||
def get_supported_codes(self) -> List[str]:
|
||
"""获取支持的代码列表"""
|
||
# 从fund_basic.csv读取
|
||
basic_file = "data/etf_cache/fund_basic.csv"
|
||
if os.path.exists(basic_file):
|
||
df = pd.read_csv(basic_file)
|
||
return df['code'].tolist()
|
||
return []
|
||
|
||
|
||
class TushareDataSource(DataSource):
|
||
"""
|
||
Tushare数据源(定制实现)
|
||
|
||
用于获取A股数据
|
||
"""
|
||
|
||
name = "tushare"
|
||
|
||
def __init__(self, token: Optional[str] = None):
|
||
"""初始化Tushare数据源"""
|
||
super().__init__(token=token)
|
||
self.token = token
|
||
|
||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||
"""获取A股指数数据"""
|
||
# TODO: 实现Tushare数据获取
|
||
return OHLCVData(code=code, data=pd.DataFrame())
|
||
|
||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||
"""批量获取"""
|
||
return {code: self.fetch(code, start, end) for code in codes}
|
||
|
||
|
||
class YFinanceDataSource(DataSource):
|
||
"""
|
||
YFinance数据源(定制实现)
|
||
|
||
用于获取港股/美股/加密货币数据
|
||
"""
|
||
|
||
name = "yfinance"
|
||
|
||
def __init__(self, use_ssh_tunnel: bool = False, ssh_config: Optional[Dict] = None):
|
||
"""初始化YFinance数据源"""
|
||
super().__init__(use_ssh_tunnel=use_ssh_tunnel, ssh_config=ssh_config)
|
||
self.use_ssh_tunnel = use_ssh_tunnel
|
||
self.ssh_config = ssh_config or {}
|
||
|
||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||
"""获取境外数据"""
|
||
# TODO: 实现YFinance数据获取(含SSH隧道)
|
||
return OHLCVData(code=code, data=pd.DataFrame())
|
||
|
||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||
"""批量获取"""
|
||
return {code: self.fetch(code, start, end) for code in codes}
|
||
|
||
|
||
# 导出定制数据源
|
||
__all__ = [
|
||
'LocalFileCache',
|
||
'HybridDataSourceAdapter',
|
||
'TushareDataSource',
|
||
'YFinanceDataSource'
|
||
] |