diff --git a/datasource/flask_api_source.py b/datasource/flask_api_source.py index 2ca69d0..393666f 100644 --- a/datasource/flask_api_source.py +++ b/datasource/flask_api_source.py @@ -370,6 +370,109 @@ class FlaskAPIDataSource: except Exception as e: return {'status': 'error', 'message': str(e), 'available': False} + def get_calendar_info(self) -> Dict: + """获取交易日历服务信息""" + url = f"{self.base_url}/api/v1/calendar/info" + + try: + response = requests.get(url, timeout=10) + if response.status_code == 200: + return response.json() + else: + return {"error": f"HTTP {response.status_code}"} + except Exception as e: + return {"error": str(e)} + + def get_trading_calendar( + self, + market: str, + start_date: str, + end_date: str + ) -> Optional[pd.DatetimeIndex]: + """ + 获取交易日历 + + Args: + market: 市场代码 + - 'A' 或 'china': A股(上交所/深交所,交易日历一致) + - 'US' 或 'us': 美股(NYSE) + - 'HK' 或 'hk': 港股(HKEX) + start_date: 开始日期 YYYY-MM-DD + end_date: 结束日期 YYYY-MM-DD + + Returns: + DatetimeIndex: 交易日日期序列,失败返回 None + + 示例: + # 获取 A 股 2024 年 1 月交易日历 + dates = source.get_trading_calendar('A', '2024-01-01', '2024-01-31') + + # 获取美股交易日历 + dates = source.get_trading_calendar('US', '2024-01-01', '2024-01-15') + """ + url = f"{self.base_url}/api/v1/trading-calendar" + + params = { + 'market': market, + 'start': start_date, + 'end': end_date + } + + for attempt in range(self.retries): + try: + response = requests.get( + url, + params=params, + timeout=self.timeout + ) + + if response.status_code != 200: + if attempt < self.retries - 1: + print(f"⚠ 交易日历请求失败 (HTTP {response.status_code}),重试 {attempt + 2}/{self.retries}") + continue + print(f"✗ 交易日历请求失败: HTTP {response.status_code} - {response.text[:100]}") + return None + + data = response.json() + + # 检查错误 + if 'error' in data: + print(f"✗ 交易日历获取失败: {data['error']}") + return None + + # 解析交易日期 + trading_dates = data.get('trading_dates', []) + if not trading_dates: + print(f"⚠ 市场 {market} 在 {start_date} ~ {end_date} 期间无交易日") + return pd.DatetimeIndex([]) + + # 转换为 DatetimeIndex + dates = pd.DatetimeIndex(trading_dates) + count = data.get('count', len(dates)) + exchange = data.get('exchange', '') + + print(f"✓ {market} ({exchange}): {count} 个交易日 ({start_date} ~ {end_date})") + return dates + + except requests.exceptions.Timeout: + if attempt < self.retries - 1: + print(f"⚠ 交易日历请求超时,重试 {attempt + 2}/{self.retries}") + continue + print(f"✗ 交易日历请求超时") + return None + + except requests.exceptions.RequestException as e: + if attempt < self.retries - 1: + continue + print(f"✗ 交易日历请求异常: {e}") + return None + + except (json.JSONDecodeError, requests.exceptions.JSONDecodeError) as e: + print(f"✗ 交易日历 JSON 解析失败: {e}") + return None + + return None + def get_service_info(self) -> Dict: """获取服务信息""" url = f"{self.base_url}/" diff --git a/tests/test_flask_api_calendar.py b/tests/test_flask_api_calendar.py new file mode 100644 index 0000000..a0e641b --- /dev/null +++ b/tests/test_flask_api_calendar.py @@ -0,0 +1,95 @@ +""" +测试 FlaskAPIDataSource 的交易日历获取功能 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from datasource.flask_api_source import get_flask_api_source + + +def test_trading_calendar(): + """测试交易日历获取""" + print("=" * 60) + print("测试 FlaskAPIDataSource 交易日历获取") + print("=" * 60) + + source = get_flask_api_source() + + # 测试 1: 获取服务信息 + print("\n[测试 1] 获取服务信息") + info = source.get_service_info() + if 'error' not in info: + print(f"✓ 服务名称: {info.get('name', 'N/A')}") + print(f"✓ API 版本: {info.get('version', 'N/A')}") + else: + print(f"✗ 服务信息获取失败: {info['error']}") + + # 测试 2: 获取日历信息 + print("\n[测试 2] 获取日历信息") + cal_info = source.get_calendar_info() + if 'error' not in cal_info: + print(f"✓ pandas_market_calendars 已安装: {cal_info.get('pandas_market_calendars_installed')}") + print(f"✓ 支持的市场: {list(cal_info.get('supported_markets', {}).keys())}") + else: + print(f"✗ 日历信息获取失败: {cal_info['error']}") + + # 测试 3: A 股交易日历 + print("\n[测试 3] A 股交易日历 (2024-01)") + dates_a = source.get_trading_calendar('A', '2024-01-01', '2024-01-31') + if dates_a is not None: + print(f"✓ 返回 {len(dates_a)} 个交易日") + print(f" 首个交易日: {dates_a[0].strftime('%Y-%m-%d')}") + print(f" 最后交易日: {dates_a[-1].strftime('%Y-%m-%d')}") + else: + print("✗ A 股交易日历获取失败") + + # 测试 4: 美股交易日历 + print("\n[测试 4] 美股交易日历 (2024-01)") + dates_us = source.get_trading_calendar('US', '2024-01-01', '2024-01-15') + if dates_us is not None: + print(f"✓ 返回 {len(dates_us)} 个交易日") + print(f" 首个交易日: {dates_us[0].strftime('%Y-%m-%d')}") + print(f" 最后交易日: {dates_us[-1].strftime('%Y-%m-%d')}") + # 验证马丁·路德·金日(1月15日)是否被排除 + mlk_day = '2024-01-15' + if mlk_day not in [d.strftime('%Y-%m-%d') for d in dates_us]: + print(f" ✓ 正确识别马丁·路德·金日休市") + else: + print("✗ 美股交易日历获取失败") + + # 测试 5: 港股交易日历 + print("\n[测试 5] 港股交易日历 (2024-01)") + dates_hk = source.get_trading_calendar('HK', '2024-01-01', '2024-01-15') + if dates_hk is not None: + print(f"✓ 返回 {len(dates_hk)} 个交易日") + print(f" 首个交易日: {dates_hk[0].strftime('%Y-%m-%d')}") + print(f" 最后交易日: {dates_hk[-1].strftime('%Y-%m-%d')}") + else: + print("✗ 港股交易日历获取失败") + + # 测试 6: 验证 OHLCV 功能仍然正常 + print("\n[测试 6] 验证 OHLCV 数据获取") + df = source.fetch('518880.SH', '2024-01-01', '2024-01-10') + if df is not None and len(df) > 0: + print(f"✓ 获取 {len(df)} 条 OHLCV 数据") + else: + print("✗ OHLCV 数据获取失败") + + # 总结 + print("\n" + "=" * 60) + success_count = sum([ + dates_a is not None, + dates_us is not None, + dates_hk is not None, + df is not None + ]) + print(f"测试完成: {success_count}/4 通过") + print("=" * 60) + + +if __name__ == '__main__': + test_trading_calendar()