refactor(rotation): 移除数据缓存 + 修复空值和pct_change警告

- 移除CSV本地缓存(cache_dir、_cache_path、_premium_cache_path、_save_premium_cache)
- 每次运行直接从API获取数据,简化DataCache类
- 修复_get_etf_prices中open/close为None时的空值处理(中证指数API不提供OHLC)
- 修复pct_change的FutureWarning(显式传fill_method=None)
- 更新trade_cost注释
This commit is contained in:
2026-06-03 00:54:48 +08:00
parent d1139a9ee9
commit 524fa5f513
2 changed files with 24 additions and 85 deletions

View File

@@ -151,7 +151,7 @@ rotation:
rebalance: rebalance:
min_hold_days: 1 min_hold_days: 1
score_threshold: 0.0 score_threshold: 0.0
trade_cost: 0.001 # 0.1% 交易成本 trade_cost: 0.001 # 万1 交易成本场内ETF万0.5免5
# ============================================================ # ============================================================
# 溢价控制配置 # 溢价控制配置

View File

@@ -81,53 +81,18 @@ def is_crash(prices: np.ndarray) -> bool:
# ============================================================ # ============================================================
class DataCache: class DataCache:
"""CSV-cached data fetching for index (raw) and ETF (hfq)""" """Data fetching (no local cache, always from API)"""
def __init__(self, base_url: str, cache_dir: str = None, timeout: int = 120): def __init__(self, base_url: str, timeout: int = 120):
self.base_url = base_url.rstrip('/') self.base_url = base_url.rstrip('/')
self.api_path = '/api/v1/ohlcv' self.api_path = '/api/v1/ohlcv'
self.timeout = timeout self.timeout = timeout
if cache_dir is None: # premium data in memory: {trade_code: {date_str: premium_ratio}}
cache_dir = PROJECT_ROOT / 'data' / 'simple_rotation_cache'
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
# premium data cache: {trade_code: {date_str: premium_ratio}}
self.premium_data: Dict[str, Dict[str, float]] = {} self.premium_data: Dict[str, Dict[str, float]] = {}
def _cache_path(self, code: str, adj: str) -> Path:
prefix = 'index' if adj == 'raw' else 'etf'
safe_code = code.replace('=', '_').replace('^', '_')
return self.cache_dir / f"{prefix}_{safe_code}.csv"
def _premium_cache_path(self, code: str) -> Path:
safe_code = code.replace('=', '_').replace('^', '_')
return self.cache_dir / f"premium_{safe_code}.csv"
def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]: def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
"""Preload full history and cache to CSV""" """Fetch data from API"""
cache_path = self._cache_path(code, adj) return self._fetch_api(code, start_date, end_date, adj)
if cache_path.exists():
try:
df = pd.read_csv(cache_path, index_col='date', parse_dates=True)
if len(df) > 0:
cs = df.index.min().strftime('%Y-%m-%d')
ce = df.index.max().strftime('%Y-%m-%d')
if cs <= start_date and ce >= end_date:
return df
new_start = (df.index.max() + timedelta(days=1)).strftime('%Y-%m-%d')
if new_start <= end_date:
new_df = self._fetch_api(code, new_start, end_date, adj)
if new_df is not None and len(new_df) > 0:
df = pd.concat([df, new_df])
df = df[~df.index.duplicated(keep='last')]
df.to_csv(cache_path)
return df
except Exception:
pass
df = self._fetch_api(code, start_date, end_date, adj)
if df is not None and len(df) > 0:
df.to_csv(cache_path)
return df
def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]: def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]:
"""Fetch from Flask API, also extracts premium_series for ETFs""" """Fetch from Flask API, also extracts premium_series for ETFs"""
@@ -155,11 +120,13 @@ class DataCache:
df = df.set_index('date').sort_index() df = df.set_index('date').sort_index()
keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns] keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns]
df = df[keep] df = df[keep]
# Extract and cache premium_series (ETF only) # Extract premium_series (ETF only) and store in memory
premium_series = data.get('premium_series', []) premium_series = data.get('premium_series', [])
if premium_series: if premium_series:
df.attrs['premium_series'] = {item['date']: item['premium'] for item in premium_series} df.attrs['premium_series'] = {item['date']: item['premium'] for item in premium_series}
self._save_premium_cache(code, df.attrs['premium_series']) if code not in self.premium_data:
self.premium_data[code] = {}
self.premium_data[code].update(df.attrs['premium_series'])
print(f" + {code}: {len(df)} rows ({adj})") print(f" + {code}: {len(df)} rows ({adj})")
return df return df
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e: except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
@@ -174,50 +141,20 @@ class DataCache:
return None return None
return None return None
def _save_premium_cache(self, code: str, premium_dict: Dict[str, float]):
"""Save premium data to CSV cache"""
try:
cache_path = self._premium_cache_path(code)
pd.DataFrame(
[{'date': k, 'premium': v} for k, v in premium_dict.items()]
).to_csv(cache_path, index=False)
except Exception:
pass
def preload_premium(self, code: str, end_date: str = None) -> Optional[Dict[str, float]]: def preload_premium(self, code: str, end_date: str = None) -> Optional[Dict[str, float]]:
"""Load premium data for an ETF code from cache, with incremental update. """Load premium data for an ETF code, fetch from API if not in memory."""
If cache exists but doesn't cover end_date, fetches the gap."""
if code in self.premium_data: if code in self.premium_data:
# Already in memory - check if up-to-date
if end_date: if end_date:
dates = sorted(self.premium_data[code].keys()) dates = sorted(self.premium_data[code].keys())
if dates and dates[-1] >= end_date: if dates and dates[-1] >= end_date:
return self.premium_data[code] return self.premium_data[code]
else: else:
return self.premium_data[code] return self.premium_data[code]
cache_path = self._premium_cache_path(code)
if cache_path.exists():
try:
df = pd.read_csv(cache_path)
if len(df) > 0 and 'date' in df.columns and 'premium' in df.columns:
self.premium_data[code] = dict(zip(df['date'].astype(str), df['premium']))
# Check if cache covers end_date
if end_date:
latest_cached = max(self.premium_data[code].keys())
if latest_cached >= end_date:
return self.premium_data[code]
# Cache is stale - fetch gap from latest_cached+1 to end_date
fetch_start = (pd.Timestamp(latest_cached) + timedelta(days=1)).strftime('%Y-%m-%d')
self._fetch_premium_api(code, fetch_start, end_date)
return self.premium_data[code]
except Exception:
pass
# No cache: fetch full history from API
self._fetch_premium_api(code, '2000-01-01', end_date or datetime.now().strftime('%Y-%m-%d')) self._fetch_premium_api(code, '2000-01-01', end_date or datetime.now().strftime('%Y-%m-%d'))
return self.premium_data.get(code) return self.premium_data.get(code)
def _fetch_premium_api(self, code: str, start_date: str, end_date: str): def _fetch_premium_api(self, code: str, start_date: str, end_date: str):
"""Fetch premium_series from API and merge into cache""" """Fetch premium_series from API and store in memory"""
url = f"{self.base_url}{self.api_path}" url = f"{self.base_url}{self.api_path}"
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': 'raw'} params = {'code': code, 'start': start_date, 'end': end_date, 'adj': 'raw'}
for attempt in range(3): for attempt in range(3):
@@ -237,7 +174,6 @@ class DataCache:
if code not in self.premium_data: if code not in self.premium_data:
self.premium_data[code] = {} self.premium_data[code] = {}
self.premium_data[code].update(new_data) self.premium_data[code].update(new_data)
self._save_premium_cache(code, self.premium_data[code])
print(f" + premium {code}: +{len(new_data)} days (total {len(self.premium_data[code])})") print(f" + premium {code}: +{len(new_data)} days (total {len(self.premium_data[code])})")
return return
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError): except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError):
@@ -474,14 +410,17 @@ class SimpleRotationStrategy:
today_date = recent.index[-1] today_date = recent.index[-1]
if (date - today_date).days > 5: if (date - today_date).days > 5:
return None return None
close = float(today_row['close']) close_raw = today_row.get('close')
prev_close = float(prev_row['close']) prev_close_raw = prev_row.get('close')
# Handle missing open (common for index data like 931862.CSI) if close_raw is None or pd.isna(close_raw) or prev_close_raw is None or pd.isna(prev_close_raw):
open_price = float(today_row.get('open', close)) if 'open' in today_row.index else close return None
close = float(close_raw)
prev_close = float(prev_close_raw)
# Handle missing/invalid open (common for index data like 931862.CSI)
raw_open = today_row.get('open')
open_price = float(raw_open) if raw_open is not None and not pd.isna(raw_open) else close
if pd.isna(open_price) or open_price == 0: if pd.isna(open_price) or open_price == 0:
open_price = close open_price = close
if pd.isna(close) or pd.isna(prev_close):
return None
return { return {
'open': open_price, 'open': open_price,
'close': close, 'close': close,
@@ -995,7 +934,7 @@ class SimpleRotationStrategy:
if self.benchmark_data is not None: if self.benchmark_data is not None:
bm_close = self.benchmark_data['close'].reindex(df.index, method='ffill') bm_close = self.benchmark_data['close'].reindex(df.index, method='ffill')
if bm_close is not None and not bm_close.isna().all(): if bm_close is not None and not bm_close.isna().all():
benchmark_nav = (1 + bm_close.pct_change()).cumprod() benchmark_nav = (1 + bm_close.pct_change(fill_method=None)).cumprod()
first_valid = benchmark_nav.dropna().iloc[0] if len(benchmark_nav.dropna()) > 0 else 1 first_valid = benchmark_nav.dropna().iloc[0] if len(benchmark_nav.dropna()) > 0 else 1
benchmark_nav = benchmark_nav / first_valid benchmark_nav = benchmark_nav / first_valid
@@ -1005,7 +944,7 @@ class SimpleRotationStrategy:
if code in self.index_data: if code in self.index_data:
price = self.index_data[code]['close'].reindex(df.index, method='ffill') price = self.index_data[code]['close'].reindex(df.index, method='ffill')
if price is not None and not price.isna().all(): if price is not None and not price.isna().all():
nav_s = (1 + price.pct_change()).cumprod() nav_s = (1 + price.pct_change(fill_method=None)).cumprod()
fv = nav_s.dropna().iloc[0] if len(nav_s.dropna()) > 0 else 1 fv = nav_s.dropna().iloc[0] if len(nav_s.dropna()) > 0 else 1
asset_navs[code] = nav_s / fv asset_navs[code] = nav_s / fv
@@ -1023,7 +962,7 @@ class SimpleRotationStrategy:
# Benchmark metrics # Benchmark metrics
b_total_return = b_annual = b_sharpe = b_dd = 0 b_total_return = b_annual = b_sharpe = b_dd = 0
if benchmark_nav is not None: if benchmark_nav is not None:
bm_ret = benchmark_nav.pct_change() bm_ret = benchmark_nav.pct_change(fill_method=None)
b_total_return = benchmark_nav.iloc[-1] - 1 b_total_return = benchmark_nav.iloc[-1] - 1
b_annual = (1 + b_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0 b_annual = (1 + b_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
b_sharpe = bm_ret.mean() / bm_ret.std() * np.sqrt(252) if bm_ret.std() > 0 else 0 b_sharpe = bm_ret.mean() / bm_ret.std() * np.sqrt(252) if bm_ret.std() > 0 else 0