diff --git a/datasource/__init__.py b/datasource/__init__.py index 9b9da9d..d24d233 100644 --- a/datasource/__init__.py +++ b/datasource/__init__.py @@ -4,22 +4,27 @@ 核心数据获取能力: - A股数据:Tushare(指数、ETF、期货) - 境外数据:YFinance(港股、美股)通过SSH隧道 +- 加密货币:CCXT(OKX)通过 socks2http 架构设计: -- 分层架构:对外统一接口,对内各资产类型独立实现 +- 分层架构:基础层统一接口,扩展层资产类型特有方法 - Flask API:LRU + TTL 双缓存机制 用法: - from datasource import UniversalDataFetcher, AssetType + from datasource import UniversalDataFetcher + # 基础层:统一 OHLCV 接口 fetcher = UniversalDataFetcher() df = fetcher.fetch("000300.SH", "2024-01-01", "2024-12-31") + + # 扩展层:资产类型特有方法 + df_adj = fetcher.fetch_etf_adj("513100.SH", ...) # ETF 后复权 + df_adj = fetcher.fetch_us_adj("AAPL", ...) # 美股复权 """ from .ssh_tunnel import SSHTunnelManager from .tushare_source import TushareSource from .yfinance_source import YFinanceSource -from .hybrid_source import HybridDataSource from .asset_type_detector import AssetTypeDetector, AssetType from .universal_fetcher import UniversalDataFetcher @@ -27,7 +32,6 @@ __all__ = [ 'SSHTunnelManager', 'TushareSource', 'YFinanceSource', - 'HybridDataSource', 'AssetTypeDetector', 'AssetType', 'UniversalDataFetcher', diff --git a/datasource/hybrid_source.py b/datasource/hybrid_source.py deleted file mode 100644 index 2fb4021..0000000 --- a/datasource/hybrid_source.py +++ /dev/null @@ -1,301 +0,0 @@ -""" -混合数据源 - -整合 Tushare(A股) + YFinance(境外)数据获取 -""" - -import os -import time -from typing import Optional, Tuple, Dict, List -from datetime import datetime -from pathlib import Path -import pandas as pd - -from .ssh_tunnel import SSHTunnelManager -from .tushare_source import TushareSource -from .yfinance_source import YFinanceSource - - -class HybridDataSource: - """ - 混合数据源 - - - A股指数/ETF/期货: Tushare - - 港股/美股/商品: YFinance(通过SSH隧道) - - 使用方式: - from datasource import HybridDataSource - - source = HybridDataSource.from_yaml('strategies/rotation/config.yaml') - result = source.fetch_all() - """ - - def __init__( - self, - ssh_config: Optional[dict] = None, - use_cache: bool = True, - cache_dir: str = "data/etf_cache/daily" - ): - """ - 初始化混合数据源 - - Args: - ssh_config: SSH隧道配置 - use_cache: 是否使用缓存 - cache_dir: 缓存目录 - """ - self.ssh_config = ssh_config or {} - self.use_cache = use_cache - self.cache_dir = cache_dir - - # 数据源实例 - self._tushare = TushareSource() - self._yfinance = YFinanceSource() - - # SSH隧道(延迟初始化) - self._tunnel: Optional[SSHTunnelManager] = None - - @classmethod - def from_yaml(cls, config_path: str) -> 'HybridDataSource': - """从YAML配置创建实例""" - import yaml - - with open(config_path, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - - return cls( - ssh_config=config.get('ssh_tunnel', {}), - use_cache=config.get('use_cache', True) - ) - - def _start_tunnel(self) -> bool: - """启动SSH隧道""" - if self._tunnel is None and self.ssh_config.get('enabled'): - self._tunnel = SSHTunnelManager(self.ssh_config) - return self._tunnel.start() - return True - - def _stop_tunnel(self): - """停止SSH隧道""" - if self._tunnel: - self._tunnel.stop() - self._tunnel = None - - def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: - """ - 获取单个标的数据 - - Args: - code: 标的代码 - start_date: 开始日期 - end_date: 结束日期 - - Returns: - DataFrame with OHLCV data - """ - # 判断数据源 - if self._tushare.is_china_index(code) or self._tushare.is_futures(code): - return self._tushare.fetch(code, start_date, end_date) - else: - # YFinance需要SSH隧道 - self._start_tunnel() - return self._yfinance.fetch(code, start_date, end_date) - - def fetch_all( - self, - code_config: dict, - benchmark_code: str = "000300.SH", - start_date: str = "2019-01-01", - end_date: str = None - ) -> Tuple[ - Optional[pd.DataFrame], # index_data: 指数收盘价(宽格式) - Optional[pd.DataFrame], # etf_data: ETF价格(宽格式) - Optional[pd.DataFrame], # etf_nav_data: ETF净值 - Optional[pd.DataFrame], # benchmark_data: 基准数据 - List[str], # valid_codes: 有效代码列表 - Dict[str, pd.DataFrame], # index_ohlcv_data: 原始OHLCV数据 - Dict[str, str] # etf_code_map: {指数代码: ETF代码} 映射 - ]: - """ - 批量获取数据 - - Args: - code_config: 标的配置 {代码: {name, etf, market}} - benchmark_code: 基准代码 - start_date: 开始日期 - end_date: 结束日期 - - Returns: - (index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map) - """ - if end_date is None: - end_date = datetime.now().strftime('%Y-%m-%d') - - # 启动SSH隧道 - self._start_tunnel() - - index_codes = list(code_config.keys()) - etf_codes = {idx_code: cfg['etf'] for idx_code, cfg in code_config.items() if cfg.get('etf')} - - print(f"开始下载 {len(index_codes)} 只标的的数据...") - print(f" 指数代码: {len(index_codes)} 只") - print(f" ETF映射: {len(etf_codes)} 只") - - # 分类统计 - china_codes = [c for c in index_codes if self._tushare.is_china_index(c)] - futures_codes = [c for c in index_codes if self._tushare.is_futures(c)] - yf_codes = [c for c in index_codes if not self._tushare.is_china_index(c) and not self._tushare.is_futures(c)] - - print(f" 中国A股指数: {len(china_codes)} 只") - print(f" 期货合约: {len(futures_codes)} 只") - print(f" 港股/美股: {len(yf_codes)} 只") - - # 下载指数数据 - print("\n [1/2] 下载指数数据...") - index_data_list = [] - index_ohlcv_data = {} - valid_codes = [] - - for code in index_codes: - name = code_config[code].get('name', code) - source = "Tushare" if self._tushare.is_china_index(code) or self._tushare.is_futures(code) else "YFinance" - - print(f" 下载 {code} ({name}) - {source}...", end=" ") - - data = self.fetch_single(code, start_date, end_date) - - if data is not None and len(data) > 0: - # 标准化 - data = data.copy() - data['source'] = source - data['code'] = code - data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize() - - index_ohlcv_data[code] = data.copy() - index_data_list.append(data[['code', 'close', 'source']]) - valid_codes.append(code) - print(f"✓ {len(data)} 条") - else: - print("✗ 无数据") - - # 下载ETF数据 - etf_data_list = [] - etf_nav_data_list = [] - - if etf_codes: - print("\n [2/2] 下载ETF数据...") - - for idx_code, etf_code in etf_codes.items(): - name = code_config[idx_code].get('name', idx_code) - - print(f" 下载ETF {etf_code} (对应指数 {idx_code})...", end=" ") - - # ETF价格 - etf_data = self._tushare.fetch_etf(etf_code, start_date, end_date) - - # ETF净值 - etf_nav = self._tushare.fetch_etf_nav(etf_code, start_date, end_date) - - if etf_data is not None and len(etf_data) > 0: - etf_data.index = pd.to_datetime(etf_data.index, utc=True).tz_localize(None).normalize() - etf_data_list.append(etf_data[['code', 'close']]) - - price_count = len(etf_data) - nav_count = len(etf_nav) if etf_nav is not None else 0 - - print(f"✓ 价格{price_count}条 净值{nav_count}条") - else: - print("✗ 无数据") - - if etf_nav is not None and len(etf_nav) > 0: - etf_nav.index = pd.to_datetime(etf_nav.index, utc=True).tz_localize(None).normalize() - etf_nav_data_list.append(etf_nav[['code', 'nav']]) - - # 整合数据 - index_data = None - if index_data_list: - index_data = pd.concat(index_data_list) - if 'code' in index_data.columns and 'close' in index_data.columns: - index_data = index_data.reset_index() - if 'index' in index_data.columns: - index_data = index_data.rename(columns={'index': 'date'}) - index_data['date'] = pd.to_datetime(index_data['date']).dt.normalize() - index_data = index_data.pivot_table(index='date', columns='code', values='close') - - etf_data = None - if etf_data_list: - etf_data = pd.concat(etf_data_list) - if 'code' in etf_data.columns and 'close' in etf_data.columns: - etf_data = etf_data.reset_index() - if 'index' in etf_data.columns: - etf_data = etf_data.rename(columns={'index': 'date'}) - etf_data['date'] = pd.to_datetime(etf_data['date']).dt.normalize() - etf_data = etf_data.pivot_table(index='date', columns='code', values='close') - - etf_nav_data = None - if etf_nav_data_list: - etf_nav_data = pd.concat(etf_nav_data_list) - if 'code' in etf_nav_data.columns and 'nav' in etf_nav_data.columns: - etf_nav_data = etf_nav_data.reset_index() - if 'index' in etf_nav_data.columns: - etf_nav_data = etf_nav_data.rename(columns={'index': 'date'}) - etf_nav_data['date'] = pd.to_datetime(etf_nav_data['date']).dt.normalize() - etf_nav_data = etf_nav_data.pivot_table(index='date', columns='code', values='nav') - - # 基准数据 - benchmark_data = self._tushare.fetch_index(benchmark_code, start_date, end_date) - if benchmark_data is not None: - benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize() - print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)} 条") - - return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_codes - - def __enter__(self): - self._start_tunnel() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._stop_tunnel() - - -# 简化接口 -def fetch_rotation_data(config_path: str = "strategies/rotation/config.yaml") -> dict: - """ - 获取轮动策略数据(简化接口) - - Args: - config_path: 配置文件路径 - - Returns: - { - 'index_data': 指数收盘价DataFrame, - 'etf_data': ETF价格DataFrame, - 'etf_nav_data': ETF净值DataFrame, - 'benchmark_data': 基准DataFrame, - 'valid_codes': 有效代码列表, - 'index_ohlcv_data': 原始OHLCV数据字典 - } - """ - import yaml - - with open(config_path, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - - source = HybridDataSource.from_yaml(config_path) - - index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = \ - source.fetch_all( - code_config=config.get('code_list', {}), - benchmark_code=config.get('benchmark', {}).get('code', '000300.SH'), - start_date=config.get('start_date', '2019-01-01'), - end_date=config.get('end_date', datetime.now().strftime('%Y-%m-%d')) - ) - - return { - 'index_data': index_data, - 'etf_data': etf_data, - 'etf_nav_data': etf_nav_data, - 'benchmark_data': benchmark_data, - 'valid_codes': valid_codes, - 'index_ohlcv_data': index_ohlcv_data - } \ No newline at end of file diff --git a/datasource/universal_fetcher.py b/datasource/universal_fetcher.py index 15c1e2f..4cf799b 100644 --- a/datasource/universal_fetcher.py +++ b/datasource/universal_fetcher.py @@ -455,4 +455,64 @@ class UniversalDataFetcher: def is_supported(self, code: str) -> bool: """判断是否支持该代码""" - return AssetTypeDetector.detect(code) != AssetType.UNKNOWN \ No newline at end of file + return AssetTypeDetector.detect(code) != AssetType.UNKNOWN + + # ============================================================ + # 扩展层:资产类型特有方法(复权/净值/溢价率) + # ============================================================ + + def fetch_etf_adj( + self, + code: str, + start_date: str, + end_date: str + ) -> Optional[pd.DataFrame]: + """ + 获取 A股 ETF 后复权价格 + + 通过 fund_daily + fund_adj 手动计算后复权价格 + - 消除份额折算(拆分)对收益率的影响 + - 适用于计算真实收益率 + + Args: + code: ETF代码,如 '159915.SZ', '513100.SH' + start_date: 开始日期 'YYYY-MM-DD' + end_date: 结束日期 'YYYY-MM-DD' + + Returns: + DataFrame with columns: date, open, close, adj_factor, close_hfq + + 示例: + # 纳指ETF后复权(正确计算收益率) + df = fetcher.fetch_etf_adj("513100.SH", "2020-01-01", "2024-12-31") + # 使用 close_hfq 计算收益率,而非 close + """ + return self._tushare.fetch_etf_adj(code, start_date, end_date) + + def fetch_us_adj( + self, + code: str, + start_date: str, + end_date: str + ) -> Optional[pd.DataFrame]: + """ + 获取美股复权价格 + + 使用 YFinance auto_adjust=True + - 消除拆分(split)和分红(dividend)对价格的影响 + - 适用于美股股票/ETF + + Args: + code: 美股代码,如 'AAPL', 'TSLA', 'QQQ' + start_date: 开始日期 'YYYY-MM-DD' + end_date: 结束日期 'YYYY-MM-DD' + + Returns: + DataFrame with columns: date, open, high, low, close, volume (复权后) + + 示例: + # 苹果复权价格(包含分红和拆分调整) + df = fetcher.fetch_us_adj("AAPL", "2020-01-01", "2024-12-31") + """ + self._start_tunnel() + return self._yfinance.fetch_adj(code, start_date, end_date) \ No newline at end of file diff --git a/datasource/yfinance_source.py b/datasource/yfinance_source.py index 76d53b3..959369e 100644 --- a/datasource/yfinance_source.py +++ b/datasource/yfinance_source.py @@ -114,6 +114,70 @@ class YFinanceSource: print(f"YFinance下载 {code} ({yf_code}) 失败: {e}") return None + def fetch_adj(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """ + 获取复权价格数据 + + 使用 auto_adjust=True 获取复权后的价格 + - 消除拆分(split)和分红(dividend)对价格的影响 + - 适用于美股股票/ETF + + Args: + code: 代码(如 'AAPL', 'TSLA', 'QQQ') + start_date: 开始日期 'YYYY-MM-DD' + end_date: 结束日期 'YYYY-MM-DD' + + Returns: + DataFrame with columns: date, open, high, low, close, volume (复权后) + """ + import yfinance as yf + + # 添加延迟避免限流 + time.sleep(self._delay) + + # 转换代码格式 + yf_code = self.CODE_MAP.get(code, code) + + try: + ticker = yf.Ticker(yf_code) + + # end_date 需要加一天(yfinance的end是排他的) + end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1) + + # auto_adjust=True 获取复权价格 + df = ticker.history( + start=start_date, + end=end_dt.strftime("%Y-%m-%d"), + auto_adjust=True + ) + + if df is None or len(df) == 0: + return None + + # 标准化列名 + df = df.rename(columns={ + "Open": "open", + "High": "high", + "Low": "low", + "Close": "close", + "Volume": "volume", + }) + + # 确保索引是日期格式 + df.index = pd.to_datetime(df.index, utc=True).tz_localize(None).normalize() + df.index.name = "date" + + # 添加代码列和标记 + df["code"] = code + df.attrs['code'] = code + df.attrs['adjusted'] = True + + return df[['code', 'open', 'high', 'low', 'close', 'volume']] + + except Exception as e: + print(f"YFinance下载复权数据 {code} ({yf_code}) 失败: {e}") + return None + def is_yfinance_code(self, code: str) -> bool: """判断是否需要YFinance获取""" # 非A股代码 diff --git a/strategies/rotation/strategy.py b/strategies/rotation/strategy.py index f60b56f..96e8bbb 100644 --- a/strategies/rotation/strategy.py +++ b/strategies/rotation/strategy.py @@ -5,6 +5,7 @@ """ import pandas as pd +import numpy as np import yaml from datetime import datetime from pathlib import Path @@ -113,7 +114,7 @@ class RotationStrategy(StrategyBase): Args: use_flask_api: 是否使用 Flask API 服务获取数据(默认 True) - False 则使用本地 HybridDataSource + False 则使用本地 UniversalDataFetcher """ code_list_config = self.config.get('code_list', {}) benchmark_config = self.config.get('benchmark', {}) @@ -237,6 +238,12 @@ class RotationStrategy(StrategyBase): index_close_dict[code] = df['close'] index_close = pd.DataFrame(index_close_dict) if index_close_dict else None + # 获取 A 股 SSE 官方交易日历 + from datasource.tushare_source import TushareSource + tushare = TushareSource() + a_share_dates = tushare.fetch_trade_cal(self.start_date, self.end_date) + print(f"A股交易日历: {len(a_share_dates)} 天") + return { 'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame} 'index_close': index_close, # 对齐后的收盘价(宽格式) @@ -245,7 +252,8 @@ class RotationStrategy(StrategyBase): 'etf_premium_data': etf_premium_data, # ETF 溢价率数据 {code: dict} 'benchmark_data': benchmark_data, # 基准收盘价 Series 'valid_codes': valid_codes, # 有效指数代码列表 - 'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射 + 'etf_code_map': etf_code_map, # {指数代码: ETF代码} 映射 + 'a_share_dates': a_share_dates # A股SSE交易日历 } def _get_data_from_local( @@ -253,33 +261,90 @@ class RotationStrategy(StrategyBase): code_list_config: dict, benchmark_code: str ) -> dict: - """使用本地 HybridDataSource 获取数据""" - from datasource import HybridDataSource + """使用本地 UniversalDataFetcher 获取数据""" + from datasource import UniversalDataFetcher + from datasource.tushare_source import TushareSource ssh_config = self.config.get('ssh_tunnel', {}) - data_source = HybridDataSource( + fetcher = UniversalDataFetcher( ssh_config=ssh_config, use_cache=self.config.get('use_cache', True) ) - # 调用 fetch_all - index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map = \ - data_source.fetch_all( - code_config=code_list_config, - benchmark_code=benchmark_code, - start_date=self.start_date, - end_date=self.end_date - ) + index_codes = list(code_list_config.keys()) + etf_code_map = {idx_code: cfg['etf'] for idx_code, cfg in code_list_config.items() if cfg.get('etf')} + + # 获取指数数据 + index_ohlcv_data = {} + valid_codes = [] + + with fetcher: # 使用上下文管理器自动管理 SSH 隧道 + for code in index_codes: + data = fetcher.fetch(code, self.start_date, self.end_date) + if data is not None and len(data) > 0: + index_ohlcv_data[code] = data + valid_codes.append(code) + print(f"✓ {code}: {len(data)} 条") + else: + print(f"✗ {code}: 无数据") + + # 构建宽格式收盘价 + index_close = None + if index_ohlcv_data: + close_list = [] + for code, df in index_ohlcv_data.items(): + close_df = df[['close']].copy() + close_df.columns = [code] + close_list.append(close_df) + index_close = pd.concat(close_list, axis=1) + + # 获取 ETF 数据 + etf_data = None + etf_nav_data = None + + tushare = TushareSource() + + if etf_code_map: + etf_price_list = [] + etf_nav_list = [] + + for idx_code, etf_code in etf_code_map.items(): + # ETF 价格 + etf_df = tushare.fetch_etf(etf_code, self.start_date, self.end_date) + if etf_df is not None and len(etf_df) > 0: + etf_df = etf_df[['close']].copy() + etf_df.columns = [etf_code] + etf_price_list.append(etf_df) + + # ETF 净值 + nav_df = tushare.fetch_etf_nav(etf_code, self.start_date, self.end_date) + if nav_df is not None and len(nav_df) > 0: + nav_df = nav_df[['nav']].copy() + nav_df.columns = [etf_code] + etf_nav_list.append(nav_df) + + if etf_price_list: + etf_data = pd.concat(etf_price_list, axis=1) + if etf_nav_list: + etf_nav_data = pd.concat(etf_nav_list, axis=1) + + # 基准数据 + benchmark_data = tushare.fetch_index(benchmark_code, self.start_date, self.end_date) + + # A股交易日历 + a_share_dates = tushare.fetch_trade_cal(self.start_date, self.end_date) + print(f"A股交易日历: {len(a_share_dates)} 天") return { 'index_data': index_ohlcv_data, # 原始OHLCV数据 - 'index_close': index_data, # 对齐后的收盘价(宽格式) + 'index_close': index_close, # 对齐后的收盘价(宽格式) 'etf_data': etf_data, 'etf_nav_data': etf_nav_data, 'benchmark_data': benchmark_data, 'valid_codes': valid_codes, - 'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射 + 'etf_code_map': etf_code_map, # {指数代码: ETF代码} 映射 + 'a_share_dates': a_share_dates # A股SSE交易日历 } def compute_factors(self, data: dict) -> pd.DataFrame: @@ -290,17 +355,20 @@ class RotationStrategy(StrategyBase): index_data = data['index_data'] valid_codes = data['valid_codes'] - # 获取A股交易日历作为基准(使用已有的对齐后数据索引) - index_close = data.get('index_close') - if index_close is not None: - a_share_dates = index_close.index - else: - for code in valid_codes: - if code.endswith('.SH') or code.endswith('.SZ') or code.endswith('.CSI'): - a_share_dates = index_data[code].index - break + # 获取 A 股 SSE 官方交易日历(优先使用已获取的) + a_share_dates = data.get('a_share_dates') + if a_share_dates is None or len(a_share_dates) == 0: + # 回退:使用已有的对齐后数据索引 + index_close = data.get('index_close') + if index_close is not None: + a_share_dates = index_close.index else: - a_share_dates = index_data[valid_codes[0]].index + for code in valid_codes: + if code.endswith('.SH') or code.endswith('.SZ') or code.endswith('.CSI'): + a_share_dates = index_data[code].index + break + else: + a_share_dates = index_data[valid_codes[0]].index factor_values = {} final_valid_codes = [] @@ -408,8 +476,15 @@ class RotationStrategy(StrategyBase): # 4. 执行回测 print("\n执行回测...") - # 获取A股交易日历(从因子数据索引) - a_share_dates = signals.index + # 获取 A 股 SSE 官方交易日历(优先使用已获取的) + a_share_dates = data.get('a_share_dates') + if a_share_dates is None or len(a_share_dates) == 0: + a_share_dates = signals.index + + # 将信号对齐到 A 股日历 + if a_share_dates is not signals.index: + signals = signals.reindex(a_share_dates, method='ffill').dropna(subset=[signals.columns[0]]) + print(f" 信号对齐到A股日历: {len(signals)} 天") # 计算日收益率:先在原始交易日历计算,再对齐到A股日历 # 关键:与因子计算逻辑一致,避免交易日不对齐导致收益率NaN