Files
etf/scripts/get_trading_calendar.py
aszerW 451ffa33d2 clean(rotation): add simple rotation strategy and remove unused files
New:
- rotation/simple_rotation.py: daily-iteration rotation strategy (584 lines)
- rotation/config_loader.py: standalone config loader
- rotation/config_simple.yaml: 11 assets, 7 groups
- rotation/README_SIMPLE.md: usage guide
- scripts/get_trading_calendar.py: trading calendar fetcher

Removed:
- rotation/example_usage.py, run_strategy.py (replaced by simple_rotation.py)
- rotation/results/ output files (gitignored)
- scripts/verify_*.py, calculate_returns_from_detail.py (one-off scripts)
- scripts/README_TRADING_CALENDAR.md

Backtest result (2020-01-10 ~ 2026-06-01):
- Total return: 1237.6%, Annual: 52.66%
- Max drawdown: -11.71%, Sharpe: 2.50
2026-06-01 22:28:26 +08:00

348 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
获取 A 股交易日历脚本
使用 Flask API 交易日历服务获取 A 股交易日历
支持多市场、多年份的交易日查询
用法:
python scripts/get_trading_calendar.py
python scripts/get_trading_calendar.py --year 2024
python scripts/get_trading_calendar.py --start 2024-01-01 --end 2024-12-31
"""
import sys
import argparse
from pathlib import Path
from datetime import datetime, timedelta
import pandas as pd
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
# 加载环境变量
from dotenv import load_dotenv
load_dotenv()
# 导入 Flask API 数据源
from datasource.flask_api_source import FlaskAPIDataSource
def get_calendar_for_year(source: FlaskAPIDataSource, year: int, market: str = 'A'):
"""
获取指定年份的交易日历
Args:
source: Flask API 数据源实例
year: 年份(如 2024
market: 市场代码('A', 'US', 'HK'
Returns:
pd.DatetimeIndex: 交易日序列
"""
start_date = f"{year}-01-01"
end_date = f"{year}-12-31"
print(f"\n获取 {year}{market} 市场交易日历...")
trading_dates = source.get_trading_calendar(
market=market,
start_date=start_date,
end_date=end_date
)
if trading_dates is None or len(trading_dates) == 0:
print(f"{year}{market} 市场无交易日数据")
return None
return trading_dates
def analyze_calendar(trading_dates: pd.DatetimeIndex, year: int):
"""
分析交易日历统计信息
Args:
trading_dates: 交易日序列
year: 年份
"""
if trading_dates is None or len(trading_dates) == 0:
return
print(f"\n{'=' * 60}")
print(f"{year} 年 A 股交易日历分析")
print(f"{'=' * 60}")
# 基本统计
total_days = len(trading_dates)
print(f"\n基本统计:")
print(f" 总交易日: {total_days}")
print(f" 起始日期: {trading_dates.min().strftime('%Y-%m-%d')}")
print(f" 结束日期: {trading_dates.max().strftime('%Y-%m-%d')}")
# 按月份统计
print(f"\n按月份统计:")
monthly_counts = {}
for date in trading_dates:
month = date.month
monthly_counts[month] = monthly_counts.get(month, 0) + 1
for month in range(1, 13):
count = monthly_counts.get(month, 0)
month_name = datetime(2024, month, 1).strftime('%B')
print(f" {month:02d}月 ({month_name}): {count}")
# 按季度统计
print(f"\n按季度统计:")
quarterly_counts = {1: 0, 2: 0, 3: 0, 4: 0}
for date in trading_dates:
quarter = (date.month - 1) // 3 + 1
quarterly_counts[quarter] += 1
for quarter, count in quarterly_counts.items():
print(f" Q{quarter}: {count}")
# 特殊日期统计
print(f"\n特殊日期:")
first_date = trading_dates.min()
last_date = trading_dates.max()
print(f" 首个交易日: {first_date.strftime('%Y-%m-%d')} ({first_date.strftime('%A')})")
print(f" 最后交易日: {last_date.strftime('%Y-%m-%d')} ({last_date.strftime('%A')})")
# 查找节假日后的首个交易日(通过间隔判断)
gaps = []
for i in range(1, len(trading_dates)):
prev_date = trading_dates[i-1]
curr_date = trading_dates[i]
gap_days = (curr_date - prev_date).days
if gap_days > 3: # 超过3天视为可能节假日
gaps.append({
'prev': prev_date,
'curr': curr_date,
'gap': gap_days
})
if gaps:
print(f"\n可能的节假日(间隔 > 3天:")
for gap_info in gaps[:5]: # 只显示前5个
print(f" {gap_info['prev'].strftime('%Y-%m-%d')}{gap_info['curr'].strftime('%Y-%m-%d')} "
f"(间隔 {gap_info['gap']} 天)")
print(f"\n{'=' * 60}")
def compare_markets(source: FlaskAPIDataSource, year: int):
"""
比较不同市场的交易日历
Args:
source: Flask API 数据源实例
year: 年份
"""
print(f"\n{'=' * 60}")
print(f"{year} 年不同市场交易日历对比")
print(f"{'=' * 60}")
markets = {
'A': 'A股上交所/深交所)',
'US': '美股NYSE',
'HK': '港股HKEX'
}
results = {}
for market_code, market_name in markets.items():
print(f"\n获取 {market_name} 交易日历...")
trading_dates = get_calendar_for_year(source, year, market_code)
if trading_dates is not None and len(trading_dates) > 0:
results[market_code] = {
'name': market_name,
'dates': trading_dates,
'count': len(trading_dates)
}
# 对比统计
print(f"\n交易日对比:")
print(f"{'市场':<20} {'交易日数':<10} {'起始日期':<12} {'结束日期':<12}")
print("-" * 60)
for market_code, data in results.items():
print(f"{data['name']:<20} {data['count']:<10} "
f"{data['dates'].min().strftime('%Y-%m-%d'):<12} "
f"{data['dates'].max().strftime('%Y-%m-%d'):<12}")
# 计算差异
if len(results) >= 2:
print(f"\n交易日差异:")
market_codes = list(results.keys())
for i in range(len(market_codes)):
for j in range(i+1, len(market_codes)):
m1 = market_codes[i]
m2 = market_codes[j]
diff = results[m1]['count'] - results[m2]['count']
print(f" {results[m1]['name']} vs {results[m2]['name']}: "
f"相差 {abs(diff)} 天 ({'+' if diff > 0 else ''}{diff})")
print(f"\n{'=' * 60}")
def show_recent_dates(trading_dates: pd.DatetimeIndex, n: int = 10):
"""
显示最近的交易日
Args:
trading_dates: 交易日序列
n: 显示数量
"""
if trading_dates is None or len(trading_dates) == 0:
return
print(f"\n最近 {n} 个交易日:")
recent_dates = trading_dates[-n:] if len(trading_dates) >= n else trading_dates
for date in recent_dates:
weekday = date.strftime('%A')
print(f" {date.strftime('%Y-%m-%d')} ({weekday})")
def export_calendar(trading_dates: pd.DatetimeIndex, output_path: str, year: int):
"""
导出交易日历到 CSV
Args:
trading_dates: 交易日序列
output_path: 输出路径
year: 年份
"""
if trading_dates is None or len(trading_dates) == 0:
return
# 创建 DataFrame
df = pd.DataFrame({
'date': trading_dates,
'year': trading_dates.year,
'month': trading_dates.month,
'quarter': (trading_dates.month - 1) // 3 + 1,
'weekday': [d.strftime('%A') for d in trading_dates]
})
# 导出到 CSV
filename = f"{output_path}/trading_calendar_A_{year}.csv"
df.to_csv(filename, index=False)
print(f"\n✓ 交易日历已导出到: {filename}")
print(f" 文件包含 {len(df)} 条记录")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='获取 A 股交易日历')
parser.add_argument(
'--year',
type=int,
default=datetime.now().year,
help='年份(默认当前年份)'
)
parser.add_argument(
'--start',
type=str,
help='起始日期 YYYY-MM-DD'
)
parser.add_argument(
'--end',
type=str,
help='结束日期 YYYY-MM-DD'
)
parser.add_argument(
'--market',
type=str,
default='A',
choices=['A', 'US', 'HK'],
help='市场代码A=A股, US=美股, HK=港股)'
)
parser.add_argument(
'--compare',
action='store_true',
help='对比不同市场交易日历'
)
parser.add_argument(
'--export',
action='store_true',
help='导出交易日历到 CSV'
)
parser.add_argument(
'--output',
type=str,
default='data',
help='导出目录(默认 data'
)
args = parser.parse_args()
# 初始化 Flask API 数据源
print("\n初始化 Flask API 数据源...")
source = FlaskAPIDataSource()
# 检查服务健康状态
health = source.get_health()
if health.get('status') != 'healthy':
print(f"✗ Flask API 服务不可用: {health}")
sys.exit(1)
print(f"✓ Flask API 服务可用 ({source.base_url})")
# 获取交易日历信息
calendar_info = source.get_calendar_info()
if 'error' not in calendar_info:
print(f"\n交易日历服务信息:")
print(f" 支持市场: {', '.join(calendar_info.get('markets', []))}")
print(f" 数据源: {calendar_info.get('source', 'pandas_market_calendars')}")
# 执行不同功能
if args.compare:
# 对比不同市场
compare_markets(source, args.year)
elif args.start and args.end:
# 自定义日期范围
print(f"\n获取 {args.market} 市场交易日历 ({args.start} ~ {args.end})...")
trading_dates = source.get_trading_calendar(
market=args.market,
start_date=args.start,
end_date=args.end
)
if trading_dates is not None:
print(f"✓ 获取到 {len(trading_dates)} 个交易日")
show_recent_dates(trading_dates)
if args.export:
export_calendar(trading_dates, args.output, args.year)
else:
# 获取指定年份交易日历
trading_dates = get_calendar_for_year(source, args.year, args.market)
if trading_dates is not None:
# 分析统计
analyze_calendar(trading_dates, args.year)
# 显示最近交易日
show_recent_dates(trading_dates)
# 导出
if args.export:
export_calendar(trading_dates, args.output, args.year)
print("\n✓ 完成!")
if __name__ == "__main__":
main()