From 8b2c2be6f36193f4ff9bd84a7e4e658c74b661df Mon Sep 17 00:00:00 2001 From: aszerW Date: Thu, 7 May 2026 21:19:29 +0800 Subject: [PATCH] =?UTF-8?q?feat(api):=20=E5=AE=9E=E7=8E=B0=20Flask=20RESTf?= =?UTF-8?q?ul=20API=20=E6=95=B0=E6=8D=AE=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 Flask 服务提供统一 HTTP 接口 - 支持 6 个 API 端点:health、asset-type、ohlcv、batch、supported-codes - 集成 SSH 隧道自动管理(环境变量配置) - 提供一键启动脚本 start_flask_server.sh - 支持 CORS 跨域访问 - 完善的错误处理和响应格式化 --- core/datasource/flask_server.py | 540 ++++++++++++++++++++++++++++++++ start_flask_server.sh | 127 ++++++++ 2 files changed, 667 insertions(+) create mode 100644 core/datasource/flask_server.py create mode 100755 start_flask_server.sh diff --git a/core/datasource/flask_server.py b/core/datasource/flask_server.py new file mode 100644 index 0000000..edbfc32 --- /dev/null +++ b/core/datasource/flask_server.py @@ -0,0 +1,540 @@ +""" +Flask 数据服务 API +================== +提供 RESTful API 接口,支持获取各类资产的 K 线数据 + +支持的资产类型: +- A股指数 (000300.SH, 399006.SZ) +- A股ETF (510300.SH, 159915.SZ) +- A股股票 (600000.SH, 000001.SZ) +- 港股指数 (HSI, HSTECH.HK) +- 美股指数 (NDX, SPX, DJI) +- 美股股票 (AAPL, MSFT, GOOGL) +- 期货合约 (AU.SHF, CU.SHF) +- 加密货币 (BTC, ETH) + +运行: + python core/datasource/flask_server.py + +API 文档: + GET /health - 健康检查 + GET /api/v1/asset-type?code=000300.SH - 检测资产类型 + GET /api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31 - 获取K线数据 + POST /api/v1/ohlcv/batch - 批量获取K线数据 +""" + +import os +import sys +import json +from pathlib import Path +from datetime import datetime, timedelta +from typing import Optional, Dict, Any, List + +# 添加项目根目录到路径 +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + +from dotenv import load_dotenv +load_dotenv() + +from flask import Flask, request, jsonify +from flask_cors import CORS +import pandas as pd + +from core.datasource.universal_fetcher import ( + UniversalDataFetcher, + AssetTypeDetector, +) + + +# ============================================================ +# Flask 应用配置 +# ============================================================ + +app = Flask(__name__) +CORS(app) # 启用跨域支持 + +# 全局数据获取器实例 +fetcher: Optional[UniversalDataFetcher] = None +ssh_config: Optional[Dict] = None + + +def get_ssh_config() -> Optional[Dict]: + """ + 从环境变量获取 SSH 配置 + + 环境变量: + SSH_ENABLED: 是否启用 SSH 隧道 (true/false) + SSH_HOST: SSH 服务器地址 + SSH_PORT: SSH 端口 (默认 22) + SSH_USERNAME: SSH 用户名 + SSH_KEY_PATH: SSH 私钥路径 + SSH_LOCAL_PORT: 本地 SOCKS5 端口 (默认 1080) + """ + enabled = os.getenv('SSH_ENABLED', 'false').lower() == 'true' + + if not enabled: + return None + + return { + "enabled": True, + "host": os.getenv('SSH_HOST', ''), + "port": int(os.getenv('SSH_PORT', '22')), + "username": os.getenv('SSH_USERNAME', ''), + "key_path": os.getenv('SSH_KEY_PATH', 'hk_ecs.pem'), + "local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')), + } + + +def get_fetcher() -> UniversalDataFetcher: + """获取或创建数据获取器实例""" + global fetcher, ssh_config + + if fetcher is None: + ssh_config = get_ssh_config() + fetcher = UniversalDataFetcher(ssh_config=ssh_config) + + return fetcher + + +def dataframe_to_json(df: pd.DataFrame) -> Dict: + """将 DataFrame 转换为 JSON 可序列化的字典""" + if df is None or len(df) == 0: + return {"data": [], "count": 0} + + # 重置索引,将日期转为列 + df_reset = df.reset_index() + + # 处理日期列 + if 'date' in df_reset.columns: + df_reset['date'] = df_reset['date'].dt.strftime('%Y-%m-%d') + elif 'index' in df_reset.columns: + df_reset['index'] = pd.to_datetime(df_reset['index']).dt.strftime('%Y-%m-%d') + df_reset = df_reset.rename(columns={'index': 'date'}) + + # 转换为字典列表 + records = df_reset.to_dict(orient='records') + + return { + "data": records, + "count": len(records), + "columns": list(df_reset.columns), + "date_range": { + "start": df.index.min().strftime('%Y-%m-%d'), + "end": df.index.max().strftime('%Y-%m-%d'), + } if len(df) > 0 else None + } + + +def validate_date(date_str: str) -> bool: + """验证日期格式""" + try: + datetime.strptime(date_str, '%Y-%m-%d') + return True + except ValueError: + return False + + +def get_default_dates() -> tuple: + """获取默认日期范围(最近3个月)""" + end = datetime.now() + start = end - timedelta(days=90) + return start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d') + + +# ============================================================ +# API 路由 +# ============================================================ + +@app.route('/') +def index(): + """首页 - API 信息""" + return jsonify({ + "name": "Universal Data Fetcher API", + "version": "1.0.0", + "description": "统一数据获取服务,支持A股、港股、美股、期货、加密货币", + "endpoints": { + "health": "/health", + "asset_type": "/api/v1/asset-type?code={code}", + "ohlcv": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}", + "batch": "POST /api/v1/ohlcv/batch", + }, + "supported_assets": [ + "A股指数 (000300.SH)", + "A股ETF (510300.SH)", + "A股股票 (600000.SH)", + "港股指数 (HSI)", + "美股指数 (NDX)", + "美股股票 (AAPL)", + "期货 (AU.SHF)", + "加密货币 (BTC)", + ], + "ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled", + }) + + +@app.route('/health') +def health(): + """健康检查""" + return jsonify({ + "status": "healthy", + "timestamp": datetime.now().isoformat(), + "ssh_configured": ssh_config is not None and ssh_config.get('enabled', False), + }) + + +@app.route('/api/v1/asset-type') +def detect_asset_type(): + """ + 检测资产类型 + + Query Parameters: + code: 标的代码 (required) + + Returns: + { + "code": "000300.SH", + "asset_type": "china_index", + "description": "A股指数" + } + """ + code = request.args.get('code', '').strip() + + if not code: + return jsonify({ + "error": "Missing required parameter: code", + "example": "/api/v1/asset-type?code=000300.SH" + }), 400 + + asset_type = AssetTypeDetector.detect(code) + + # 资产类型描述映射 + descriptions = { + 'china_index': 'A股指数', + 'china_etf': 'A股ETF', + 'china_stock': 'A股股票', + 'hk_index': '港股指数', + 'hk_stock': '港股股票', + 'us_index': '美股指数', + 'us_stock': '美股股票', + 'futures': '期货合约', + 'crypto': '加密货币', + } + + return jsonify({ + "code": code, + "asset_type": asset_type, + "description": descriptions.get(asset_type, '未知类型'), + }) + + +@app.route('/api/v1/ohlcv') +def get_ohlcv(): + """ + 获取单只标的的 OHLCV 数据 + + Query Parameters: + code: 标的代码 (required) + start: 开始日期,格式 YYYY-MM-DD (optional, 默认90天前) + end: 结束日期,格式 YYYY-MM-DD (optional, 默认今天) + retry: 重试次数 (optional, 默认3) + + Returns: + { + "code": "000300.SH", + "asset_type": "china_index", + "data": [ + {"date": "2024-01-02", "open": 3500.0, "high": 3550.0, ...}, + ... + ], + "count": 58, + "date_range": {"start": "2024-01-02", "end": "2024-03-29"} + } + """ + code = request.args.get('code', '').strip() + start = request.args.get('start', '').strip() + end = request.args.get('end', '').strip() + retry = request.args.get('retry', '3') + + # 参数验证 + if not code: + return jsonify({ + "error": "Missing required parameter: code", + "example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31" + }), 400 + + # 设置默认日期 + if not start or not end: + default_start, default_end = get_default_dates() + start = start or default_start + end = end or default_end + + # 日期格式验证 + if not validate_date(start) or not validate_date(end): + return jsonify({ + "error": "Invalid date format. Use YYYY-MM-DD", + "start": start, + "end": end, + }), 400 + + try: + retry_count = int(retry) + except ValueError: + retry_count = 3 + + # 检测资产类型 + asset_type = AssetTypeDetector.detect(code) + + # 获取数据 + fetcher = get_fetcher() + + try: + with fetcher: + df = fetcher.fetch(code, start, end, retry=retry_count) + + if df is None or len(df) == 0: + return jsonify({ + "code": code, + "asset_type": asset_type, + "error": "No data available for the specified code and date range", + "start": start, + "end": end, + }), 404 + + result = dataframe_to_json(df) + result['code'] = code + result['asset_type'] = asset_type + + return jsonify(result) + + except Exception as e: + return jsonify({ + "code": code, + "asset_type": asset_type, + "error": str(e), + "start": start, + "end": end, + }), 500 + + +@app.route('/api/v1/ohlcv/batch', methods=['POST']) +def batch_ohlcv(): + """ + 批量获取多只标的的 OHLCV 数据 + + Request Body: + { + "codes": ["000300.SH", "NDX", "HSI"], + "start": "2024-01-01", + "end": "2024-03-31", + "retry": 3 + } + + Returns: + { + "results": { + "000300.SH": {"data": [...], "count": 58, ...}, + "NDX": {"data": [...], "count": 61, ...}, + "HSI": {"error": "No data available", ...} + }, + "success_count": 2, + "failed_count": 1, + "total": 3 + } + """ + data = request.get_json() + + if not data: + return jsonify({ + "error": "Missing request body", + "example": { + "codes": ["000300.SH", "NDX"], + "start": "2024-01-01", + "end": "2024-03-31", + } + }), 400 + + codes = data.get('codes', []) + start = data.get('start', '').strip() + end = data.get('end', '').strip() + retry = data.get('retry', 3) + + if not codes or not isinstance(codes, list): + return jsonify({ + "error": "Missing or invalid parameter: codes (must be a list)" + }), 400 + + # 设置默认日期 + if not start or not end: + default_start, default_end = get_default_dates() + start = start or default_start + end = end or default_end + + # 日期格式验证 + if not validate_date(start) or not validate_date(end): + return jsonify({ + "error": "Invalid date format. Use YYYY-MM-DD", + "start": start, + "end": end, + }), 400 + + # 获取数据 + fetcher = get_fetcher() + results = {} + success_count = 0 + failed_count = 0 + + try: + with fetcher: + for code in codes: + try: + df = fetcher.fetch(code, start, end, retry=retry) + + if df is not None and len(df) > 0: + result = dataframe_to_json(df) + result['code'] = code + result['asset_type'] = AssetTypeDetector.detect(code) + results[code] = result + success_count += 1 + else: + results[code] = { + "code": code, + "asset_type": AssetTypeDetector.detect(code), + "error": "No data available", + "data": [], + "count": 0, + } + failed_count += 1 + + except Exception as e: + results[code] = { + "code": code, + "asset_type": AssetTypeDetector.detect(code), + "error": str(e), + "data": [], + "count": 0, + } + failed_count += 1 + + except Exception as e: + return jsonify({ + "error": f"Batch fetch failed: {str(e)}" + }), 500 + + return jsonify({ + "results": results, + "success_count": success_count, + "failed_count": failed_count, + "total": len(codes), + "start": start, + "end": end, + }) + + +@app.route('/api/v1/supported-codes') +def get_supported_codes(): + """ + 获取支持的代码示例 + + Returns: + { + "china_index": ["000300.SH", "399006.SZ", ...], + "china_etf": ["510300.SH", "159915.SZ", ...], + ... + } + """ + return jsonify({ + "china_index": { + "description": "A股指数", + "examples": ["000300.SH", "399006.SZ", "000016.SH", "H30269.CSI"], + }, + "china_etf": { + "description": "A股ETF", + "examples": ["510300.SH", "159915.SZ", "510500.SH", "513100.SH"], + }, + "china_stock": { + "description": "A股股票", + "examples": ["600000.SH", "000001.SZ"], + }, + "hk_index": { + "description": "港股指数", + "examples": ["HSI", "HSTECH.HK"], + }, + "us_index": { + "description": "美股指数", + "examples": ["NDX", "SPX", "DJI", "N225", "GDAXI"], + }, + "us_stock": { + "description": "美股股票", + "examples": ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"], + }, + "futures": { + "description": "期货合约", + "examples": ["AU.SHF", "CU.SHF"], + }, + "crypto": { + "description": "加密货币", + "examples": ["BTC", "ETH"], + }, + }) + + +# ============================================================ +# 错误处理 +# ============================================================ + +@app.errorhandler(404) +def not_found(error): + return jsonify({ + "error": "Endpoint not found", + "available_endpoints": [ + "/", + "/health", + "/api/v1/asset-type?code={code}", + "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}", + "/api/v1/ohlcv/batch", + "/api/v1/supported-codes", + ] + }), 404 + + +@app.errorhandler(500) +def internal_error(error): + return jsonify({ + "error": "Internal server error", + "message": str(error) + }), 500 + + +# ============================================================ +# 启动服务 +# ============================================================ + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description='Universal Data Fetcher API Server') + parser.add_argument('--host', default='0.0.0.0', help='Host to bind (default: 0.0.0.0)') + parser.add_argument('--port', type=int, default=5000, help='Port to bind (default: 5000)') + parser.add_argument('--debug', action='store_true', help='Enable debug mode') + + args = parser.parse_args() + + # 预加载 SSH 配置 + ssh_config = get_ssh_config() + if ssh_config and ssh_config.get('enabled'): + print(f"✓ SSH 隧道已配置: {ssh_config['host']}:{ssh_config['port']}") + else: + print("✗ SSH 隧道未启用(仅支持A股数据)") + + print(f"\n🚀 启动 Universal Data Fetcher API Server") + print(f" Host: {args.host}") + print(f" Port: {args.port}") + print(f" Debug: {args.debug}") + print(f"\n📖 API 文档: http://{args.host}:{args.port}/") + print(f" 健康检查: http://{args.host}:{args.port}/health") + print(f"\n💡 示例请求:") + print(f" curl 'http://{args.host}:{args.port}/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31'") + print(f"\n") + + app.run(host=args.host, port=args.port, debug=args.debug) diff --git a/start_flask_server.sh b/start_flask_server.sh new file mode 100755 index 0000000..48e6966 --- /dev/null +++ b/start_flask_server.sh @@ -0,0 +1,127 @@ +#!/bin/bash +# Flask API 服务启动脚本 +# ===================== + +# 颜色定义 +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +RED='\033[0;31m' +NC='\033[0m' # No Color + +echo -e "${GREEN}Universal Data Fetcher API 服务启动脚本${NC}" +echo "==========================================" + +# 检查 Python +echo -e "\n1. 检查 Python 环境..." +if ! command -v python &> /dev/null; then + echo -e "${RED}✗ Python 未安装${NC}" + exit 1 +fi +echo -e "${GREEN}✓ Python 已安装: $(python --version)${NC}" + +# 检查依赖 +echo -e "\n2. 检查依赖..." +python -c "import flask" 2>/dev/null || { + echo -e "${YELLOW}⚠ Flask 未安装,正在安装...${NC}" + pip install flask flask-cors +} +echo -e "${GREEN}✓ 依赖检查完成${NC}" + +# 检查环境变量 +echo -e "\n3. 检查环境变量..." +if [ -z "$TUSHARE_TOKEN" ]; then + if [ -f ".env" ]; then + echo -e "${YELLOW}⚠ 从 .env 文件加载环境变量${NC}" + export $(cat .env | grep -v '^#' | xargs) + else + echo -e "${RED}✗ TUSHARE_TOKEN 未设置${NC}" + echo " 请在 .env 文件中设置 TUSHARE_TOKEN" + exit 1 + fi +fi +echo -e "${GREEN}✓ 环境变量检查完成${NC}" + +# 检查 SSH 配置 +echo -e "\n4. 检查 SSH 配置..." +if [ -f "hk_ecs.pem" ]; then + echo -e "${GREEN}✓ SSH 私钥文件存在 (hk_ecs.pem)${NC}" + + # 检查权限 + PERM=$(stat -f "%Lp" hk_ecs.pem 2>/dev/null || stat -c "%a" hk_ecs.pem 2>/dev/null) + if [ "$PERM" != "600" ]; then + echo -e "${YELLOW}⚠ 修复 SSH 私钥权限...${NC}" + chmod 600 hk_ecs.pem + fi + echo -e "${GREEN}✓ SSH 私钥权限正确 (600)${NC}" +else + echo -e "${YELLOW}⚠ SSH 私钥文件不存在,港美股数据获取将受限${NC}" +fi + +# 解析参数 +HOST="0.0.0.0" +PORT="5000" +DEBUG="" + +while [[ $# -gt 0 ]]; do + case $1 in + --host) + HOST="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --debug) + DEBUG="--debug" + shift + ;; + --with-ssh) + export SSH_ENABLED=true + export SSH_HOST=8.218.167.69 + export SSH_PORT=22 + export SSH_USERNAME=root + export SSH_KEY_PATH=hk_ecs.pem + export SSH_LOCAL_PORT=1080 + echo -e "${GREEN}✓ SSH 隧道已启用${NC}" + shift + ;; + --help) + echo "" + echo "用法: ./start_flask_server.sh [选项]" + echo "" + echo "选项:" + echo " --host HOST 绑定主机 (默认: 0.0.0.0)" + echo " --port PORT 绑定端口 (默认: 5000)" + echo " --debug 启用调试模式" + echo " --with-ssh 启用 SSH 隧道" + echo " --help 显示帮助" + echo "" + echo "示例:" + echo " ./start_flask_server.sh" + echo " ./start_flask_server.sh --port 8080" + echo " ./start_flask_server.sh --with-ssh" + echo " ./start_flask_server.sh --host 127.0.0.1 --port 5000 --debug --with-ssh" + exit 0 + ;; + *) + echo -e "${RED}未知选项: $1${NC}" + echo "使用 --help 查看帮助" + exit 1 + ;; + esac +done + +# 启动服务 +echo -e "\n5. 启动 Flask 服务..." +echo -e " 主机: ${YELLOW}$HOST${NC}" +echo -e " 端口: ${YELLOW}$PORT${NC}" +echo -e " 调试: ${YELLOW}$([ -n "$DEBUG" ] && echo "是" || echo "否")${NC}" +echo -e " SSH: ${YELLOW}$([ "$SSH_ENABLED" = "true" ] && echo "启用" || echo "禁用")${NC}" +echo "" + +echo -e "${GREEN}✓ 服务启动中...${NC}" +echo "==========================================" +echo "" + +python core/datasource/flask_server.py --host "$HOST" --port "$PORT" $DEBUG