diff --git a/rotation/config_simple.yaml b/rotation/config_simple.yaml index 996cc95..b883be1 100644 --- a/rotation/config_simple.yaml +++ b/rotation/config_simple.yaml @@ -151,7 +151,7 @@ rotation: rebalance: min_hold_days: 1 score_threshold: 0.0 - trade_cost: 0.001 # 0.1% 交易成本 + trade_cost: 0.001 # 万1 交易成本(场内ETF万0.5免5) # ============================================================ # 溢价控制配置 diff --git a/rotation/simple_rotation.py b/rotation/simple_rotation.py index b5b21c2..9dde1c3 100644 --- a/rotation/simple_rotation.py +++ b/rotation/simple_rotation.py @@ -81,53 +81,18 @@ def is_crash(prices: np.ndarray) -> bool: # ============================================================ class DataCache: - """CSV-cached data fetching for index (raw) and ETF (hfq)""" + """Data fetching (no local cache, always from API)""" - def __init__(self, base_url: str, cache_dir: str = None, timeout: int = 120): + def __init__(self, base_url: str, timeout: int = 120): self.base_url = base_url.rstrip('/') self.api_path = '/api/v1/ohlcv' self.timeout = timeout - if cache_dir is None: - cache_dir = PROJECT_ROOT / 'data' / 'simple_rotation_cache' - self.cache_dir = Path(cache_dir) - self.cache_dir.mkdir(parents=True, exist_ok=True) - # premium data cache: {trade_code: {date_str: premium_ratio}} + # premium data in memory: {trade_code: {date_str: premium_ratio}} self.premium_data: Dict[str, Dict[str, float]] = {} - def _cache_path(self, code: str, adj: str) -> Path: - prefix = 'index' if adj == 'raw' else 'etf' - safe_code = code.replace('=', '_').replace('^', '_') - return self.cache_dir / f"{prefix}_{safe_code}.csv" - - def _premium_cache_path(self, code: str) -> Path: - safe_code = code.replace('=', '_').replace('^', '_') - return self.cache_dir / f"premium_{safe_code}.csv" - def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]: - """Preload full history and cache to CSV""" - cache_path = self._cache_path(code, adj) - if cache_path.exists(): - try: - df = pd.read_csv(cache_path, index_col='date', parse_dates=True) - if len(df) > 0: - cs = df.index.min().strftime('%Y-%m-%d') - ce = df.index.max().strftime('%Y-%m-%d') - if cs <= start_date and ce >= end_date: - return df - new_start = (df.index.max() + timedelta(days=1)).strftime('%Y-%m-%d') - if new_start <= end_date: - new_df = self._fetch_api(code, new_start, end_date, adj) - if new_df is not None and len(new_df) > 0: - df = pd.concat([df, new_df]) - df = df[~df.index.duplicated(keep='last')] - df.to_csv(cache_path) - return df - except Exception: - pass - df = self._fetch_api(code, start_date, end_date, adj) - if df is not None and len(df) > 0: - df.to_csv(cache_path) - return df + """Fetch data from API""" + return self._fetch_api(code, start_date, end_date, adj) def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]: """Fetch from Flask API, also extracts premium_series for ETFs""" @@ -155,11 +120,13 @@ class DataCache: df = df.set_index('date').sort_index() keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns] df = df[keep] - # Extract and cache premium_series (ETF only) + # Extract premium_series (ETF only) and store in memory premium_series = data.get('premium_series', []) if premium_series: df.attrs['premium_series'] = {item['date']: item['premium'] for item in premium_series} - self._save_premium_cache(code, df.attrs['premium_series']) + if code not in self.premium_data: + self.premium_data[code] = {} + self.premium_data[code].update(df.attrs['premium_series']) print(f" + {code}: {len(df)} rows ({adj})") return df except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e: @@ -174,50 +141,20 @@ class DataCache: return None return None - def _save_premium_cache(self, code: str, premium_dict: Dict[str, float]): - """Save premium data to CSV cache""" - try: - cache_path = self._premium_cache_path(code) - pd.DataFrame( - [{'date': k, 'premium': v} for k, v in premium_dict.items()] - ).to_csv(cache_path, index=False) - except Exception: - pass - def preload_premium(self, code: str, end_date: str = None) -> Optional[Dict[str, float]]: - """Load premium data for an ETF code from cache, with incremental update. - If cache exists but doesn't cover end_date, fetches the gap.""" + """Load premium data for an ETF code, fetch from API if not in memory.""" if code in self.premium_data: - # Already in memory - check if up-to-date if end_date: dates = sorted(self.premium_data[code].keys()) if dates and dates[-1] >= end_date: return self.premium_data[code] else: return self.premium_data[code] - cache_path = self._premium_cache_path(code) - if cache_path.exists(): - try: - df = pd.read_csv(cache_path) - if len(df) > 0 and 'date' in df.columns and 'premium' in df.columns: - self.premium_data[code] = dict(zip(df['date'].astype(str), df['premium'])) - # Check if cache covers end_date - if end_date: - latest_cached = max(self.premium_data[code].keys()) - if latest_cached >= end_date: - return self.premium_data[code] - # Cache is stale - fetch gap from latest_cached+1 to end_date - fetch_start = (pd.Timestamp(latest_cached) + timedelta(days=1)).strftime('%Y-%m-%d') - self._fetch_premium_api(code, fetch_start, end_date) - return self.premium_data[code] - except Exception: - pass - # No cache: fetch full history from API self._fetch_premium_api(code, '2000-01-01', end_date or datetime.now().strftime('%Y-%m-%d')) return self.premium_data.get(code) def _fetch_premium_api(self, code: str, start_date: str, end_date: str): - """Fetch premium_series from API and merge into cache""" + """Fetch premium_series from API and store in memory""" url = f"{self.base_url}{self.api_path}" params = {'code': code, 'start': start_date, 'end': end_date, 'adj': 'raw'} for attempt in range(3): @@ -237,7 +174,6 @@ class DataCache: if code not in self.premium_data: self.premium_data[code] = {} self.premium_data[code].update(new_data) - self._save_premium_cache(code, self.premium_data[code]) print(f" + premium {code}: +{len(new_data)} days (total {len(self.premium_data[code])})") return except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError): @@ -474,14 +410,17 @@ class SimpleRotationStrategy: today_date = recent.index[-1] if (date - today_date).days > 5: return None - close = float(today_row['close']) - prev_close = float(prev_row['close']) - # Handle missing open (common for index data like 931862.CSI) - open_price = float(today_row.get('open', close)) if 'open' in today_row.index else close + close_raw = today_row.get('close') + prev_close_raw = prev_row.get('close') + if close_raw is None or pd.isna(close_raw) or prev_close_raw is None or pd.isna(prev_close_raw): + return None + close = float(close_raw) + prev_close = float(prev_close_raw) + # Handle missing/invalid open (common for index data like 931862.CSI) + raw_open = today_row.get('open') + open_price = float(raw_open) if raw_open is not None and not pd.isna(raw_open) else close if pd.isna(open_price) or open_price == 0: open_price = close - if pd.isna(close) or pd.isna(prev_close): - return None return { 'open': open_price, 'close': close, @@ -995,7 +934,7 @@ class SimpleRotationStrategy: if self.benchmark_data is not None: bm_close = self.benchmark_data['close'].reindex(df.index, method='ffill') if bm_close is not None and not bm_close.isna().all(): - benchmark_nav = (1 + bm_close.pct_change()).cumprod() + benchmark_nav = (1 + bm_close.pct_change(fill_method=None)).cumprod() first_valid = benchmark_nav.dropna().iloc[0] if len(benchmark_nav.dropna()) > 0 else 1 benchmark_nav = benchmark_nav / first_valid @@ -1005,7 +944,7 @@ class SimpleRotationStrategy: if code in self.index_data: price = self.index_data[code]['close'].reindex(df.index, method='ffill') if price is not None and not price.isna().all(): - nav_s = (1 + price.pct_change()).cumprod() + nav_s = (1 + price.pct_change(fill_method=None)).cumprod() fv = nav_s.dropna().iloc[0] if len(nav_s.dropna()) > 0 else 1 asset_navs[code] = nav_s / fv @@ -1023,7 +962,7 @@ class SimpleRotationStrategy: # Benchmark metrics b_total_return = b_annual = b_sharpe = b_dd = 0 if benchmark_nav is not None: - bm_ret = benchmark_nav.pct_change() + bm_ret = benchmark_nav.pct_change(fill_method=None) b_total_return = benchmark_nav.iloc[-1] - 1 b_annual = (1 + b_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0 b_sharpe = bm_ret.mean() / bm_ret.std() * np.sqrt(252) if bm_ret.std() > 0 else 0