New: - rotation/simple_rotation.py: daily-iteration rotation strategy (584 lines) - rotation/config_loader.py: standalone config loader - rotation/config_simple.yaml: 11 assets, 7 groups - rotation/README_SIMPLE.md: usage guide - scripts/get_trading_calendar.py: trading calendar fetcher Removed: - rotation/example_usage.py, run_strategy.py (replaced by simple_rotation.py) - rotation/results/ output files (gitignored) - scripts/verify_*.py, calculate_returns_from_detail.py (one-off scripts) - scripts/README_TRADING_CALENDAR.md Backtest result (2020-01-10 ~ 2026-06-01): - Total return: 1237.6%, Annual: 52.66% - Max drawdown: -11.71%, Sharpe: 2.50
348 lines
10 KiB
Python
348 lines
10 KiB
Python
"""
|
||
获取 A 股交易日历脚本
|
||
|
||
使用 Flask API 交易日历服务获取 A 股交易日历
|
||
支持多市场、多年份的交易日查询
|
||
|
||
用法:
|
||
python scripts/get_trading_calendar.py
|
||
python scripts/get_trading_calendar.py --year 2024
|
||
python scripts/get_trading_calendar.py --start 2024-01-01 --end 2024-12-31
|
||
"""
|
||
|
||
import sys
|
||
import argparse
|
||
from pathlib import Path
|
||
from datetime import datetime, timedelta
|
||
import pandas as pd
|
||
|
||
# 添加项目根目录到路径
|
||
project_root = Path(__file__).parent.parent
|
||
if str(project_root) not in sys.path:
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
# 加载环境变量
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
|
||
# 导入 Flask API 数据源
|
||
from datasource.flask_api_source import FlaskAPIDataSource
|
||
|
||
|
||
def get_calendar_for_year(source: FlaskAPIDataSource, year: int, market: str = 'A'):
|
||
"""
|
||
获取指定年份的交易日历
|
||
|
||
Args:
|
||
source: Flask API 数据源实例
|
||
year: 年份(如 2024)
|
||
market: 市场代码('A', 'US', 'HK')
|
||
|
||
Returns:
|
||
pd.DatetimeIndex: 交易日序列
|
||
"""
|
||
start_date = f"{year}-01-01"
|
||
end_date = f"{year}-12-31"
|
||
|
||
print(f"\n获取 {year} 年 {market} 市场交易日历...")
|
||
|
||
trading_dates = source.get_trading_calendar(
|
||
market=market,
|
||
start_date=start_date,
|
||
end_date=end_date
|
||
)
|
||
|
||
if trading_dates is None or len(trading_dates) == 0:
|
||
print(f"✗ {year} 年 {market} 市场无交易日数据")
|
||
return None
|
||
|
||
return trading_dates
|
||
|
||
|
||
def analyze_calendar(trading_dates: pd.DatetimeIndex, year: int):
|
||
"""
|
||
分析交易日历统计信息
|
||
|
||
Args:
|
||
trading_dates: 交易日序列
|
||
year: 年份
|
||
"""
|
||
if trading_dates is None or len(trading_dates) == 0:
|
||
return
|
||
|
||
print(f"\n{'=' * 60}")
|
||
print(f"{year} 年 A 股交易日历分析")
|
||
print(f"{'=' * 60}")
|
||
|
||
# 基本统计
|
||
total_days = len(trading_dates)
|
||
print(f"\n基本统计:")
|
||
print(f" 总交易日: {total_days} 天")
|
||
print(f" 起始日期: {trading_dates.min().strftime('%Y-%m-%d')}")
|
||
print(f" 结束日期: {trading_dates.max().strftime('%Y-%m-%d')}")
|
||
|
||
# 按月份统计
|
||
print(f"\n按月份统计:")
|
||
monthly_counts = {}
|
||
for date in trading_dates:
|
||
month = date.month
|
||
monthly_counts[month] = monthly_counts.get(month, 0) + 1
|
||
|
||
for month in range(1, 13):
|
||
count = monthly_counts.get(month, 0)
|
||
month_name = datetime(2024, month, 1).strftime('%B')
|
||
print(f" {month:02d}月 ({month_name}): {count} 天")
|
||
|
||
# 按季度统计
|
||
print(f"\n按季度统计:")
|
||
quarterly_counts = {1: 0, 2: 0, 3: 0, 4: 0}
|
||
for date in trading_dates:
|
||
quarter = (date.month - 1) // 3 + 1
|
||
quarterly_counts[quarter] += 1
|
||
|
||
for quarter, count in quarterly_counts.items():
|
||
print(f" Q{quarter}: {count} 天")
|
||
|
||
# 特殊日期统计
|
||
print(f"\n特殊日期:")
|
||
first_date = trading_dates.min()
|
||
last_date = trading_dates.max()
|
||
print(f" 首个交易日: {first_date.strftime('%Y-%m-%d')} ({first_date.strftime('%A')})")
|
||
print(f" 最后交易日: {last_date.strftime('%Y-%m-%d')} ({last_date.strftime('%A')})")
|
||
|
||
# 查找节假日后的首个交易日(通过间隔判断)
|
||
gaps = []
|
||
for i in range(1, len(trading_dates)):
|
||
prev_date = trading_dates[i-1]
|
||
curr_date = trading_dates[i]
|
||
gap_days = (curr_date - prev_date).days
|
||
if gap_days > 3: # 超过3天视为可能节假日
|
||
gaps.append({
|
||
'prev': prev_date,
|
||
'curr': curr_date,
|
||
'gap': gap_days
|
||
})
|
||
|
||
if gaps:
|
||
print(f"\n可能的节假日(间隔 > 3天):")
|
||
for gap_info in gaps[:5]: # 只显示前5个
|
||
print(f" {gap_info['prev'].strftime('%Y-%m-%d')} → {gap_info['curr'].strftime('%Y-%m-%d')} "
|
||
f"(间隔 {gap_info['gap']} 天)")
|
||
|
||
print(f"\n{'=' * 60}")
|
||
|
||
|
||
def compare_markets(source: FlaskAPIDataSource, year: int):
|
||
"""
|
||
比较不同市场的交易日历
|
||
|
||
Args:
|
||
source: Flask API 数据源实例
|
||
year: 年份
|
||
"""
|
||
print(f"\n{'=' * 60}")
|
||
print(f"{year} 年不同市场交易日历对比")
|
||
print(f"{'=' * 60}")
|
||
|
||
markets = {
|
||
'A': 'A股(上交所/深交所)',
|
||
'US': '美股(NYSE)',
|
||
'HK': '港股(HKEX)'
|
||
}
|
||
|
||
results = {}
|
||
for market_code, market_name in markets.items():
|
||
print(f"\n获取 {market_name} 交易日历...")
|
||
trading_dates = get_calendar_for_year(source, year, market_code)
|
||
|
||
if trading_dates is not None and len(trading_dates) > 0:
|
||
results[market_code] = {
|
||
'name': market_name,
|
||
'dates': trading_dates,
|
||
'count': len(trading_dates)
|
||
}
|
||
|
||
# 对比统计
|
||
print(f"\n交易日对比:")
|
||
print(f"{'市场':<20} {'交易日数':<10} {'起始日期':<12} {'结束日期':<12}")
|
||
print("-" * 60)
|
||
|
||
for market_code, data in results.items():
|
||
print(f"{data['name']:<20} {data['count']:<10} "
|
||
f"{data['dates'].min().strftime('%Y-%m-%d'):<12} "
|
||
f"{data['dates'].max().strftime('%Y-%m-%d'):<12}")
|
||
|
||
# 计算差异
|
||
if len(results) >= 2:
|
||
print(f"\n交易日差异:")
|
||
market_codes = list(results.keys())
|
||
for i in range(len(market_codes)):
|
||
for j in range(i+1, len(market_codes)):
|
||
m1 = market_codes[i]
|
||
m2 = market_codes[j]
|
||
diff = results[m1]['count'] - results[m2]['count']
|
||
print(f" {results[m1]['name']} vs {results[m2]['name']}: "
|
||
f"相差 {abs(diff)} 天 ({'+' if diff > 0 else ''}{diff})")
|
||
|
||
print(f"\n{'=' * 60}")
|
||
|
||
|
||
def show_recent_dates(trading_dates: pd.DatetimeIndex, n: int = 10):
|
||
"""
|
||
显示最近的交易日
|
||
|
||
Args:
|
||
trading_dates: 交易日序列
|
||
n: 显示数量
|
||
"""
|
||
if trading_dates is None or len(trading_dates) == 0:
|
||
return
|
||
|
||
print(f"\n最近 {n} 个交易日:")
|
||
recent_dates = trading_dates[-n:] if len(trading_dates) >= n else trading_dates
|
||
|
||
for date in recent_dates:
|
||
weekday = date.strftime('%A')
|
||
print(f" {date.strftime('%Y-%m-%d')} ({weekday})")
|
||
|
||
|
||
def export_calendar(trading_dates: pd.DatetimeIndex, output_path: str, year: int):
|
||
"""
|
||
导出交易日历到 CSV
|
||
|
||
Args:
|
||
trading_dates: 交易日序列
|
||
output_path: 输出路径
|
||
year: 年份
|
||
"""
|
||
if trading_dates is None or len(trading_dates) == 0:
|
||
return
|
||
|
||
# 创建 DataFrame
|
||
df = pd.DataFrame({
|
||
'date': trading_dates,
|
||
'year': trading_dates.year,
|
||
'month': trading_dates.month,
|
||
'quarter': (trading_dates.month - 1) // 3 + 1,
|
||
'weekday': [d.strftime('%A') for d in trading_dates]
|
||
})
|
||
|
||
# 导出到 CSV
|
||
filename = f"{output_path}/trading_calendar_A_{year}.csv"
|
||
df.to_csv(filename, index=False)
|
||
print(f"\n✓ 交易日历已导出到: {filename}")
|
||
print(f" 文件包含 {len(df)} 条记录")
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
parser = argparse.ArgumentParser(description='获取 A 股交易日历')
|
||
|
||
parser.add_argument(
|
||
'--year',
|
||
type=int,
|
||
default=datetime.now().year,
|
||
help='年份(默认当前年份)'
|
||
)
|
||
|
||
parser.add_argument(
|
||
'--start',
|
||
type=str,
|
||
help='起始日期 YYYY-MM-DD'
|
||
)
|
||
|
||
parser.add_argument(
|
||
'--end',
|
||
type=str,
|
||
help='结束日期 YYYY-MM-DD'
|
||
)
|
||
|
||
parser.add_argument(
|
||
'--market',
|
||
type=str,
|
||
default='A',
|
||
choices=['A', 'US', 'HK'],
|
||
help='市场代码(A=A股, US=美股, HK=港股)'
|
||
)
|
||
|
||
parser.add_argument(
|
||
'--compare',
|
||
action='store_true',
|
||
help='对比不同市场交易日历'
|
||
)
|
||
|
||
parser.add_argument(
|
||
'--export',
|
||
action='store_true',
|
||
help='导出交易日历到 CSV'
|
||
)
|
||
|
||
parser.add_argument(
|
||
'--output',
|
||
type=str,
|
||
default='data',
|
||
help='导出目录(默认 data)'
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 初始化 Flask API 数据源
|
||
print("\n初始化 Flask API 数据源...")
|
||
source = FlaskAPIDataSource()
|
||
|
||
# 检查服务健康状态
|
||
health = source.get_health()
|
||
if health.get('status') != 'healthy':
|
||
print(f"✗ Flask API 服务不可用: {health}")
|
||
sys.exit(1)
|
||
|
||
print(f"✓ Flask API 服务可用 ({source.base_url})")
|
||
|
||
# 获取交易日历信息
|
||
calendar_info = source.get_calendar_info()
|
||
if 'error' not in calendar_info:
|
||
print(f"\n交易日历服务信息:")
|
||
print(f" 支持市场: {', '.join(calendar_info.get('markets', []))}")
|
||
print(f" 数据源: {calendar_info.get('source', 'pandas_market_calendars')}")
|
||
|
||
# 执行不同功能
|
||
if args.compare:
|
||
# 对比不同市场
|
||
compare_markets(source, args.year)
|
||
|
||
elif args.start and args.end:
|
||
# 自定义日期范围
|
||
print(f"\n获取 {args.market} 市场交易日历 ({args.start} ~ {args.end})...")
|
||
trading_dates = source.get_trading_calendar(
|
||
market=args.market,
|
||
start_date=args.start,
|
||
end_date=args.end
|
||
)
|
||
|
||
if trading_dates is not None:
|
||
print(f"✓ 获取到 {len(trading_dates)} 个交易日")
|
||
show_recent_dates(trading_dates)
|
||
|
||
if args.export:
|
||
export_calendar(trading_dates, args.output, args.year)
|
||
|
||
else:
|
||
# 获取指定年份交易日历
|
||
trading_dates = get_calendar_for_year(source, args.year, args.market)
|
||
|
||
if trading_dates is not None:
|
||
# 分析统计
|
||
analyze_calendar(trading_dates, args.year)
|
||
|
||
# 显示最近交易日
|
||
show_recent_dates(trading_dates)
|
||
|
||
# 导出
|
||
if args.export:
|
||
export_calendar(trading_dates, args.output, args.year)
|
||
|
||
print("\n✓ 完成!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |