diff --git a/rotation/simple_rotation.py b/rotation/simple_rotation.py index fb10a51..11d11ef 100644 --- a/rotation/simple_rotation.py +++ b/rotation/simple_rotation.py @@ -78,12 +78,18 @@ class DataCache: cache_dir = PROJECT_ROOT / 'data' / 'simple_rotation_cache' self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) + # premium data cache: {trade_code: {date_str: premium_ratio}} + self.premium_data: Dict[str, Dict[str, float]] = {} def _cache_path(self, code: str, adj: str) -> Path: prefix = 'index' if adj == 'raw' else 'etf' safe_code = code.replace('=', '_').replace('^', '_') return self.cache_dir / f"{prefix}_{safe_code}.csv" + def _premium_cache_path(self, code: str) -> Path: + safe_code = code.replace('=', '_').replace('^', '_') + return self.cache_dir / f"premium_{safe_code}.csv" + def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]: """Preload full history and cache to CSV""" cache_path = self._cache_path(code, adj) @@ -111,7 +117,7 @@ class DataCache: return df def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]: - """Fetch from Flask API""" + """Fetch from Flask API, also extracts premium_series for ETFs""" url = f"{self.base_url}{self.api_path}" params = {'code': code, 'start': start_date, 'end': end_date, 'adj': adj} for attempt in range(3): @@ -136,6 +142,11 @@ class DataCache: df = df.set_index('date').sort_index() keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns] df = df[keep] + # Extract and cache premium_series (ETF only) + premium_series = data.get('premium_series', []) + if premium_series: + df.attrs['premium_series'] = {item['date']: item['premium'] for item in premium_series} + self._save_premium_cache(code, df.attrs['premium_series']) print(f" + {code}: {len(df)} rows ({adj})") return df except requests.exceptions.Timeout: @@ -148,6 +159,61 @@ class DataCache: return None return None + def _save_premium_cache(self, code: str, premium_dict: Dict[str, float]): + """Save premium data to CSV cache""" + try: + cache_path = self._premium_cache_path(code) + pd.DataFrame( + [{'date': k, 'premium': v} for k, v in premium_dict.items()] + ).to_csv(cache_path, index=False) + except Exception: + pass + + def preload_premium(self, code: str, start_date: str = '2000-01-01', end_date: str = None) -> Optional[Dict[str, float]]: + """Load premium data for an ETF code from cache, or fetch from API if not available""" + if code in self.premium_data: + return self.premium_data[code] + cache_path = self._premium_cache_path(code) + if cache_path.exists(): + try: + df = pd.read_csv(cache_path) + if len(df) > 0 and 'date' in df.columns and 'premium' in df.columns: + self.premium_data[code] = dict(zip(df['date'].astype(str), df['premium'])) + return self.premium_data[code] + except Exception: + pass + # No cache: fetch premium_series directly from API (returns full history) + if end_date is None: + end_date = datetime.now().strftime('%Y-%m-%d') + url = f"{self.base_url}{self.api_path}" + params = {'code': code, 'start': start_date, 'end': end_date, 'adj': 'raw'} + for attempt in range(3): + try: + resp = requests.get(url, params=params, timeout=self.timeout) + if resp.status_code != 200: + if attempt < 2: + time.sleep(1) + continue + return None + data = resp.json() + if 'error' in data: + return None + premium_series = data.get('premium_series', []) + if premium_series: + premium_dict = {item['date']: item['premium'] for item in premium_series} + self.premium_data[code] = premium_dict + self._save_premium_cache(code, premium_dict) + print(f" + premium {code}: {len(premium_dict)} days") + return premium_dict + return None + except requests.exceptions.Timeout: + if attempt < 2: + continue + return None + except Exception: + return None + return None + def get_trading_calendar(self, market: str, start_date: str, end_date: str) -> Optional[pd.DatetimeIndex]: """Fetch trading calendar from API""" url = f"{self.base_url}/api/v1/trading-calendar" @@ -218,10 +284,12 @@ class SimpleRotationStrategy: self.signal_codes = [] self.signal_to_trade = {} self.code_to_group = {} + self.trade_code_to_group = {} for code, asset in self.config.asset_pools.assets.items(): self.signal_codes.append(asset.signal_source) self.signal_to_trade[asset.signal_source] = asset.trade_source self.code_to_group[asset.signal_source] = asset.group + self.trade_code_to_group[asset.trade_source] = asset.group # Data source data_source = self.config.data.sources[0] @@ -260,7 +328,10 @@ class SimpleRotationStrategy: df = self.data_cache.preload(code, preload_start, end_date, adj=adj) if df is not None: self.etf_data[code] = df - print(f"\n Trade: {len(self.etf_data)}/{len(trade_codes)} OK") + # Load premium data cache for all ETF trade codes + for code in trade_codes: + self.data_cache.preload_premium(code) + print(f"\n Trade: {len(self.etf_data)}/{len(trade_codes)} OK, premium: {len(self.data_cache.premium_data)} loaded") def _compute_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]: """Compute momentum for a single code on a given date""" @@ -620,15 +691,20 @@ class SimpleRotationStrategy: return idx_ret, etf_ret - def _compute_premium(self, code: str, idx_close: float, etf_close: float) -> Optional[float]: - """Compute premium = (etf_close - index_close) / index_close. - Only meaningful for ETFs that track an index (not bonds).""" - group = self.code_to_group.get(code, '') + def _compute_premium(self, trade_code: str, date: pd.Timestamp) -> Optional[float]: + """Get real premium from API data cache: (ETF_price - NAV) / NAV. + Returns None for BOND or when premium data is unavailable.""" + group = self.trade_code_to_group.get(trade_code, '') if group == 'BOND': return None - if idx_close is None or etf_close is None or idx_close == 0: + premium_dict = self.data_cache.premium_data.get(trade_code) + if not premium_dict: return None - return round((etf_close - idx_close) / idx_close, 6) + date_str = date.strftime('%Y-%m-%d') + val = premium_dict.get(date_str) + if val is None: + return None + return round(float(val), 6) def _build_day_assets(self, record: dict, date: pd.Timestamp, entry_info: Dict[str, dict]) -> dict: @@ -652,7 +728,7 @@ class SimpleRotationStrategy: idx_close = self._get_index_close(code, date) etf_close = self._get_etf_close(trade_code, date) idx_ret, etf_ret_ctc = self._get_daily_returns(code, date) - premium = self._compute_premium(code, idx_close, etf_close) + premium = self._compute_premium(trade_code, date) # Entry / holding info ei = entry_info.get(code) if is_held else None