Compare commits
5 Commits
74f0eebef0
...
524fa5f513
| Author | SHA1 | Date | |
|---|---|---|---|
| 524fa5f513 | |||
| d1139a9ee9 | |||
| a2b4289080 | |||
| e29f57749d | |||
| 81045f9d85 |
@@ -7,6 +7,7 @@ Flask API 数据源
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
import requests
|
import requests
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from typing import Optional, Dict, List
|
from typing import Optional, Dict, List
|
||||||
@@ -18,6 +19,19 @@ from .models import OHLCVResponse, validate_ohlcv_response
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# HTTP client (requests + trust_env=False,绕过系统代理避免 SSL EOF)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# Clash 等代理在处理 TLS 1.3 + 后量子密钥交换时会触发 SSL EOF 错误
|
||||||
|
# trust_env=False 让 requests 忽略环境变量中的代理配置,直连目标服务器
|
||||||
|
_session = requests.Session()
|
||||||
|
_session.trust_env = False
|
||||||
|
|
||||||
|
def _http_get(url: str, params: dict = None, timeout: int = 120) -> requests.Response:
|
||||||
|
"""使用 requests 发起 GET 请求(trust_env=False 绕过系统代理)"""
|
||||||
|
return _session.get(url, params=params, timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
class FlaskAPIDataSource:
|
class FlaskAPIDataSource:
|
||||||
"""
|
"""
|
||||||
@@ -110,24 +124,17 @@ class FlaskAPIDataSource:
|
|||||||
|
|
||||||
for attempt in range(self.retries):
|
for attempt in range(self.retries):
|
||||||
try:
|
try:
|
||||||
response = requests.get(
|
response = _http_get(url, params=params, timeout=self.timeout)
|
||||||
url,
|
|
||||||
params=params,
|
|
||||||
timeout=self.timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
if attempt < self.retries - 1:
|
if attempt < self.retries - 1:
|
||||||
|
time.sleep(1 + attempt)
|
||||||
continue
|
continue
|
||||||
print(f"✗ API请求失败: {response.status_code} - {response.text[:100]}")
|
print(f"✗ API请求失败: {response.status_code} - {response.text[:100]}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 尝试解析 JSON(支持 zstd 响应)
|
# 解析 JSON
|
||||||
try:
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
except (json.JSONDecodeError, requests.exceptions.JSONDecodeError):
|
|
||||||
# 如果 response.json() 失败,手动解析
|
|
||||||
data = json.loads(response.text)
|
|
||||||
|
|
||||||
# 检查错误
|
# 检查错误
|
||||||
if 'error' in data:
|
if 'error' in data:
|
||||||
@@ -199,12 +206,22 @@ class FlaskAPIDataSource:
|
|||||||
except requests.exceptions.Timeout:
|
except requests.exceptions.Timeout:
|
||||||
if attempt < self.retries - 1:
|
if attempt < self.retries - 1:
|
||||||
print(f"⚠ {code}: 请求超时,重试 {attempt + 2}/{self.retries}")
|
print(f"⚠ {code}: 请求超时,重试 {attempt + 2}/{self.retries}")
|
||||||
|
time.sleep(1 + attempt)
|
||||||
continue
|
continue
|
||||||
print(f"✗ {code}: 请求超时")
|
print(f"✗ {code}: 请求超时")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
except (requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
|
||||||
|
if attempt < self.retries - 1:
|
||||||
|
print(f"⚠ {code}: {type(e).__name__},重试 {attempt + 2}/{self.retries}")
|
||||||
|
time.sleep(1 + attempt)
|
||||||
|
continue
|
||||||
|
print(f"✗ {code}: {type(e).__name__} after {self.retries} retries")
|
||||||
|
return None
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
if attempt < self.retries - 1:
|
if attempt < self.retries - 1:
|
||||||
|
time.sleep(1 + attempt)
|
||||||
continue
|
continue
|
||||||
print(f"✗ {code}: 请求异常 - {e}")
|
print(f"✗ {code}: 请求异常 - {e}")
|
||||||
return None
|
return None
|
||||||
@@ -277,16 +294,12 @@ class FlaskAPIDataSource:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, params=params, timeout=self.timeout)
|
response = _http_get(url, params=params, timeout=self.timeout)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 处理 zstd 响应
|
|
||||||
try:
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
except (json.JSONDecodeError, requests.exceptions.JSONDecodeError):
|
|
||||||
data = json.loads(response.text)
|
|
||||||
|
|
||||||
if 'error' in data:
|
if 'error' in data:
|
||||||
return None
|
return None
|
||||||
@@ -357,7 +370,7 @@ class FlaskAPIDataSource:
|
|||||||
params = {'code': '000300.SH', 'start': '2024-01-01', 'end': '2024-01-05'}
|
params = {'code': '000300.SH', 'start': '2024-01-01', 'end': '2024-01-05'}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, params=params, timeout=self.timeout)
|
response = _http_get(url, params=params, timeout=self.timeout)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
return {
|
return {
|
||||||
@@ -375,7 +388,7 @@ class FlaskAPIDataSource:
|
|||||||
url = f"{self.base_url}/api/v1/calendar/info"
|
url = f"{self.base_url}/api/v1/calendar/info"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, timeout=10)
|
response = _http_get(url, timeout=10)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return response.json()
|
return response.json()
|
||||||
else:
|
else:
|
||||||
@@ -420,15 +433,12 @@ class FlaskAPIDataSource:
|
|||||||
|
|
||||||
for attempt in range(self.retries):
|
for attempt in range(self.retries):
|
||||||
try:
|
try:
|
||||||
response = requests.get(
|
response = _http_get(url, params=params, timeout=self.timeout)
|
||||||
url,
|
|
||||||
params=params,
|
|
||||||
timeout=self.timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
if attempt < self.retries - 1:
|
if attempt < self.retries - 1:
|
||||||
print(f"⚠ 交易日历请求失败 (HTTP {response.status_code}),重试 {attempt + 2}/{self.retries}")
|
print(f"⚠ 交易日历请求失败 (HTTP {response.status_code}),重试 {attempt + 2}/{self.retries}")
|
||||||
|
time.sleep(1 + attempt)
|
||||||
continue
|
continue
|
||||||
print(f"✗ 交易日历请求失败: HTTP {response.status_code} - {response.text[:100]}")
|
print(f"✗ 交易日历请求失败: HTTP {response.status_code} - {response.text[:100]}")
|
||||||
return None
|
return None
|
||||||
@@ -457,17 +467,27 @@ class FlaskAPIDataSource:
|
|||||||
except requests.exceptions.Timeout:
|
except requests.exceptions.Timeout:
|
||||||
if attempt < self.retries - 1:
|
if attempt < self.retries - 1:
|
||||||
print(f"⚠ 交易日历请求超时,重试 {attempt + 2}/{self.retries}")
|
print(f"⚠ 交易日历请求超时,重试 {attempt + 2}/{self.retries}")
|
||||||
|
time.sleep(1 + attempt)
|
||||||
continue
|
continue
|
||||||
print(f"✗ 交易日历请求超时")
|
print(f"✗ 交易日历请求超时")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
except (requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
|
||||||
|
if attempt < self.retries - 1:
|
||||||
|
print(f"⚠ 交易日历: {type(e).__name__},重试 {attempt + 2}/{self.retries}")
|
||||||
|
time.sleep(1 + attempt)
|
||||||
|
continue
|
||||||
|
print(f"✗ 交易日历: {type(e).__name__} after {self.retries} retries")
|
||||||
|
return None
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
if attempt < self.retries - 1:
|
if attempt < self.retries - 1:
|
||||||
|
time.sleep(1 + attempt)
|
||||||
continue
|
continue
|
||||||
print(f"✗ 交易日历请求异常: {e}")
|
print(f"✗ 交易日历请求异常: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except (json.JSONDecodeError, requests.exceptions.JSONDecodeError) as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"✗ 交易日历 JSON 解析失败: {e}")
|
print(f"✗ 交易日历 JSON 解析失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -478,7 +498,7 @@ class FlaskAPIDataSource:
|
|||||||
url = f"{self.base_url}/"
|
url = f"{self.base_url}/"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, timeout=10)
|
response = _http_get(url, timeout=10)
|
||||||
return response.json()
|
return response.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ rotation:
|
|||||||
rebalance:
|
rebalance:
|
||||||
min_hold_days: 1
|
min_hold_days: 1
|
||||||
score_threshold: 0.0
|
score_threshold: 0.0
|
||||||
trade_cost: 0.001 # 0.1% 交易成本
|
trade_cost: 0.001 # 万1 交易成本(场内ETF万0.5免5)
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# 溢价控制配置
|
# 溢价控制配置
|
||||||
|
|||||||
@@ -26,6 +26,19 @@ sys.path.insert(0, str(PROJECT_ROOT))
|
|||||||
|
|
||||||
from rotation.config_loader import load_rotation_config, RotationStrategyConfig
|
from rotation.config_loader import load_rotation_config, RotationStrategyConfig
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# HTTP client (requests + trust_env=False,绕过系统代理避免 SSL EOF)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# Clash 等代理在处理 TLS 1.3 + 后量子密钥交换时会触发 SSL EOF 错误
|
||||||
|
# trust_env=False 让 requests 忽略环境变量中的代理配置,直连目标服务器
|
||||||
|
_session = requests.Session()
|
||||||
|
_session.trust_env = False
|
||||||
|
|
||||||
|
def _http_get(url: str, params: dict = None, timeout: int = 120) -> requests.Response:
|
||||||
|
"""使用 requests 发起 GET 请求(trust_env=False 绕过系统代理)"""
|
||||||
|
return _session.get(url, params=params, timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Pure functions: momentum
|
# Pure functions: momentum
|
||||||
@@ -68,53 +81,18 @@ def is_crash(prices: np.ndarray) -> bool:
|
|||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
class DataCache:
|
class DataCache:
|
||||||
"""CSV-cached data fetching for index (raw) and ETF (hfq)"""
|
"""Data fetching (no local cache, always from API)"""
|
||||||
|
|
||||||
def __init__(self, base_url: str, cache_dir: str = None, timeout: int = 120):
|
def __init__(self, base_url: str, timeout: int = 120):
|
||||||
self.base_url = base_url.rstrip('/')
|
self.base_url = base_url.rstrip('/')
|
||||||
self.api_path = '/api/v1/ohlcv'
|
self.api_path = '/api/v1/ohlcv'
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
if cache_dir is None:
|
# premium data in memory: {trade_code: {date_str: premium_ratio}}
|
||||||
cache_dir = PROJECT_ROOT / 'data' / 'simple_rotation_cache'
|
|
||||||
self.cache_dir = Path(cache_dir)
|
|
||||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
# premium data cache: {trade_code: {date_str: premium_ratio}}
|
|
||||||
self.premium_data: Dict[str, Dict[str, float]] = {}
|
self.premium_data: Dict[str, Dict[str, float]] = {}
|
||||||
|
|
||||||
def _cache_path(self, code: str, adj: str) -> Path:
|
|
||||||
prefix = 'index' if adj == 'raw' else 'etf'
|
|
||||||
safe_code = code.replace('=', '_').replace('^', '_')
|
|
||||||
return self.cache_dir / f"{prefix}_{safe_code}.csv"
|
|
||||||
|
|
||||||
def _premium_cache_path(self, code: str) -> Path:
|
|
||||||
safe_code = code.replace('=', '_').replace('^', '_')
|
|
||||||
return self.cache_dir / f"premium_{safe_code}.csv"
|
|
||||||
|
|
||||||
def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
|
def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
|
||||||
"""Preload full history and cache to CSV"""
|
"""Fetch data from API"""
|
||||||
cache_path = self._cache_path(code, adj)
|
return self._fetch_api(code, start_date, end_date, adj)
|
||||||
if cache_path.exists():
|
|
||||||
try:
|
|
||||||
df = pd.read_csv(cache_path, index_col='date', parse_dates=True)
|
|
||||||
if len(df) > 0:
|
|
||||||
cs = df.index.min().strftime('%Y-%m-%d')
|
|
||||||
ce = df.index.max().strftime('%Y-%m-%d')
|
|
||||||
if cs <= start_date and ce >= end_date:
|
|
||||||
return df
|
|
||||||
new_start = (df.index.max() + timedelta(days=1)).strftime('%Y-%m-%d')
|
|
||||||
if new_start <= end_date:
|
|
||||||
new_df = self._fetch_api(code, new_start, end_date, adj)
|
|
||||||
if new_df is not None and len(new_df) > 0:
|
|
||||||
df = pd.concat([df, new_df])
|
|
||||||
df = df[~df.index.duplicated(keep='last')]
|
|
||||||
df.to_csv(cache_path)
|
|
||||||
return df
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
df = self._fetch_api(code, start_date, end_date, adj)
|
|
||||||
if df is not None and len(df) > 0:
|
|
||||||
df.to_csv(cache_path)
|
|
||||||
return df
|
|
||||||
|
|
||||||
def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]:
|
def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]:
|
||||||
"""Fetch from Flask API, also extracts premium_series for ETFs"""
|
"""Fetch from Flask API, also extracts premium_series for ETFs"""
|
||||||
@@ -122,7 +100,7 @@ class DataCache:
|
|||||||
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': adj}
|
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': adj}
|
||||||
for attempt in range(3):
|
for attempt in range(3):
|
||||||
try:
|
try:
|
||||||
resp = requests.get(url, params=params, timeout=self.timeout)
|
resp = _http_get(url, params=params, timeout=self.timeout)
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
if attempt < 2:
|
if attempt < 2:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
@@ -142,75 +120,49 @@ class DataCache:
|
|||||||
df = df.set_index('date').sort_index()
|
df = df.set_index('date').sort_index()
|
||||||
keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns]
|
keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns]
|
||||||
df = df[keep]
|
df = df[keep]
|
||||||
# Extract and cache premium_series (ETF only)
|
# Extract premium_series (ETF only) and store in memory
|
||||||
premium_series = data.get('premium_series', [])
|
premium_series = data.get('premium_series', [])
|
||||||
if premium_series:
|
if premium_series:
|
||||||
df.attrs['premium_series'] = {item['date']: item['premium'] for item in premium_series}
|
df.attrs['premium_series'] = {item['date']: item['premium'] for item in premium_series}
|
||||||
self._save_premium_cache(code, df.attrs['premium_series'])
|
if code not in self.premium_data:
|
||||||
|
self.premium_data[code] = {}
|
||||||
|
self.premium_data[code].update(df.attrs['premium_series'])
|
||||||
print(f" + {code}: {len(df)} rows ({adj})")
|
print(f" + {code}: {len(df)} rows ({adj})")
|
||||||
return df
|
return df
|
||||||
except requests.exceptions.Timeout:
|
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
|
||||||
|
# 网络相关错误(超时、SSL、连接断开)都进行重试
|
||||||
if attempt < 2:
|
if attempt < 2:
|
||||||
|
time.sleep(1 + attempt) # 递增延迟: 1s, 2s
|
||||||
continue
|
continue
|
||||||
print(f" x {code}: timeout")
|
print(f" x {code}: {type(e).__name__} after {attempt+1} retries")
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" x {code}: {e}")
|
print(f" x {code}: {e}")
|
||||||
return None
|
return None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _save_premium_cache(self, code: str, premium_dict: Dict[str, float]):
|
|
||||||
"""Save premium data to CSV cache"""
|
|
||||||
try:
|
|
||||||
cache_path = self._premium_cache_path(code)
|
|
||||||
pd.DataFrame(
|
|
||||||
[{'date': k, 'premium': v} for k, v in premium_dict.items()]
|
|
||||||
).to_csv(cache_path, index=False)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def preload_premium(self, code: str, end_date: str = None) -> Optional[Dict[str, float]]:
|
def preload_premium(self, code: str, end_date: str = None) -> Optional[Dict[str, float]]:
|
||||||
"""Load premium data for an ETF code from cache, with incremental update.
|
"""Load premium data for an ETF code, fetch from API if not in memory."""
|
||||||
If cache exists but doesn't cover end_date, fetches the gap."""
|
|
||||||
if code in self.premium_data:
|
if code in self.premium_data:
|
||||||
# Already in memory - check if up-to-date
|
|
||||||
if end_date:
|
if end_date:
|
||||||
dates = sorted(self.premium_data[code].keys())
|
dates = sorted(self.premium_data[code].keys())
|
||||||
if dates and dates[-1] >= end_date:
|
if dates and dates[-1] >= end_date:
|
||||||
return self.premium_data[code]
|
return self.premium_data[code]
|
||||||
else:
|
else:
|
||||||
return self.premium_data[code]
|
return self.premium_data[code]
|
||||||
cache_path = self._premium_cache_path(code)
|
|
||||||
if cache_path.exists():
|
|
||||||
try:
|
|
||||||
df = pd.read_csv(cache_path)
|
|
||||||
if len(df) > 0 and 'date' in df.columns and 'premium' in df.columns:
|
|
||||||
self.premium_data[code] = dict(zip(df['date'].astype(str), df['premium']))
|
|
||||||
# Check if cache covers end_date
|
|
||||||
if end_date:
|
|
||||||
latest_cached = max(self.premium_data[code].keys())
|
|
||||||
if latest_cached >= end_date:
|
|
||||||
return self.premium_data[code]
|
|
||||||
# Cache is stale - fetch gap from latest_cached+1 to end_date
|
|
||||||
fetch_start = (pd.Timestamp(latest_cached) + timedelta(days=1)).strftime('%Y-%m-%d')
|
|
||||||
self._fetch_premium_api(code, fetch_start, end_date)
|
|
||||||
return self.premium_data[code]
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
# No cache: fetch full history from API
|
|
||||||
self._fetch_premium_api(code, '2000-01-01', end_date or datetime.now().strftime('%Y-%m-%d'))
|
self._fetch_premium_api(code, '2000-01-01', end_date or datetime.now().strftime('%Y-%m-%d'))
|
||||||
return self.premium_data.get(code)
|
return self.premium_data.get(code)
|
||||||
|
|
||||||
def _fetch_premium_api(self, code: str, start_date: str, end_date: str):
|
def _fetch_premium_api(self, code: str, start_date: str, end_date: str):
|
||||||
"""Fetch premium_series from API and merge into cache"""
|
"""Fetch premium_series from API and store in memory"""
|
||||||
url = f"{self.base_url}{self.api_path}"
|
url = f"{self.base_url}{self.api_path}"
|
||||||
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': 'raw'}
|
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': 'raw'}
|
||||||
for attempt in range(3):
|
for attempt in range(3):
|
||||||
try:
|
try:
|
||||||
resp = requests.get(url, params=params, timeout=self.timeout)
|
resp = _http_get(url, params=params, timeout=self.timeout)
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
if attempt < 2:
|
if attempt < 2:
|
||||||
time.sleep(1)
|
time.sleep(1 + attempt)
|
||||||
continue
|
continue
|
||||||
return
|
return
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
@@ -222,11 +174,11 @@ class DataCache:
|
|||||||
if code not in self.premium_data:
|
if code not in self.premium_data:
|
||||||
self.premium_data[code] = {}
|
self.premium_data[code] = {}
|
||||||
self.premium_data[code].update(new_data)
|
self.premium_data[code].update(new_data)
|
||||||
self._save_premium_cache(code, self.premium_data[code])
|
|
||||||
print(f" + premium {code}: +{len(new_data)} days (total {len(self.premium_data[code])})")
|
print(f" + premium {code}: +{len(new_data)} days (total {len(self.premium_data[code])})")
|
||||||
return
|
return
|
||||||
except requests.exceptions.Timeout:
|
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError):
|
||||||
if attempt < 2:
|
if attempt < 2:
|
||||||
|
time.sleep(1 + attempt)
|
||||||
continue
|
continue
|
||||||
return
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -238,9 +190,10 @@ class DataCache:
|
|||||||
params = {'market': market, 'start': start_date, 'end': end_date}
|
params = {'market': market, 'start': start_date, 'end': end_date}
|
||||||
for attempt in range(3):
|
for attempt in range(3):
|
||||||
try:
|
try:
|
||||||
resp = requests.get(url, params=params, timeout=self.timeout)
|
resp = _http_get(url, params=params, timeout=self.timeout)
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
if attempt < 2:
|
if attempt < 2:
|
||||||
|
time.sleep(1 + attempt)
|
||||||
continue
|
continue
|
||||||
return None
|
return None
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
@@ -253,8 +206,15 @@ class DataCache:
|
|||||||
result = pd.DatetimeIndex(dates)
|
result = pd.DatetimeIndex(dates)
|
||||||
print(f" + {market}: {len(result)} trading days ({start_date} ~ {end_date})")
|
print(f" + {market}: {len(result)} trading days ({start_date} ~ {end_date})")
|
||||||
return result
|
return result
|
||||||
|
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
|
||||||
|
if attempt < 2:
|
||||||
|
time.sleep(1 + attempt)
|
||||||
|
continue
|
||||||
|
print(f" x calendar: {type(e).__name__} after {attempt+1} retries")
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attempt < 2:
|
if attempt < 2:
|
||||||
|
time.sleep(1 + attempt)
|
||||||
continue
|
continue
|
||||||
print(f" x calendar: {e}")
|
print(f" x calendar: {e}")
|
||||||
return None
|
return None
|
||||||
@@ -450,14 +410,17 @@ class SimpleRotationStrategy:
|
|||||||
today_date = recent.index[-1]
|
today_date = recent.index[-1]
|
||||||
if (date - today_date).days > 5:
|
if (date - today_date).days > 5:
|
||||||
return None
|
return None
|
||||||
close = float(today_row['close'])
|
close_raw = today_row.get('close')
|
||||||
prev_close = float(prev_row['close'])
|
prev_close_raw = prev_row.get('close')
|
||||||
# Handle missing open (common for index data like 931862.CSI)
|
if close_raw is None or pd.isna(close_raw) or prev_close_raw is None or pd.isna(prev_close_raw):
|
||||||
open_price = float(today_row.get('open', close)) if 'open' in today_row.index else close
|
return None
|
||||||
|
close = float(close_raw)
|
||||||
|
prev_close = float(prev_close_raw)
|
||||||
|
# Handle missing/invalid open (common for index data like 931862.CSI)
|
||||||
|
raw_open = today_row.get('open')
|
||||||
|
open_price = float(raw_open) if raw_open is not None and not pd.isna(raw_open) else close
|
||||||
if pd.isna(open_price) or open_price == 0:
|
if pd.isna(open_price) or open_price == 0:
|
||||||
open_price = close
|
open_price = close
|
||||||
if pd.isna(close) or pd.isna(prev_close):
|
|
||||||
return None
|
|
||||||
return {
|
return {
|
||||||
'open': open_price,
|
'open': open_price,
|
||||||
'close': close,
|
'close': close,
|
||||||
@@ -971,7 +934,7 @@ class SimpleRotationStrategy:
|
|||||||
if self.benchmark_data is not None:
|
if self.benchmark_data is not None:
|
||||||
bm_close = self.benchmark_data['close'].reindex(df.index, method='ffill')
|
bm_close = self.benchmark_data['close'].reindex(df.index, method='ffill')
|
||||||
if bm_close is not None and not bm_close.isna().all():
|
if bm_close is not None and not bm_close.isna().all():
|
||||||
benchmark_nav = (1 + bm_close.pct_change()).cumprod()
|
benchmark_nav = (1 + bm_close.pct_change(fill_method=None)).cumprod()
|
||||||
first_valid = benchmark_nav.dropna().iloc[0] if len(benchmark_nav.dropna()) > 0 else 1
|
first_valid = benchmark_nav.dropna().iloc[0] if len(benchmark_nav.dropna()) > 0 else 1
|
||||||
benchmark_nav = benchmark_nav / first_valid
|
benchmark_nav = benchmark_nav / first_valid
|
||||||
|
|
||||||
@@ -981,7 +944,7 @@ class SimpleRotationStrategy:
|
|||||||
if code in self.index_data:
|
if code in self.index_data:
|
||||||
price = self.index_data[code]['close'].reindex(df.index, method='ffill')
|
price = self.index_data[code]['close'].reindex(df.index, method='ffill')
|
||||||
if price is not None and not price.isna().all():
|
if price is not None and not price.isna().all():
|
||||||
nav_s = (1 + price.pct_change()).cumprod()
|
nav_s = (1 + price.pct_change(fill_method=None)).cumprod()
|
||||||
fv = nav_s.dropna().iloc[0] if len(nav_s.dropna()) > 0 else 1
|
fv = nav_s.dropna().iloc[0] if len(nav_s.dropna()) > 0 else 1
|
||||||
asset_navs[code] = nav_s / fv
|
asset_navs[code] = nav_s / fv
|
||||||
|
|
||||||
@@ -999,7 +962,7 @@ class SimpleRotationStrategy:
|
|||||||
# Benchmark metrics
|
# Benchmark metrics
|
||||||
b_total_return = b_annual = b_sharpe = b_dd = 0
|
b_total_return = b_annual = b_sharpe = b_dd = 0
|
||||||
if benchmark_nav is not None:
|
if benchmark_nav is not None:
|
||||||
bm_ret = benchmark_nav.pct_change()
|
bm_ret = benchmark_nav.pct_change(fill_method=None)
|
||||||
b_total_return = benchmark_nav.iloc[-1] - 1
|
b_total_return = benchmark_nav.iloc[-1] - 1
|
||||||
b_annual = (1 + b_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
|
b_annual = (1 + b_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
|
||||||
b_sharpe = bm_ret.mean() / bm_ret.std() * np.sqrt(252) if bm_ret.std() > 0 else 0
|
b_sharpe = bm_ret.mean() / bm_ret.std() * np.sqrt(252) if bm_ret.std() > 0 else 0
|
||||||
|
|||||||
Reference in New Issue
Block a user