refactor(archive): move unused modules to archive/
Archive legacy framework and utility modules that are no longer referenced by the active core (datasource/ and rotation/): - framework/ -> archive/framework/ - framework_v2/ -> archive/framework_v2/ - strategies/ -> archive/strategies/ - config/ -> archive/config/ - visualization/ -> archive/visualization/ - scripts/ -> archive/scripts/ - tests/ -> archive/tests/ - run_rotation.py, run_us_rotation.py -> archive/single_files/ - compare_*.py, test_api_dates.py -> archive/single_files/
This commit is contained in:
223
archive/strategies/shared/data/sources.py
Normal file
223
archive/strategies/shared/data/sources.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
定制数据源实现
|
||||
|
||||
具体数据源适配器继承framework.data.DataSource
|
||||
"""
|
||||
|
||||
from framework.data import DataSource, OHLCVData, DataCache
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
class LocalFileCache(DataCache):
|
||||
"""
|
||||
本地文件缓存(定制实现)
|
||||
|
||||
支持版本控制和新鲜性检查
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str = "data/etf_cache/daily"):
|
||||
"""初始化缓存"""
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]:
|
||||
"""从缓存获取数据"""
|
||||
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||||
|
||||
if not os.path.exists(cache_file):
|
||||
return None
|
||||
|
||||
try:
|
||||
df = pd.read_csv(cache_file, index_col=0, parse_dates=True)
|
||||
|
||||
# 过滤日期范围
|
||||
df = df.loc[start:end]
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
return OHLCVData(
|
||||
code=code,
|
||||
data=df,
|
||||
start_date=df.index.min(),
|
||||
end_date=df.index.max()
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"缓存读取失败: {code} - {e}")
|
||||
return None
|
||||
|
||||
def set(self, code: str, data: OHLCVData) -> None:
|
||||
"""写入缓存"""
|
||||
if data.data is None:
|
||||
return
|
||||
|
||||
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||||
|
||||
# 如果已存在,追加新数据
|
||||
if os.path.exists(cache_file):
|
||||
existing = pd.read_csv(cache_file, index_col=0, parse_dates=True)
|
||||
combined = pd.concat([existing, data.data]).drop_duplicates()
|
||||
combined.to_csv(cache_file)
|
||||
else:
|
||||
data.data.to_csv(cache_file)
|
||||
|
||||
def is_fresh(self, code: str, max_age_days: int = 1) -> bool:
|
||||
"""检查缓存是否新鲜"""
|
||||
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
|
||||
|
||||
if not os.path.exists(meta_file):
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(meta_file, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
last_update = datetime.fromisoformat(meta.get('last_update', ''))
|
||||
age = (datetime.now() - last_update).days
|
||||
|
||||
return age <= max_age_days
|
||||
except:
|
||||
return False
|
||||
|
||||
def clear(self, code: Optional[str] = None) -> None:
|
||||
"""清空缓存"""
|
||||
if code:
|
||||
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||||
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
os.remove(cache_file)
|
||||
if os.path.exists(meta_file):
|
||||
os.remove(meta_file)
|
||||
else:
|
||||
# 清空整个缓存目录
|
||||
for f in os.listdir(self.cache_dir):
|
||||
os.remove(os.path.join(self.cache_dir, f))
|
||||
|
||||
|
||||
class HybridDataSourceAdapter(DataSource):
|
||||
"""
|
||||
混合数据源适配器(定制实现)
|
||||
|
||||
封装现有的HybridDataSource,适配到框架DataSource接口
|
||||
"""
|
||||
|
||||
name = "hybrid"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_cache: bool = True,
|
||||
cache_dir: str = "data/etf_cache/daily",
|
||||
ssh_config: Optional[Dict] = None
|
||||
):
|
||||
"""初始化混合数据源"""
|
||||
super().__init__(use_cache=use_cache, cache_dir=cache_dir, ssh_config=ssh_config)
|
||||
self.use_cache = use_cache
|
||||
self.cache = LocalFileCache(cache_dir) if use_cache else None
|
||||
self.ssh_config = ssh_config or {}
|
||||
|
||||
# 内部使用现有的HybridDataSource
|
||||
self._hybrid_source = None
|
||||
|
||||
def _init_hybrid_source(self):
|
||||
"""延迟初始化HybridDataSource"""
|
||||
if self._hybrid_source is None:
|
||||
from core.datasource.hybrid_source import HybridDataSource
|
||||
self._hybrid_source = HybridDataSource(
|
||||
ssh_config=self.ssh_config,
|
||||
use_cache=self.use_cache
|
||||
)
|
||||
|
||||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||
"""获取单个标的数据"""
|
||||
# 先检查缓存
|
||||
if self.cache and self.cache.is_fresh(code):
|
||||
cached = self.cache.get(code, start, end)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# 从数据源获取
|
||||
self._init_hybrid_source()
|
||||
|
||||
# 这里需要根据代码类型判断使用哪个数据源
|
||||
# 简化实现:直接调用现有HybridDataSource
|
||||
# TODO: 完整实现需要适配现有数据源
|
||||
|
||||
return OHLCVData(code=code, data=pd.DataFrame())
|
||||
|
||||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||
"""批量获取数据"""
|
||||
result = {}
|
||||
for code in codes:
|
||||
result[code] = self.fetch(code, start, end)
|
||||
return result
|
||||
|
||||
def get_supported_codes(self) -> List[str]:
|
||||
"""获取支持的代码列表"""
|
||||
# 从fund_basic.csv读取
|
||||
basic_file = "data/etf_cache/fund_basic.csv"
|
||||
if os.path.exists(basic_file):
|
||||
df = pd.read_csv(basic_file)
|
||||
return df['code'].tolist()
|
||||
return []
|
||||
|
||||
|
||||
class TushareDataSource(DataSource):
|
||||
"""
|
||||
Tushare数据源(定制实现)
|
||||
|
||||
用于获取A股数据
|
||||
"""
|
||||
|
||||
name = "tushare"
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
"""初始化Tushare数据源"""
|
||||
super().__init__(token=token)
|
||||
self.token = token
|
||||
|
||||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||
"""获取A股指数数据"""
|
||||
# TODO: 实现Tushare数据获取
|
||||
return OHLCVData(code=code, data=pd.DataFrame())
|
||||
|
||||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||
"""批量获取"""
|
||||
return {code: self.fetch(code, start, end) for code in codes}
|
||||
|
||||
|
||||
class YFinanceDataSource(DataSource):
|
||||
"""
|
||||
YFinance数据源(定制实现)
|
||||
|
||||
用于获取港股/美股/加密货币数据
|
||||
"""
|
||||
|
||||
name = "yfinance"
|
||||
|
||||
def __init__(self, use_ssh_tunnel: bool = False, ssh_config: Optional[Dict] = None):
|
||||
"""初始化YFinance数据源"""
|
||||
super().__init__(use_ssh_tunnel=use_ssh_tunnel, ssh_config=ssh_config)
|
||||
self.use_ssh_tunnel = use_ssh_tunnel
|
||||
self.ssh_config = ssh_config or {}
|
||||
|
||||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||
"""获取境外数据"""
|
||||
# TODO: 实现YFinance数据获取(含SSH隧道)
|
||||
return OHLCVData(code=code, data=pd.DataFrame())
|
||||
|
||||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||
"""批量获取"""
|
||||
return {code: self.fetch(code, start, end) for code in codes}
|
||||
|
||||
|
||||
# 导出定制数据源
|
||||
__all__ = [
|
||||
'LocalFileCache',
|
||||
'HybridDataSourceAdapter',
|
||||
'TushareDataSource',
|
||||
'YFinanceDataSource'
|
||||
]
|
||||
Reference in New Issue
Block a user