feat: 为 FlaskAPIDataSource 添加交易日历获取功能
- 新增 get_trading_calendar() 方法,支持 A股/美股/港股 - 新增 get_calendar_info() 方法,获取服务信息 - 支持自动重试、超时保护、详细错误提示 - 返回标准 DatetimeIndex 格式 - 添加端到端测试验证所有市场
This commit is contained in:
@@ -370,6 +370,109 @@ class FlaskAPIDataSource:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {'status': 'error', 'message': str(e), 'available': False}
|
return {'status': 'error', 'message': str(e), 'available': False}
|
||||||
|
|
||||||
|
def get_calendar_info(self) -> Dict:
|
||||||
|
"""获取交易日历服务信息"""
|
||||||
|
url = f"{self.base_url}/api/v1/calendar/info"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, timeout=10)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
return {"error": f"HTTP {response.status_code}"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
def get_trading_calendar(
|
||||||
|
self,
|
||||||
|
market: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str
|
||||||
|
) -> Optional[pd.DatetimeIndex]:
|
||||||
|
"""
|
||||||
|
获取交易日历
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market: 市场代码
|
||||||
|
- 'A' 或 'china': A股(上交所/深交所,交易日历一致)
|
||||||
|
- 'US' 或 'us': 美股(NYSE)
|
||||||
|
- 'HK' 或 'hk': 港股(HKEX)
|
||||||
|
start_date: 开始日期 YYYY-MM-DD
|
||||||
|
end_date: 结束日期 YYYY-MM-DD
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatetimeIndex: 交易日日期序列,失败返回 None
|
||||||
|
|
||||||
|
示例:
|
||||||
|
# 获取 A 股 2024 年 1 月交易日历
|
||||||
|
dates = source.get_trading_calendar('A', '2024-01-01', '2024-01-31')
|
||||||
|
|
||||||
|
# 获取美股交易日历
|
||||||
|
dates = source.get_trading_calendar('US', '2024-01-01', '2024-01-15')
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/api/v1/trading-calendar"
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'market': market,
|
||||||
|
'start': start_date,
|
||||||
|
'end': end_date
|
||||||
|
}
|
||||||
|
|
||||||
|
for attempt in range(self.retries):
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
url,
|
||||||
|
params=params,
|
||||||
|
timeout=self.timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
if attempt < self.retries - 1:
|
||||||
|
print(f"⚠ 交易日历请求失败 (HTTP {response.status_code}),重试 {attempt + 2}/{self.retries}")
|
||||||
|
continue
|
||||||
|
print(f"✗ 交易日历请求失败: HTTP {response.status_code} - {response.text[:100]}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# 检查错误
|
||||||
|
if 'error' in data:
|
||||||
|
print(f"✗ 交易日历获取失败: {data['error']}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 解析交易日期
|
||||||
|
trading_dates = data.get('trading_dates', [])
|
||||||
|
if not trading_dates:
|
||||||
|
print(f"⚠ 市场 {market} 在 {start_date} ~ {end_date} 期间无交易日")
|
||||||
|
return pd.DatetimeIndex([])
|
||||||
|
|
||||||
|
# 转换为 DatetimeIndex
|
||||||
|
dates = pd.DatetimeIndex(trading_dates)
|
||||||
|
count = data.get('count', len(dates))
|
||||||
|
exchange = data.get('exchange', '')
|
||||||
|
|
||||||
|
print(f"✓ {market} ({exchange}): {count} 个交易日 ({start_date} ~ {end_date})")
|
||||||
|
return dates
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
if attempt < self.retries - 1:
|
||||||
|
print(f"⚠ 交易日历请求超时,重试 {attempt + 2}/{self.retries}")
|
||||||
|
continue
|
||||||
|
print(f"✗ 交易日历请求超时")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
if attempt < self.retries - 1:
|
||||||
|
continue
|
||||||
|
print(f"✗ 交易日历请求异常: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, requests.exceptions.JSONDecodeError) as e:
|
||||||
|
print(f"✗ 交易日历 JSON 解析失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def get_service_info(self) -> Dict:
|
def get_service_info(self) -> Dict:
|
||||||
"""获取服务信息"""
|
"""获取服务信息"""
|
||||||
url = f"{self.base_url}/"
|
url = f"{self.base_url}/"
|
||||||
|
|||||||
95
tests/test_flask_api_calendar.py
Normal file
95
tests/test_flask_api_calendar.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
"""
|
||||||
|
测试 FlaskAPIDataSource 的交易日历获取功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 添加项目根目录到路径
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from datasource.flask_api_source import get_flask_api_source
|
||||||
|
|
||||||
|
|
||||||
|
def test_trading_calendar():
|
||||||
|
"""测试交易日历获取"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("测试 FlaskAPIDataSource 交易日历获取")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
source = get_flask_api_source()
|
||||||
|
|
||||||
|
# 测试 1: 获取服务信息
|
||||||
|
print("\n[测试 1] 获取服务信息")
|
||||||
|
info = source.get_service_info()
|
||||||
|
if 'error' not in info:
|
||||||
|
print(f"✓ 服务名称: {info.get('name', 'N/A')}")
|
||||||
|
print(f"✓ API 版本: {info.get('version', 'N/A')}")
|
||||||
|
else:
|
||||||
|
print(f"✗ 服务信息获取失败: {info['error']}")
|
||||||
|
|
||||||
|
# 测试 2: 获取日历信息
|
||||||
|
print("\n[测试 2] 获取日历信息")
|
||||||
|
cal_info = source.get_calendar_info()
|
||||||
|
if 'error' not in cal_info:
|
||||||
|
print(f"✓ pandas_market_calendars 已安装: {cal_info.get('pandas_market_calendars_installed')}")
|
||||||
|
print(f"✓ 支持的市场: {list(cal_info.get('supported_markets', {}).keys())}")
|
||||||
|
else:
|
||||||
|
print(f"✗ 日历信息获取失败: {cal_info['error']}")
|
||||||
|
|
||||||
|
# 测试 3: A 股交易日历
|
||||||
|
print("\n[测试 3] A 股交易日历 (2024-01)")
|
||||||
|
dates_a = source.get_trading_calendar('A', '2024-01-01', '2024-01-31')
|
||||||
|
if dates_a is not None:
|
||||||
|
print(f"✓ 返回 {len(dates_a)} 个交易日")
|
||||||
|
print(f" 首个交易日: {dates_a[0].strftime('%Y-%m-%d')}")
|
||||||
|
print(f" 最后交易日: {dates_a[-1].strftime('%Y-%m-%d')}")
|
||||||
|
else:
|
||||||
|
print("✗ A 股交易日历获取失败")
|
||||||
|
|
||||||
|
# 测试 4: 美股交易日历
|
||||||
|
print("\n[测试 4] 美股交易日历 (2024-01)")
|
||||||
|
dates_us = source.get_trading_calendar('US', '2024-01-01', '2024-01-15')
|
||||||
|
if dates_us is not None:
|
||||||
|
print(f"✓ 返回 {len(dates_us)} 个交易日")
|
||||||
|
print(f" 首个交易日: {dates_us[0].strftime('%Y-%m-%d')}")
|
||||||
|
print(f" 最后交易日: {dates_us[-1].strftime('%Y-%m-%d')}")
|
||||||
|
# 验证马丁·路德·金日(1月15日)是否被排除
|
||||||
|
mlk_day = '2024-01-15'
|
||||||
|
if mlk_day not in [d.strftime('%Y-%m-%d') for d in dates_us]:
|
||||||
|
print(f" ✓ 正确识别马丁·路德·金日休市")
|
||||||
|
else:
|
||||||
|
print("✗ 美股交易日历获取失败")
|
||||||
|
|
||||||
|
# 测试 5: 港股交易日历
|
||||||
|
print("\n[测试 5] 港股交易日历 (2024-01)")
|
||||||
|
dates_hk = source.get_trading_calendar('HK', '2024-01-01', '2024-01-15')
|
||||||
|
if dates_hk is not None:
|
||||||
|
print(f"✓ 返回 {len(dates_hk)} 个交易日")
|
||||||
|
print(f" 首个交易日: {dates_hk[0].strftime('%Y-%m-%d')}")
|
||||||
|
print(f" 最后交易日: {dates_hk[-1].strftime('%Y-%m-%d')}")
|
||||||
|
else:
|
||||||
|
print("✗ 港股交易日历获取失败")
|
||||||
|
|
||||||
|
# 测试 6: 验证 OHLCV 功能仍然正常
|
||||||
|
print("\n[测试 6] 验证 OHLCV 数据获取")
|
||||||
|
df = source.fetch('518880.SH', '2024-01-01', '2024-01-10')
|
||||||
|
if df is not None and len(df) > 0:
|
||||||
|
print(f"✓ 获取 {len(df)} 条 OHLCV 数据")
|
||||||
|
else:
|
||||||
|
print("✗ OHLCV 数据获取失败")
|
||||||
|
|
||||||
|
# 总结
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
success_count = sum([
|
||||||
|
dates_a is not None,
|
||||||
|
dates_us is not None,
|
||||||
|
dates_hk is not None,
|
||||||
|
df is not None
|
||||||
|
])
|
||||||
|
print(f"测试完成: {success_count}/4 通过")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_trading_calendar()
|
||||||
Reference in New Issue
Block a user