feat(strategy): rotation策略支持Flask API数据获取
- 新增 flask_api_source.py: Flask API远程数据源模块 - 修改 strategy.py: get_data() 支持通过Flask API获取数据 使用方式: strategy.get_data(use_flask_api=True) # 通过部署服务获取 strategy.get_data(use_flask_api=False) # 本地HybridDataSource 配置项: flask_api_url: 可在config.yaml中指定API地址
This commit is contained in:
@@ -105,8 +105,14 @@ class RotationStrategy(StrategyBase):
|
||||
group_mapping[code] = cfg.get('market', 'default')
|
||||
return group_mapping
|
||||
|
||||
def get_data(self) -> dict:
|
||||
"""获取数据(使用新数据源模块)"""
|
||||
def get_data(self, use_flask_api: bool = True) -> dict:
|
||||
"""
|
||||
获取数据
|
||||
|
||||
Args:
|
||||
use_flask_api: 是否使用 Flask API 服务获取数据(默认 True)
|
||||
False 则使用本地 HybridDataSource
|
||||
"""
|
||||
code_list_config = self.config.get('code_list', {})
|
||||
benchmark_config = self.config.get('benchmark', {})
|
||||
benchmark_code = benchmark_config.get('code', '000300.SH')
|
||||
@@ -114,7 +120,123 @@ class RotationStrategy(StrategyBase):
|
||||
if not code_list_config:
|
||||
raise ValueError("配置中未找到 code_list")
|
||||
|
||||
# 使用新数据源模块
|
||||
# 获取 Flask API 地址
|
||||
flask_api_url = self.config.get('flask_api_url')
|
||||
|
||||
if use_flask_api:
|
||||
# 使用 Flask API 服务获取数据(远程调用)
|
||||
return self._get_data_from_flask_api(
|
||||
code_list_config,
|
||||
benchmark_code,
|
||||
flask_api_url
|
||||
)
|
||||
else:
|
||||
# 使用本地 HybridDataSource(需要本地 SSH 隧道)
|
||||
return self._get_data_from_local(
|
||||
code_list_config,
|
||||
benchmark_code
|
||||
)
|
||||
|
||||
def _get_data_from_flask_api(
|
||||
self,
|
||||
code_list_config: dict,
|
||||
benchmark_code: str,
|
||||
flask_api_url: str = None
|
||||
) -> dict:
|
||||
"""通过 Flask API 服务获取数据"""
|
||||
from datasource.flask_api_source import FlaskAPIDataSource
|
||||
|
||||
# 初始化 Flask API 数据源
|
||||
api_source = FlaskAPIDataSource(base_url=flask_api_url)
|
||||
|
||||
# 检查服务状态
|
||||
health = api_source.get_health()
|
||||
if health.get('status') != 'healthy':
|
||||
print(f"⚠ Flask API 服务状态: {health}")
|
||||
else:
|
||||
print(f"✓ Flask API 服务正常 (SSH: {health.get('ssh_configured', False)})")
|
||||
|
||||
# 获取指数代码列表
|
||||
index_codes = list(code_list_config.keys())
|
||||
|
||||
# 获取 ETF 代码映射
|
||||
etf_code_map = {}
|
||||
etf_codes = []
|
||||
for index_code, cfg in code_list_config.items():
|
||||
if isinstance(cfg, dict) and cfg.get('etf'):
|
||||
etf_code_map[index_code] = cfg['etf']
|
||||
etf_codes.append(cfg['etf'])
|
||||
|
||||
# 获取指数 OHLCV 数据
|
||||
print(f"\n获取指数数据 ({len(index_codes)} 只)...")
|
||||
index_ohlcv_data = api_source.fetch_batch(
|
||||
index_codes,
|
||||
self.start_date,
|
||||
self.end_date
|
||||
)
|
||||
|
||||
# 过滤有效代码
|
||||
valid_codes = [code for code, df in index_ohlcv_data.items() if df is not None and len(df) > 0]
|
||||
print(f"有效指数: {len(valid_codes)} 只")
|
||||
|
||||
# 获取 ETF 价格数据
|
||||
print(f"\n获取 ETF 数据 ({len(etf_codes)} 只)...")
|
||||
etf_ohlcv_data = api_source.fetch_batch(
|
||||
etf_codes,
|
||||
self.start_date,
|
||||
self.end_date
|
||||
)
|
||||
|
||||
# 转换为宽格式 DataFrame
|
||||
etf_data = None
|
||||
if etf_ohlcv_data:
|
||||
etf_close_dict = {}
|
||||
for etf_code, df in etf_ohlcv_data.items():
|
||||
if df is not None and 'close' in df.columns:
|
||||
etf_close_dict[etf_code] = df['close']
|
||||
if etf_close_dict:
|
||||
etf_data = pd.DataFrame(etf_close_dict)
|
||||
|
||||
# 获取基准数据
|
||||
print(f"\n获取基准数据 ({benchmark_code})...")
|
||||
benchmark_ohlcv = api_source.fetch(benchmark_code, self.start_date, self.end_date)
|
||||
benchmark_data = None
|
||||
if benchmark_ohlcv is not None:
|
||||
benchmark_data = benchmark_ohlcv['close']
|
||||
|
||||
# 构建指数收盘价宽格式 DataFrame(用于因子计算)
|
||||
index_close_dict = {}
|
||||
for code in valid_codes:
|
||||
df = index_ohlcv_data.get(code)
|
||||
if df is not None and 'close' in df.columns:
|
||||
index_close_dict[code] = df['close']
|
||||
index_close = pd.DataFrame(index_close_dict) if index_close_dict else None
|
||||
|
||||
# 获取 ETF 净值数据(用于溢价率计算)
|
||||
print(f"\n获取 ETF 净值数据...")
|
||||
etf_nav_data = {}
|
||||
for etf_code in etf_codes:
|
||||
nav_df = api_source.fetch_etf_nav(etf_code, self.start_date, self.end_date)
|
||||
if nav_df is not None:
|
||||
etf_nav_data[etf_code] = nav_df
|
||||
print(f"有效净值: {len(etf_nav_data)} 只")
|
||||
|
||||
return {
|
||||
'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame}
|
||||
'index_close': index_close, # 对齐后的收盘价(宽格式)
|
||||
'etf_data': etf_data, # ETF 收盘价(宽格式)
|
||||
'etf_nav_data': etf_nav_data, # ETF 净值数据 {code: DataFrame}
|
||||
'benchmark_data': benchmark_data, # 基准收盘价 Series
|
||||
'valid_codes': valid_codes, # 有效指数代码列表
|
||||
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
|
||||
}
|
||||
|
||||
def _get_data_from_local(
|
||||
self,
|
||||
code_list_config: dict,
|
||||
benchmark_code: str
|
||||
) -> dict:
|
||||
"""使用本地 HybridDataSource 获取数据"""
|
||||
from datasource import HybridDataSource
|
||||
|
||||
ssh_config = self.config.get('ssh_tunnel', {})
|
||||
|
||||
Reference in New Issue
Block a user