Files
etf/archive/strategies/shared/data/sources.py
aszerW c905230a40 refactor(archive): move unused modules to archive/
Archive legacy framework and utility modules that are no longer
referenced by the active core (datasource/ and rotation/):

- framework/ -> archive/framework/
- framework_v2/ -> archive/framework_v2/
- strategies/ -> archive/strategies/
- config/ -> archive/config/
- visualization/ -> archive/visualization/
- scripts/ -> archive/scripts/
- tests/ -> archive/tests/
- run_rotation.py, run_us_rotation.py -> archive/single_files/
- compare_*.py, test_api_dates.py -> archive/single_files/
2026-06-03 23:41:46 +08:00

223 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
定制数据源实现
具体数据源适配器继承framework.data.DataSource
"""
from framework.data import DataSource, OHLCVData, DataCache
import pandas as pd
from typing import Dict, List, Optional
from datetime import datetime
import os
import json
class LocalFileCache(DataCache):
"""
本地文件缓存(定制实现)
支持版本控制和新鲜性检查
"""
def __init__(self, cache_dir: str = "data/etf_cache/daily"):
"""初始化缓存"""
self.cache_dir = cache_dir
def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]:
"""从缓存获取数据"""
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
if not os.path.exists(cache_file):
return None
try:
df = pd.read_csv(cache_file, index_col=0, parse_dates=True)
# 过滤日期范围
df = df.loc[start:end]
if df.empty:
return None
return OHLCVData(
code=code,
data=df,
start_date=df.index.min(),
end_date=df.index.max()
)
except Exception as e:
print(f"缓存读取失败: {code} - {e}")
return None
def set(self, code: str, data: OHLCVData) -> None:
"""写入缓存"""
if data.data is None:
return
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
# 如果已存在,追加新数据
if os.path.exists(cache_file):
existing = pd.read_csv(cache_file, index_col=0, parse_dates=True)
combined = pd.concat([existing, data.data]).drop_duplicates()
combined.to_csv(cache_file)
else:
data.data.to_csv(cache_file)
def is_fresh(self, code: str, max_age_days: int = 1) -> bool:
"""检查缓存是否新鲜"""
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
if not os.path.exists(meta_file):
return False
try:
with open(meta_file, 'r') as f:
meta = json.load(f)
last_update = datetime.fromisoformat(meta.get('last_update', ''))
age = (datetime.now() - last_update).days
return age <= max_age_days
except:
return False
def clear(self, code: Optional[str] = None) -> None:
"""清空缓存"""
if code:
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
if os.path.exists(cache_file):
os.remove(cache_file)
if os.path.exists(meta_file):
os.remove(meta_file)
else:
# 清空整个缓存目录
for f in os.listdir(self.cache_dir):
os.remove(os.path.join(self.cache_dir, f))
class HybridDataSourceAdapter(DataSource):
"""
混合数据源适配器(定制实现)
封装现有的HybridDataSource适配到框架DataSource接口
"""
name = "hybrid"
def __init__(
self,
use_cache: bool = True,
cache_dir: str = "data/etf_cache/daily",
ssh_config: Optional[Dict] = None
):
"""初始化混合数据源"""
super().__init__(use_cache=use_cache, cache_dir=cache_dir, ssh_config=ssh_config)
self.use_cache = use_cache
self.cache = LocalFileCache(cache_dir) if use_cache else None
self.ssh_config = ssh_config or {}
# 内部使用现有的HybridDataSource
self._hybrid_source = None
def _init_hybrid_source(self):
"""延迟初始化HybridDataSource"""
if self._hybrid_source is None:
from core.datasource.hybrid_source import HybridDataSource
self._hybrid_source = HybridDataSource(
ssh_config=self.ssh_config,
use_cache=self.use_cache
)
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
"""获取单个标的数据"""
# 先检查缓存
if self.cache and self.cache.is_fresh(code):
cached = self.cache.get(code, start, end)
if cached:
return cached
# 从数据源获取
self._init_hybrid_source()
# 这里需要根据代码类型判断使用哪个数据源
# 简化实现直接调用现有HybridDataSource
# TODO: 完整实现需要适配现有数据源
return OHLCVData(code=code, data=pd.DataFrame())
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
"""批量获取数据"""
result = {}
for code in codes:
result[code] = self.fetch(code, start, end)
return result
def get_supported_codes(self) -> List[str]:
"""获取支持的代码列表"""
# 从fund_basic.csv读取
basic_file = "data/etf_cache/fund_basic.csv"
if os.path.exists(basic_file):
df = pd.read_csv(basic_file)
return df['code'].tolist()
return []
class TushareDataSource(DataSource):
"""
Tushare数据源定制实现
用于获取A股数据
"""
name = "tushare"
def __init__(self, token: Optional[str] = None):
"""初始化Tushare数据源"""
super().__init__(token=token)
self.token = token
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
"""获取A股指数数据"""
# TODO: 实现Tushare数据获取
return OHLCVData(code=code, data=pd.DataFrame())
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
"""批量获取"""
return {code: self.fetch(code, start, end) for code in codes}
class YFinanceDataSource(DataSource):
"""
YFinance数据源定制实现
用于获取港股/美股/加密货币数据
"""
name = "yfinance"
def __init__(self, use_ssh_tunnel: bool = False, ssh_config: Optional[Dict] = None):
"""初始化YFinance数据源"""
super().__init__(use_ssh_tunnel=use_ssh_tunnel, ssh_config=ssh_config)
self.use_ssh_tunnel = use_ssh_tunnel
self.ssh_config = ssh_config or {}
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
"""获取境外数据"""
# TODO: 实现YFinance数据获取含SSH隧道
return OHLCVData(code=code, data=pd.DataFrame())
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
"""批量获取"""
return {code: self.fetch(code, start, end) for code in codes}
# 导出定制数据源
__all__ = [
'LocalFileCache',
'HybridDataSourceAdapter',
'TushareDataSource',
'YFinanceDataSource'
]