refactor(rotation): 移除数据缓存 + 修复空值和pct_change警告

- 移除CSV本地缓存(cache_dir、_cache_path、_premium_cache_path、_save_premium_cache)
- 每次运行直接从API获取数据,简化DataCache类
- 修复_get_etf_prices中open/close为None时的空值处理(中证指数API不提供OHLC)
- 修复pct_change的FutureWarning(显式传fill_method=None)
- 更新trade_cost注释
This commit is contained in:
2026-06-03 00:54:48 +08:00
parent d1139a9ee9
commit 524fa5f513
2 changed files with 24 additions and 85 deletions

View File

@@ -151,7 +151,7 @@ rotation:
rebalance:
min_hold_days: 1
score_threshold: 0.0
trade_cost: 0.001 # 0.1% 交易成本
trade_cost: 0.001 # 万1 交易成本场内ETF万0.5免5
# ============================================================
# 溢价控制配置

View File

@@ -81,53 +81,18 @@ def is_crash(prices: np.ndarray) -> bool:
# ============================================================
class DataCache:
"""CSV-cached data fetching for index (raw) and ETF (hfq)"""
"""Data fetching (no local cache, always from API)"""
def __init__(self, base_url: str, cache_dir: str = None, timeout: int = 120):
def __init__(self, base_url: str, timeout: int = 120):
self.base_url = base_url.rstrip('/')
self.api_path = '/api/v1/ohlcv'
self.timeout = timeout
if cache_dir is None:
cache_dir = PROJECT_ROOT / 'data' / 'simple_rotation_cache'
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
# premium data cache: {trade_code: {date_str: premium_ratio}}
# premium data in memory: {trade_code: {date_str: premium_ratio}}
self.premium_data: Dict[str, Dict[str, float]] = {}
def _cache_path(self, code: str, adj: str) -> Path:
prefix = 'index' if adj == 'raw' else 'etf'
safe_code = code.replace('=', '_').replace('^', '_')
return self.cache_dir / f"{prefix}_{safe_code}.csv"
def _premium_cache_path(self, code: str) -> Path:
safe_code = code.replace('=', '_').replace('^', '_')
return self.cache_dir / f"premium_{safe_code}.csv"
def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
"""Preload full history and cache to CSV"""
cache_path = self._cache_path(code, adj)
if cache_path.exists():
try:
df = pd.read_csv(cache_path, index_col='date', parse_dates=True)
if len(df) > 0:
cs = df.index.min().strftime('%Y-%m-%d')
ce = df.index.max().strftime('%Y-%m-%d')
if cs <= start_date and ce >= end_date:
return df
new_start = (df.index.max() + timedelta(days=1)).strftime('%Y-%m-%d')
if new_start <= end_date:
new_df = self._fetch_api(code, new_start, end_date, adj)
if new_df is not None and len(new_df) > 0:
df = pd.concat([df, new_df])
df = df[~df.index.duplicated(keep='last')]
df.to_csv(cache_path)
return df
except Exception:
pass
df = self._fetch_api(code, start_date, end_date, adj)
if df is not None and len(df) > 0:
df.to_csv(cache_path)
return df
"""Fetch data from API"""
return self._fetch_api(code, start_date, end_date, adj)
def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]:
"""Fetch from Flask API, also extracts premium_series for ETFs"""
@@ -155,11 +120,13 @@ class DataCache:
df = df.set_index('date').sort_index()
keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns]
df = df[keep]
# Extract and cache premium_series (ETF only)
# Extract premium_series (ETF only) and store in memory
premium_series = data.get('premium_series', [])
if premium_series:
df.attrs['premium_series'] = {item['date']: item['premium'] for item in premium_series}
self._save_premium_cache(code, df.attrs['premium_series'])
if code not in self.premium_data:
self.premium_data[code] = {}
self.premium_data[code].update(df.attrs['premium_series'])
print(f" + {code}: {len(df)} rows ({adj})")
return df
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
@@ -174,50 +141,20 @@ class DataCache:
return None
return None
def _save_premium_cache(self, code: str, premium_dict: Dict[str, float]):
"""Save premium data to CSV cache"""
try:
cache_path = self._premium_cache_path(code)
pd.DataFrame(
[{'date': k, 'premium': v} for k, v in premium_dict.items()]
).to_csv(cache_path, index=False)
except Exception:
pass
def preload_premium(self, code: str, end_date: str = None) -> Optional[Dict[str, float]]:
"""Load premium data for an ETF code from cache, with incremental update.
If cache exists but doesn't cover end_date, fetches the gap."""
"""Load premium data for an ETF code, fetch from API if not in memory."""
if code in self.premium_data:
# Already in memory - check if up-to-date
if end_date:
dates = sorted(self.premium_data[code].keys())
if dates and dates[-1] >= end_date:
return self.premium_data[code]
else:
return self.premium_data[code]
cache_path = self._premium_cache_path(code)
if cache_path.exists():
try:
df = pd.read_csv(cache_path)
if len(df) > 0 and 'date' in df.columns and 'premium' in df.columns:
self.premium_data[code] = dict(zip(df['date'].astype(str), df['premium']))
# Check if cache covers end_date
if end_date:
latest_cached = max(self.premium_data[code].keys())
if latest_cached >= end_date:
return self.premium_data[code]
# Cache is stale - fetch gap from latest_cached+1 to end_date
fetch_start = (pd.Timestamp(latest_cached) + timedelta(days=1)).strftime('%Y-%m-%d')
self._fetch_premium_api(code, fetch_start, end_date)
return self.premium_data[code]
except Exception:
pass
# No cache: fetch full history from API
self._fetch_premium_api(code, '2000-01-01', end_date or datetime.now().strftime('%Y-%m-%d'))
return self.premium_data.get(code)
def _fetch_premium_api(self, code: str, start_date: str, end_date: str):
"""Fetch premium_series from API and merge into cache"""
"""Fetch premium_series from API and store in memory"""
url = f"{self.base_url}{self.api_path}"
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': 'raw'}
for attempt in range(3):
@@ -237,7 +174,6 @@ class DataCache:
if code not in self.premium_data:
self.premium_data[code] = {}
self.premium_data[code].update(new_data)
self._save_premium_cache(code, self.premium_data[code])
print(f" + premium {code}: +{len(new_data)} days (total {len(self.premium_data[code])})")
return
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError):
@@ -474,14 +410,17 @@ class SimpleRotationStrategy:
today_date = recent.index[-1]
if (date - today_date).days > 5:
return None
close = float(today_row['close'])
prev_close = float(prev_row['close'])
# Handle missing open (common for index data like 931862.CSI)
open_price = float(today_row.get('open', close)) if 'open' in today_row.index else close
close_raw = today_row.get('close')
prev_close_raw = prev_row.get('close')
if close_raw is None or pd.isna(close_raw) or prev_close_raw is None or pd.isna(prev_close_raw):
return None
close = float(close_raw)
prev_close = float(prev_close_raw)
# Handle missing/invalid open (common for index data like 931862.CSI)
raw_open = today_row.get('open')
open_price = float(raw_open) if raw_open is not None and not pd.isna(raw_open) else close
if pd.isna(open_price) or open_price == 0:
open_price = close
if pd.isna(close) or pd.isna(prev_close):
return None
return {
'open': open_price,
'close': close,
@@ -995,7 +934,7 @@ class SimpleRotationStrategy:
if self.benchmark_data is not None:
bm_close = self.benchmark_data['close'].reindex(df.index, method='ffill')
if bm_close is not None and not bm_close.isna().all():
benchmark_nav = (1 + bm_close.pct_change()).cumprod()
benchmark_nav = (1 + bm_close.pct_change(fill_method=None)).cumprod()
first_valid = benchmark_nav.dropna().iloc[0] if len(benchmark_nav.dropna()) > 0 else 1
benchmark_nav = benchmark_nav / first_valid
@@ -1005,7 +944,7 @@ class SimpleRotationStrategy:
if code in self.index_data:
price = self.index_data[code]['close'].reindex(df.index, method='ffill')
if price is not None and not price.isna().all():
nav_s = (1 + price.pct_change()).cumprod()
nav_s = (1 + price.pct_change(fill_method=None)).cumprod()
fv = nav_s.dropna().iloc[0] if len(nav_s.dropna()) > 0 else 1
asset_navs[code] = nav_s / fv
@@ -1023,7 +962,7 @@ class SimpleRotationStrategy:
# Benchmark metrics
b_total_return = b_annual = b_sharpe = b_dd = 0
if benchmark_nav is not None:
bm_ret = benchmark_nav.pct_change()
bm_ret = benchmark_nav.pct_change(fill_method=None)
b_total_return = benchmark_nav.iloc[-1] - 1
b_annual = (1 + b_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
b_sharpe = bm_ret.mean() / bm_ret.std() * np.sqrt(252) if bm_ret.std() > 0 else 0