feat(api): 实现 Flask RESTful API 数据服务

- 新增 Flask 服务提供统一 HTTP 接口
- 支持 6 个 API 端点:health、asset-type、ohlcv、batch、supported-codes
- 集成 SSH 隧道自动管理(环境变量配置)
- 提供一键启动脚本 start_flask_server.sh
- 支持 CORS 跨域访问
- 完善的错误处理和响应格式化
This commit is contained in:
2026-05-07 21:19:29 +08:00
parent e319426c10
commit 8b2c2be6f3
2 changed files with 667 additions and 0 deletions

View File

@@ -0,0 +1,540 @@
"""
Flask 数据服务 API
==================
提供 RESTful API 接口,支持获取各类资产的 K 线数据
支持的资产类型:
- A股指数 (000300.SH, 399006.SZ)
- A股ETF (510300.SH, 159915.SZ)
- A股股票 (600000.SH, 000001.SZ)
- 港股指数 (HSI, HSTECH.HK)
- 美股指数 (NDX, SPX, DJI)
- 美股股票 (AAPL, MSFT, GOOGL)
- 期货合约 (AU.SHF, CU.SHF)
- 加密货币 (BTC, ETH)
运行:
python core/datasource/flask_server.py
API 文档:
GET /health - 健康检查
GET /api/v1/asset-type?code=000300.SH - 检测资产类型
GET /api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31 - 获取K线数据
POST /api/v1/ohlcv/batch - 批量获取K线数据
"""
import os
import sys
import json
from pathlib import Path
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from dotenv import load_dotenv
load_dotenv()
from flask import Flask, request, jsonify
from flask_cors import CORS
import pandas as pd
from core.datasource.universal_fetcher import (
UniversalDataFetcher,
AssetTypeDetector,
)
# ============================================================
# Flask 应用配置
# ============================================================
app = Flask(__name__)
CORS(app) # 启用跨域支持
# 全局数据获取器实例
fetcher: Optional[UniversalDataFetcher] = None
ssh_config: Optional[Dict] = None
def get_ssh_config() -> Optional[Dict]:
"""
从环境变量获取 SSH 配置
环境变量:
SSH_ENABLED: 是否启用 SSH 隧道 (true/false)
SSH_HOST: SSH 服务器地址
SSH_PORT: SSH 端口 (默认 22)
SSH_USERNAME: SSH 用户名
SSH_KEY_PATH: SSH 私钥路径
SSH_LOCAL_PORT: 本地 SOCKS5 端口 (默认 1080)
"""
enabled = os.getenv('SSH_ENABLED', 'false').lower() == 'true'
if not enabled:
return None
return {
"enabled": True,
"host": os.getenv('SSH_HOST', ''),
"port": int(os.getenv('SSH_PORT', '22')),
"username": os.getenv('SSH_USERNAME', ''),
"key_path": os.getenv('SSH_KEY_PATH', 'hk_ecs.pem'),
"local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')),
}
def get_fetcher() -> UniversalDataFetcher:
"""获取或创建数据获取器实例"""
global fetcher, ssh_config
if fetcher is None:
ssh_config = get_ssh_config()
fetcher = UniversalDataFetcher(ssh_config=ssh_config)
return fetcher
def dataframe_to_json(df: pd.DataFrame) -> Dict:
"""将 DataFrame 转换为 JSON 可序列化的字典"""
if df is None or len(df) == 0:
return {"data": [], "count": 0}
# 重置索引,将日期转为列
df_reset = df.reset_index()
# 处理日期列
if 'date' in df_reset.columns:
df_reset['date'] = df_reset['date'].dt.strftime('%Y-%m-%d')
elif 'index' in df_reset.columns:
df_reset['index'] = pd.to_datetime(df_reset['index']).dt.strftime('%Y-%m-%d')
df_reset = df_reset.rename(columns={'index': 'date'})
# 转换为字典列表
records = df_reset.to_dict(orient='records')
return {
"data": records,
"count": len(records),
"columns": list(df_reset.columns),
"date_range": {
"start": df.index.min().strftime('%Y-%m-%d'),
"end": df.index.max().strftime('%Y-%m-%d'),
} if len(df) > 0 else None
}
def validate_date(date_str: str) -> bool:
"""验证日期格式"""
try:
datetime.strptime(date_str, '%Y-%m-%d')
return True
except ValueError:
return False
def get_default_dates() -> tuple:
"""获取默认日期范围最近3个月"""
end = datetime.now()
start = end - timedelta(days=90)
return start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d')
# ============================================================
# API 路由
# ============================================================
@app.route('/')
def index():
"""首页 - API 信息"""
return jsonify({
"name": "Universal Data Fetcher API",
"version": "1.0.0",
"description": "统一数据获取服务支持A股、港股、美股、期货、加密货币",
"endpoints": {
"health": "/health",
"asset_type": "/api/v1/asset-type?code={code}",
"ohlcv": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}",
"batch": "POST /api/v1/ohlcv/batch",
},
"supported_assets": [
"A股指数 (000300.SH)",
"A股ETF (510300.SH)",
"A股股票 (600000.SH)",
"港股指数 (HSI)",
"美股指数 (NDX)",
"美股股票 (AAPL)",
"期货 (AU.SHF)",
"加密货币 (BTC)",
],
"ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled",
})
@app.route('/health')
def health():
"""健康检查"""
return jsonify({
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"ssh_configured": ssh_config is not None and ssh_config.get('enabled', False),
})
@app.route('/api/v1/asset-type')
def detect_asset_type():
"""
检测资产类型
Query Parameters:
code: 标的代码 (required)
Returns:
{
"code": "000300.SH",
"asset_type": "china_index",
"description": "A股指数"
}
"""
code = request.args.get('code', '').strip()
if not code:
return jsonify({
"error": "Missing required parameter: code",
"example": "/api/v1/asset-type?code=000300.SH"
}), 400
asset_type = AssetTypeDetector.detect(code)
# 资产类型描述映射
descriptions = {
'china_index': 'A股指数',
'china_etf': 'A股ETF',
'china_stock': 'A股股票',
'hk_index': '港股指数',
'hk_stock': '港股股票',
'us_index': '美股指数',
'us_stock': '美股股票',
'futures': '期货合约',
'crypto': '加密货币',
}
return jsonify({
"code": code,
"asset_type": asset_type,
"description": descriptions.get(asset_type, '未知类型'),
})
@app.route('/api/v1/ohlcv')
def get_ohlcv():
"""
获取单只标的的 OHLCV 数据
Query Parameters:
code: 标的代码 (required)
start: 开始日期,格式 YYYY-MM-DD (optional, 默认90天前)
end: 结束日期,格式 YYYY-MM-DD (optional, 默认今天)
retry: 重试次数 (optional, 默认3)
Returns:
{
"code": "000300.SH",
"asset_type": "china_index",
"data": [
{"date": "2024-01-02", "open": 3500.0, "high": 3550.0, ...},
...
],
"count": 58,
"date_range": {"start": "2024-01-02", "end": "2024-03-29"}
}
"""
code = request.args.get('code', '').strip()
start = request.args.get('start', '').strip()
end = request.args.get('end', '').strip()
retry = request.args.get('retry', '3')
# 参数验证
if not code:
return jsonify({
"error": "Missing required parameter: code",
"example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31"
}), 400
# 设置默认日期
if not start or not end:
default_start, default_end = get_default_dates()
start = start or default_start
end = end or default_end
# 日期格式验证
if not validate_date(start) or not validate_date(end):
return jsonify({
"error": "Invalid date format. Use YYYY-MM-DD",
"start": start,
"end": end,
}), 400
try:
retry_count = int(retry)
except ValueError:
retry_count = 3
# 检测资产类型
asset_type = AssetTypeDetector.detect(code)
# 获取数据
fetcher = get_fetcher()
try:
with fetcher:
df = fetcher.fetch(code, start, end, retry=retry_count)
if df is None or len(df) == 0:
return jsonify({
"code": code,
"asset_type": asset_type,
"error": "No data available for the specified code and date range",
"start": start,
"end": end,
}), 404
result = dataframe_to_json(df)
result['code'] = code
result['asset_type'] = asset_type
return jsonify(result)
except Exception as e:
return jsonify({
"code": code,
"asset_type": asset_type,
"error": str(e),
"start": start,
"end": end,
}), 500
@app.route('/api/v1/ohlcv/batch', methods=['POST'])
def batch_ohlcv():
"""
批量获取多只标的的 OHLCV 数据
Request Body:
{
"codes": ["000300.SH", "NDX", "HSI"],
"start": "2024-01-01",
"end": "2024-03-31",
"retry": 3
}
Returns:
{
"results": {
"000300.SH": {"data": [...], "count": 58, ...},
"NDX": {"data": [...], "count": 61, ...},
"HSI": {"error": "No data available", ...}
},
"success_count": 2,
"failed_count": 1,
"total": 3
}
"""
data = request.get_json()
if not data:
return jsonify({
"error": "Missing request body",
"example": {
"codes": ["000300.SH", "NDX"],
"start": "2024-01-01",
"end": "2024-03-31",
}
}), 400
codes = data.get('codes', [])
start = data.get('start', '').strip()
end = data.get('end', '').strip()
retry = data.get('retry', 3)
if not codes or not isinstance(codes, list):
return jsonify({
"error": "Missing or invalid parameter: codes (must be a list)"
}), 400
# 设置默认日期
if not start or not end:
default_start, default_end = get_default_dates()
start = start or default_start
end = end or default_end
# 日期格式验证
if not validate_date(start) or not validate_date(end):
return jsonify({
"error": "Invalid date format. Use YYYY-MM-DD",
"start": start,
"end": end,
}), 400
# 获取数据
fetcher = get_fetcher()
results = {}
success_count = 0
failed_count = 0
try:
with fetcher:
for code in codes:
try:
df = fetcher.fetch(code, start, end, retry=retry)
if df is not None and len(df) > 0:
result = dataframe_to_json(df)
result['code'] = code
result['asset_type'] = AssetTypeDetector.detect(code)
results[code] = result
success_count += 1
else:
results[code] = {
"code": code,
"asset_type": AssetTypeDetector.detect(code),
"error": "No data available",
"data": [],
"count": 0,
}
failed_count += 1
except Exception as e:
results[code] = {
"code": code,
"asset_type": AssetTypeDetector.detect(code),
"error": str(e),
"data": [],
"count": 0,
}
failed_count += 1
except Exception as e:
return jsonify({
"error": f"Batch fetch failed: {str(e)}"
}), 500
return jsonify({
"results": results,
"success_count": success_count,
"failed_count": failed_count,
"total": len(codes),
"start": start,
"end": end,
})
@app.route('/api/v1/supported-codes')
def get_supported_codes():
"""
获取支持的代码示例
Returns:
{
"china_index": ["000300.SH", "399006.SZ", ...],
"china_etf": ["510300.SH", "159915.SZ", ...],
...
}
"""
return jsonify({
"china_index": {
"description": "A股指数",
"examples": ["000300.SH", "399006.SZ", "000016.SH", "H30269.CSI"],
},
"china_etf": {
"description": "A股ETF",
"examples": ["510300.SH", "159915.SZ", "510500.SH", "513100.SH"],
},
"china_stock": {
"description": "A股股票",
"examples": ["600000.SH", "000001.SZ"],
},
"hk_index": {
"description": "港股指数",
"examples": ["HSI", "HSTECH.HK"],
},
"us_index": {
"description": "美股指数",
"examples": ["NDX", "SPX", "DJI", "N225", "GDAXI"],
},
"us_stock": {
"description": "美股股票",
"examples": ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"],
},
"futures": {
"description": "期货合约",
"examples": ["AU.SHF", "CU.SHF"],
},
"crypto": {
"description": "加密货币",
"examples": ["BTC", "ETH"],
},
})
# ============================================================
# 错误处理
# ============================================================
@app.errorhandler(404)
def not_found(error):
return jsonify({
"error": "Endpoint not found",
"available_endpoints": [
"/",
"/health",
"/api/v1/asset-type?code={code}",
"/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}",
"/api/v1/ohlcv/batch",
"/api/v1/supported-codes",
]
}), 404
@app.errorhandler(500)
def internal_error(error):
return jsonify({
"error": "Internal server error",
"message": str(error)
}), 500
# ============================================================
# 启动服务
# ============================================================
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Universal Data Fetcher API Server')
parser.add_argument('--host', default='0.0.0.0', help='Host to bind (default: 0.0.0.0)')
parser.add_argument('--port', type=int, default=5000, help='Port to bind (default: 5000)')
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
args = parser.parse_args()
# 预加载 SSH 配置
ssh_config = get_ssh_config()
if ssh_config and ssh_config.get('enabled'):
print(f"✓ SSH 隧道已配置: {ssh_config['host']}:{ssh_config['port']}")
else:
print("✗ SSH 隧道未启用仅支持A股数据")
print(f"\n🚀 启动 Universal Data Fetcher API Server")
print(f" Host: {args.host}")
print(f" Port: {args.port}")
print(f" Debug: {args.debug}")
print(f"\n📖 API 文档: http://{args.host}:{args.port}/")
print(f" 健康检查: http://{args.host}:{args.port}/health")
print(f"\n💡 示例请求:")
print(f" curl 'http://{args.host}:{args.port}/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31'")
print(f"\n")
app.run(host=args.host, port=args.port, debug=args.debug)

127
start_flask_server.sh Executable file
View File

@@ -0,0 +1,127 @@
#!/bin/bash
# Flask API 服务启动脚本
# =====================
# 颜色定义
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
RED='\033[0;31m'
NC='\033[0m' # No Color
echo -e "${GREEN}Universal Data Fetcher API 服务启动脚本${NC}"
echo "=========================================="
# 检查 Python
echo -e "\n1. 检查 Python 环境..."
if ! command -v python &> /dev/null; then
echo -e "${RED}✗ Python 未安装${NC}"
exit 1
fi
echo -e "${GREEN}✓ Python 已安装: $(python --version)${NC}"
# 检查依赖
echo -e "\n2. 检查依赖..."
python -c "import flask" 2>/dev/null || {
echo -e "${YELLOW}⚠ Flask 未安装,正在安装...${NC}"
pip install flask flask-cors
}
echo -e "${GREEN}✓ 依赖检查完成${NC}"
# 检查环境变量
echo -e "\n3. 检查环境变量..."
if [ -z "$TUSHARE_TOKEN" ]; then
if [ -f ".env" ]; then
echo -e "${YELLOW}⚠ 从 .env 文件加载环境变量${NC}"
export $(cat .env | grep -v '^#' | xargs)
else
echo -e "${RED}✗ TUSHARE_TOKEN 未设置${NC}"
echo " 请在 .env 文件中设置 TUSHARE_TOKEN"
exit 1
fi
fi
echo -e "${GREEN}✓ 环境变量检查完成${NC}"
# 检查 SSH 配置
echo -e "\n4. 检查 SSH 配置..."
if [ -f "hk_ecs.pem" ]; then
echo -e "${GREEN}✓ SSH 私钥文件存在 (hk_ecs.pem)${NC}"
# 检查权限
PERM=$(stat -f "%Lp" hk_ecs.pem 2>/dev/null || stat -c "%a" hk_ecs.pem 2>/dev/null)
if [ "$PERM" != "600" ]; then
echo -e "${YELLOW}⚠ 修复 SSH 私钥权限...${NC}"
chmod 600 hk_ecs.pem
fi
echo -e "${GREEN}✓ SSH 私钥权限正确 (600)${NC}"
else
echo -e "${YELLOW}⚠ SSH 私钥文件不存在,港美股数据获取将受限${NC}"
fi
# 解析参数
HOST="0.0.0.0"
PORT="5000"
DEBUG=""
while [[ $# -gt 0 ]]; do
case $1 in
--host)
HOST="$2"
shift 2
;;
--port)
PORT="$2"
shift 2
;;
--debug)
DEBUG="--debug"
shift
;;
--with-ssh)
export SSH_ENABLED=true
export SSH_HOST=8.218.167.69
export SSH_PORT=22
export SSH_USERNAME=root
export SSH_KEY_PATH=hk_ecs.pem
export SSH_LOCAL_PORT=1080
echo -e "${GREEN}✓ SSH 隧道已启用${NC}"
shift
;;
--help)
echo ""
echo "用法: ./start_flask_server.sh [选项]"
echo ""
echo "选项:"
echo " --host HOST 绑定主机 (默认: 0.0.0.0)"
echo " --port PORT 绑定端口 (默认: 5000)"
echo " --debug 启用调试模式"
echo " --with-ssh 启用 SSH 隧道"
echo " --help 显示帮助"
echo ""
echo "示例:"
echo " ./start_flask_server.sh"
echo " ./start_flask_server.sh --port 8080"
echo " ./start_flask_server.sh --with-ssh"
echo " ./start_flask_server.sh --host 127.0.0.1 --port 5000 --debug --with-ssh"
exit 0
;;
*)
echo -e "${RED}未知选项: $1${NC}"
echo "使用 --help 查看帮助"
exit 1
;;
esac
done
# 启动服务
echo -e "\n5. 启动 Flask 服务..."
echo -e " 主机: ${YELLOW}$HOST${NC}"
echo -e " 端口: ${YELLOW}$PORT${NC}"
echo -e " 调试: ${YELLOW}$([ -n "$DEBUG" ] && echo "是" || echo "否")${NC}"
echo -e " SSH: ${YELLOW}$([ "$SSH_ENABLED" = "true" ] && echo "启用" || echo "禁用")${NC}"
echo ""
echo -e "${GREEN}✓ 服务启动中...${NC}"
echo "=========================================="
echo ""
python core/datasource/flask_server.py --host "$HOST" --port "$PORT" $DEBUG