Compare commits
8 Commits
70515ab169
...
1d3483bc02
| Author | SHA1 | Date | |
|---|---|---|---|
| 1d3483bc02 | |||
| 4cee249823 | |||
| 2fba6d82f4 | |||
| 5c98b1cb6a | |||
| 50032d628f | |||
| c36044a1d6 | |||
| aeb95a6f4c | |||
| 0a8d0d9212 |
@@ -124,7 +124,7 @@ portfolio = executor.execute(signals, data)
|
||||
├── docker-compose.yml # Docker部署
|
||||
├── Dockerfile # 应用镜像
|
||||
├── Dockerfile_base # 基础镜像
|
||||
├── config/hk_ecs.pem # SSH密钥(港美股数据隧道)
|
||||
├── hk_ecs.pem # SSH密钥(港美股数据隧道)
|
||||
├── README.md # 本文件
|
||||
└── requirements.txt # 依赖
|
||||
```
|
||||
|
||||
@@ -82,7 +82,7 @@ echo ""
|
||||
echo "可以使用以下命令运行容器:"
|
||||
echo "docker run -d --name etf-scheduler-container \"
|
||||
echo " -v /path/to/.env:/app/.env \"
|
||||
echo " -v /path/to/config/hk_ecs.pem:/app/config/hk_ecs.pem \"
|
||||
echo " -v /path/to/hk_ecs.pem:/app/hk_ecs.pem \"
|
||||
echo " -v /path/to/data:/app/data \"
|
||||
echo " ${FULL_IMAGE_NAME}"
|
||||
echo ""
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"""
|
||||
ETF策略项目 - 通用配置
|
||||
|
||||
敏感信息通过环境变量读取,非敏感配置直接定义
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 加载 .env 文件
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except ImportError:
|
||||
pass # python-dotenv 未安装时跳过
|
||||
|
||||
# 项目根目录
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
|
||||
# 数据目录
|
||||
DATA_DIR = PROJECT_ROOT / "data"
|
||||
DATA_CACHE_DIR = PROJECT_ROOT / "data_cache"
|
||||
|
||||
# 确保目录存在
|
||||
DATA_CACHE_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
# ==================== 钉钉配置 ====================
|
||||
def get_dingtalk_config() -> dict:
|
||||
"""从环境变量获取钉钉配置(默认群1)"""
|
||||
return {
|
||||
"webhook": os.getenv("DINGTALK_WEBHOOK", ""),
|
||||
"secret": os.getenv("DINGTALK_SECRET", ""),
|
||||
}
|
||||
|
||||
|
||||
def get_all_dingtalk_configs() -> list[dict]:
|
||||
"""获取所有已配置的钉钉群配置列表"""
|
||||
configs = []
|
||||
# 群1(主群)
|
||||
cfg1 = get_dingtalk_config()
|
||||
if cfg1["webhook"]:
|
||||
configs.append(cfg1)
|
||||
# 群2 及后续扩展:DINGTALK_WEBHOOK_2, _3, ...
|
||||
for i in range(2, 10):
|
||||
webhook = os.getenv(f"DINGTALK_WEBHOOK_{i}", "")
|
||||
secret = os.getenv(f"DINGTALK_SECRET_{i}", "")
|
||||
if webhook:
|
||||
configs.append({"webhook": webhook, "secret": secret})
|
||||
return configs
|
||||
|
||||
|
||||
# ==================== 数据库配置 ====================
|
||||
def get_db_config() -> dict:
|
||||
"""从环境变量获取数据库配置"""
|
||||
return {
|
||||
"host": os.getenv("DB_HOST", "192.168.0.115"),
|
||||
"port": int(os.getenv("DB_PORT", "5432")),
|
||||
"database": os.getenv("DB_NAME", "etf_db"),
|
||||
"username": os.getenv("DB_USER", "admin"),
|
||||
"password": os.getenv("DB_PASS", "admin"),
|
||||
}
|
||||
|
||||
|
||||
# ==================== 代码映射(默认,可被策略配置覆盖)====================
|
||||
DEFAULT_CODE_NAME_MAP = {
|
||||
# 宽基
|
||||
"000300.SH": "沪深300",
|
||||
"000905.SH": "中证500",
|
||||
"000852.SH": "中证1000",
|
||||
"399006.SZ": "创业板指",
|
||||
"000015.SH": "上证红利",
|
||||
# 金融
|
||||
"399986.SZ": "中证银行",
|
||||
"399975.SZ": "证券公司",
|
||||
"000934.SH": "中证金融",
|
||||
# 消费
|
||||
"000932.SH": "中证消费",
|
||||
"399997.SZ": "中证白酒",
|
||||
# 医药
|
||||
"000933.SH": "中证医药",
|
||||
"399989.SZ": "中证医疗",
|
||||
# 科技
|
||||
"000935.SH": "中证信息",
|
||||
"399971.SZ": "中证传媒",
|
||||
# 新能源
|
||||
"399808.SZ": "中证新能源",
|
||||
"399976.SZ": "新能源车",
|
||||
# 周期
|
||||
"399395.SZ": "国证有色",
|
||||
"399440.SZ": "中证钢铁",
|
||||
"399998.SZ": "中证煤炭",
|
||||
"399813.SZ": "细分化工",
|
||||
"000937.SH": "中证能源",
|
||||
"000938.SH": "中证材料",
|
||||
# 其他
|
||||
"399967.SZ": "中证军工",
|
||||
"399393.SZ": "国证地产",
|
||||
"000827.SH": "中证环保",
|
||||
"399995.SZ": "中证基建",
|
||||
"000949.SH": "中证农业",
|
||||
"399702.SZ": "中证国债指数",
|
||||
}
|
||||
|
||||
# 基准指数(默认,可被策略配置覆盖)
|
||||
DEFAULT_BENCHMARK_CODE = "000300.SH"
|
||||
DEFAULT_BENCHMARK_NAME = "沪深300指数"
|
||||
@@ -1,27 +0,0 @@
|
||||
# CCI技术指标筛选配置
|
||||
|
||||
# ==================== 数据源配置 ====================
|
||||
# 数据来源: "postgresql" 或 "akshare"
|
||||
data_source: "postgresql"
|
||||
|
||||
# ==================== 筛选参数 ====================
|
||||
# CCI指标周期
|
||||
day_period: 14
|
||||
week_period: 14
|
||||
|
||||
# 筛选阈值(低于该值视为超卖信号)
|
||||
threshold: -100
|
||||
|
||||
# 数据获取天数(用于计算CCI)
|
||||
lookback_days: 100
|
||||
|
||||
# ==================== 标的池 ====================
|
||||
# 指数代码列表文件路径(CSV格式,需包含"指数代码"和"指数名称"列)
|
||||
index_fund_info_file: "index_fund_info.csv"
|
||||
|
||||
# ==================== 定时任务 ====================
|
||||
# 运行时间(24小时制)
|
||||
schedule_time: "19:00"
|
||||
|
||||
# 是否跳过周末
|
||||
skip_weekend: true
|
||||
@@ -61,6 +61,10 @@ ssh_config: Optional[Dict] = None
|
||||
CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128'))
|
||||
CACHE_TTL_SECONDS = int(os.getenv('CACHE_TTL_SECONDS', '7200')) # 默认2小时
|
||||
|
||||
# 默认数据起点(下载全量数据时使用)
|
||||
# 设置为1980年以支持最长历史数据(标普500/日经225等)
|
||||
DEFAULT_START_DATE = os.getenv('DEFAULT_START_DATE', '1980-01-01')
|
||||
|
||||
|
||||
class TimedCacheEntry:
|
||||
"""带时间戳的缓存条目"""
|
||||
@@ -92,7 +96,7 @@ def get_ssh_config() -> Optional[Dict]:
|
||||
"host": os.getenv('SSH_HOST', ''),
|
||||
"port": int(os.getenv('SSH_PORT', '22')),
|
||||
"username": os.getenv('SSH_USERNAME', ''),
|
||||
"key_path": os.getenv('SSH_KEY_PATH', 'config/hk_ecs.pem'),
|
||||
"key_path": os.getenv('SSH_KEY_PATH', 'hk_ecs.pem'),
|
||||
"local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')),
|
||||
}
|
||||
|
||||
@@ -113,29 +117,97 @@ def get_fetcher() -> UniversalDataFetcher:
|
||||
# ============================================================
|
||||
|
||||
@lru_cache(maxsize=CACHE_MAXSIZE)
|
||||
def _fetch_data_cached(code: str, start: str, end: str) -> Optional[str]:
|
||||
def _fetch_full_data_cached(code: str, today: str) -> Optional[str]:
|
||||
"""
|
||||
获取数据的缓存版本
|
||||
返回 JSON 序列化的字符串
|
||||
缓存全量数据(从 DEFAULT_START_DATE 到 today)
|
||||
|
||||
缓存Key: (code, today_date)
|
||||
- today: 实际的今天日期,用于每日更新缓存
|
||||
- 每天下载一次全量数据,避免重复请求
|
||||
|
||||
Returns:
|
||||
JSON 序列化的全量数据
|
||||
"""
|
||||
f = get_fetcher()
|
||||
|
||||
try:
|
||||
with f:
|
||||
df = f.fetch(code, start, end)
|
||||
# 下载全量数据:从默认起点到今天
|
||||
df = f.fetch(code, DEFAULT_START_DATE, today)
|
||||
|
||||
if df is None or len(df) == 0:
|
||||
return None
|
||||
|
||||
result = dataframe_to_json(df)
|
||||
result['code'] = code
|
||||
result['asset_type'] = AssetTypeDetector.detect(code).value
|
||||
# 保存为 DataFrame 格式(方便后续切片)
|
||||
result = {
|
||||
'df_json': dataframe_to_json(df),
|
||||
'code': code,
|
||||
'asset_type': AssetTypeDetector.detect(code).value,
|
||||
'data_start': df.index.min().strftime('%Y-%m-%d') if len(df) > 0 else None,
|
||||
'data_end': df.index.max().strftime('%Y-%m-%d') if len(df) > 0 else None,
|
||||
}
|
||||
|
||||
return json.dumps(result)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
def _slice_data_from_cache(cached_data: Dict, start: str, end: str) -> Dict:
|
||||
"""
|
||||
从缓存的全量数据中切片指定日期范围
|
||||
|
||||
Args:
|
||||
cached_data: 缓存的全量数据
|
||||
start: 用户请求的开始日期
|
||||
end: 用户请求的结束日期
|
||||
|
||||
Returns:
|
||||
切片后的数据(JSON格式)
|
||||
"""
|
||||
if 'df_json' not in cached_data or 'data' not in cached_data['df_json']:
|
||||
return cached_data
|
||||
|
||||
# 从缓存数据中重建 DataFrame
|
||||
records = cached_data['df_json']['data']
|
||||
if not records:
|
||||
return cached_data
|
||||
|
||||
# 转换为 DataFrame
|
||||
df = pd.DataFrame(records)
|
||||
if 'date' in df.columns:
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
df = df.set_index('date')
|
||||
|
||||
# 切片日期范围
|
||||
start_dt = pd.to_datetime(start)
|
||||
end_dt = pd.to_datetime(end)
|
||||
|
||||
# 确保索引已排序
|
||||
df = df.sort_index()
|
||||
|
||||
# 切片(使用 loc 进行日期范围选择)
|
||||
sliced_df = df.loc[start_dt:end_dt]
|
||||
|
||||
if len(sliced_df) == 0:
|
||||
return {
|
||||
'data': [],
|
||||
'count': 0,
|
||||
'code': cached_data['code'],
|
||||
'asset_type': cached_data['asset_type'],
|
||||
'requested_range': {'start': start, 'end': end},
|
||||
'available_range': {'start': cached_data['data_start'], 'end': cached_data['data_end']},
|
||||
}
|
||||
|
||||
# 转换为 JSON 格式
|
||||
result = dataframe_to_json(sliced_df)
|
||||
result['code'] = cached_data['code']
|
||||
result['asset_type'] = cached_data['asset_type']
|
||||
result['requested_range'] = {'start': start, 'end': end}
|
||||
result['available_range'] = {'start': cached_data['data_start'], 'end': cached_data['data_end']}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fetch_data_with_ttl(
|
||||
code: str,
|
||||
start: str,
|
||||
@@ -145,56 +217,76 @@ def fetch_data_with_ttl(
|
||||
"""
|
||||
获取数据,支持 TTL 缓存
|
||||
|
||||
缓存策略:
|
||||
- Key: (code, today_date) 缓存全量数据
|
||||
- 每天下载一次全量数据(从 DEFAULT_START_DATE 到今天)
|
||||
- 用户请求时从缓存切片 start-end 范围返回
|
||||
|
||||
Args:
|
||||
code: 标的代码
|
||||
start: 开始日期
|
||||
end: 结束日期
|
||||
start: 用户请求的开始日期
|
||||
end: 用户请求的结束日期
|
||||
nocache: 是否跳过缓存
|
||||
|
||||
Returns:
|
||||
(data, is_cached): 数据和是否命中缓存
|
||||
(data, is_cached): 切片后的数据和是否命中缓存
|
||||
"""
|
||||
cache_key = (code, start, end)
|
||||
# 获取今天的实际日期(用于缓存Key)
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
full_cache_key = (code, today)
|
||||
|
||||
# 跳过缓存
|
||||
# 跳过缓存:清理缓存后重新下载
|
||||
if nocache:
|
||||
_fetch_data_cached.cache_clear()
|
||||
result_json = _fetch_data_cached(code, start, end)
|
||||
return (json.loads(result_json) if result_json else None, False)
|
||||
_fetch_full_data_cached.cache_clear()
|
||||
global _ttl_cache
|
||||
_ttl_cache.clear()
|
||||
result_json = _fetch_full_data_cached(code, today)
|
||||
if result_json is None:
|
||||
return None, False
|
||||
full_data = json.loads(result_json)
|
||||
return (_slice_data_from_cache(full_data, start, end), False)
|
||||
|
||||
# 检查 TTL 缓存
|
||||
global _ttl_cache
|
||||
if cache_key in _ttl_cache:
|
||||
entry = _ttl_cache[cache_key]
|
||||
# 检查 TTL 缓存(全量数据缓存)
|
||||
if full_cache_key in _ttl_cache:
|
||||
entry = _ttl_cache[full_cache_key]
|
||||
if not entry.is_expired():
|
||||
return entry.data, True
|
||||
# 从缓存切片
|
||||
sliced_data = _slice_data_from_cache(entry.data, start, end)
|
||||
return sliced_data, True
|
||||
# 过期,删除
|
||||
del _ttl_cache[cache_key]
|
||||
del _ttl_cache[full_cache_key]
|
||||
|
||||
# 从 LRU 缓存获取
|
||||
result_json = _fetch_data_cached(code, start, end)
|
||||
# 从 LRU 缓存获取全量数据
|
||||
result_json = _fetch_full_data_cached(code, today)
|
||||
|
||||
if result_json is None:
|
||||
return None, False
|
||||
|
||||
result = json.loads(result_json)
|
||||
full_data = json.loads(result_json)
|
||||
|
||||
# 存入 TTL 缓存
|
||||
_ttl_cache[cache_key] = TimedCacheEntry(result)
|
||||
# 检查是否有错误
|
||||
if "error" in full_data:
|
||||
return full_data, False
|
||||
|
||||
return result, False
|
||||
# 存入 TTL 缓存(存全量数据)
|
||||
_ttl_cache[full_cache_key] = TimedCacheEntry(full_data)
|
||||
|
||||
# 从全量数据切片返回用户请求的范围
|
||||
sliced_data = _slice_data_from_cache(full_data, start, end)
|
||||
|
||||
return sliced_data, False
|
||||
|
||||
|
||||
def clear_cache():
|
||||
"""清理所有缓存"""
|
||||
global _ttl_cache
|
||||
_fetch_data_cached.cache_clear()
|
||||
_fetch_full_data_cached.cache_clear()
|
||||
_ttl_cache.clear()
|
||||
|
||||
|
||||
def get_cache_info() -> Dict:
|
||||
"""获取缓存统计信息"""
|
||||
info = _fetch_data_cached.cache_info()
|
||||
info = _fetch_full_data_cached.cache_info()
|
||||
return {
|
||||
"lru_cache": {
|
||||
"hits": info.hits,
|
||||
@@ -204,6 +296,8 @@ def get_cache_info() -> Dict:
|
||||
},
|
||||
"ttl_cache_size": len(_ttl_cache),
|
||||
"ttl_seconds": CACHE_TTL_SECONDS,
|
||||
"default_start_date": DEFAULT_START_DATE,
|
||||
"cache_strategy": "full_data_by_code_and_today",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class HybridDataSource:
|
||||
使用方式:
|
||||
from datasource import HybridDataSource
|
||||
|
||||
source = HybridDataSource.from_yaml('config/strategies/rotation.yaml')
|
||||
source = HybridDataSource.from_yaml('strategies/rotation/config.yaml')
|
||||
result = source.fetch_all()
|
||||
"""
|
||||
|
||||
@@ -259,7 +259,7 @@ class HybridDataSource:
|
||||
|
||||
|
||||
# 简化接口
|
||||
def fetch_rotation_data(config_path: str = "config/strategies/rotation.yaml") -> dict:
|
||||
def fetch_rotation_data(config_path: str = "strategies/rotation/config.yaml") -> dict:
|
||||
"""
|
||||
获取轮动策略数据(简化接口)
|
||||
|
||||
|
||||
@@ -42,6 +42,25 @@ class SSHTunnelManager:
|
||||
key_path = str(project_root / key_path)
|
||||
self.key_path = key_path
|
||||
|
||||
def _cleanup_old_processes(self):
|
||||
"""清理残留的同端口SSH进程"""
|
||||
try:
|
||||
# 查找监听同一端口的SSH进程
|
||||
result = subprocess.run(
|
||||
['pgrep', '-f', f'ssh.*-D.*{self.local_port}'],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
pids = result.stdout.strip().split('\n')
|
||||
for pid in pids:
|
||||
try:
|
||||
subprocess.run(['kill', '-9', pid], check=True)
|
||||
print(f" 清理残留SSH进程: PID {pid}")
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
except Exception:
|
||||
pass # pgrep不可用或其他问题,忽略
|
||||
|
||||
def start(self) -> bool:
|
||||
"""启动SSH隧道"""
|
||||
if not self.enabled:
|
||||
@@ -51,6 +70,9 @@ class SSHTunnelManager:
|
||||
print("SSH配置不完整,跳过隧道建立")
|
||||
return False
|
||||
|
||||
# 先清理残留的同端口SSH进程
|
||||
self._cleanup_old_processes()
|
||||
|
||||
print(f"建立SSH隧道: {self.host}:{self.port} -> 本地SOCKS5端口 {self.local_port}")
|
||||
|
||||
cmd = [
|
||||
|
||||
@@ -11,7 +11,7 @@ services:
|
||||
# 挂载环境变量文件(必需)
|
||||
- ./.env:/app/.env:ro
|
||||
# 挂载 SSH 私钥(必需,用于 yfinance 数据下载)
|
||||
- ./config/hk_ecs.pem:/app/config/hk_ecs.pem:ro
|
||||
- ./hk_ecs.pem:/app/hk_ecs.pem:ro
|
||||
# 挂载数据目录(持久化)
|
||||
- ./data:/app/data
|
||||
# 挂载日志目录
|
||||
|
||||
@@ -4,7 +4,7 @@ ETF轮动策略回测入口
|
||||
|
||||
用法:
|
||||
python run_rotation.py
|
||||
python run_rotation.py --config config/strategies/rotation.yaml
|
||||
python run_rotation.py --config strategies/rotation/config.yaml
|
||||
python run_rotation.py --save-path results/my_rotation
|
||||
"""
|
||||
|
||||
@@ -20,7 +20,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="config/strategies/rotation.yaml",
|
||||
default="strategies/rotation/config.yaml",
|
||||
help="配置文件路径",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -34,7 +34,7 @@ def run_with_legacy_report():
|
||||
"""运行新框架回测并生成原引擎格式报告"""
|
||||
|
||||
# 加载配置
|
||||
config_path = 'config/strategies/rotation.yaml'
|
||||
config_path = 'strategies/rotation/config.yaml'
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
|
||||
@@ -124,5 +124,5 @@ ssh_tunnel:
|
||||
host: "8.218.167.69" # SSH 服务器地址(阿里云香港 ECS IP)
|
||||
port: 22 # SSH 端口
|
||||
username: "root" # SSH 用户名
|
||||
key_path: "config/hk_ecs.pem" # SSH 私钥路径(相对于项目根目录)
|
||||
key_path: "hk_ecs.pem" # SSH 私钥路径(相对于项目根目录)
|
||||
local_port: 1080 # 本地 SOCKS5 代理端口
|
||||
@@ -32,7 +32,7 @@ class RotationStrategy(StrategyBase):
|
||||
|
||||
使用方式:
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
strategy = RotationStrategy.from_yaml('config/strategies/rotation.yaml')
|
||||
strategy = RotationStrategy.from_yaml('strategies/rotation/config.yaml')
|
||||
result = strategy.run_backtest()
|
||||
"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user