diff --git a/db_config.py b/db_config.py new file mode 100644 index 0000000..6c523de --- /dev/null +++ b/db_config.py @@ -0,0 +1,125 @@ +""" +数据库配置和连接工具 +""" + +import psycopg2 +from psycopg2.extras import RealDictCursor +from sqlalchemy import create_engine +import pandas as pd +from loguru import logger +import os +from typing import Optional + + +class DatabaseConfig: + """数据库配置类""" + + def __init__(self, env: str = "online"): + self.host = "host.docker.internal" if env == "online" else "192.168.0.115" + self.port = 5432 + self.database = "etf_db" + self.username = "admin" + self.password = "admin" + + @property + def connection_string(self) -> str: + """获取连接字符串""" + return f"postgresql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}" + + @property + def psycopg2_params(self) -> dict: + """获取psycopg2连接参数""" + return { + "host": self.host, + "port": self.port, + "database": self.database, + "user": self.username, + "password": self.password, + } + + +class DatabaseManager: + """数据库管理类""" + + def __init__(self, config: DatabaseConfig = None): + self.config = config or DatabaseConfig() + self.engine = None + + def get_engine(self): + """获取SQLAlchemy引擎""" + if self.engine is None: + self.engine = create_engine( + self.config.connection_string, + pool_pre_ping=True, + pool_recycle=300, + echo=False, + ) + return self.engine + + def get_connection(self): + """获取psycopg2连接""" + return psycopg2.connect(**self.config.psycopg2_params) + + def test_connection(self) -> bool: + """测试数据库连接""" + try: + with self.get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT 1") + result = cursor.fetchone() + logger.info("数据库连接测试成功") + return True + except Exception as e: + logger.error(f"数据库连接测试失败: {e}") + return False + + def create_table_if_not_exists(self, table_name: str, create_sql: str) -> bool: + """创建表(如果不存在)""" + try: + with self.get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(create_sql) + conn.commit() + logger.info(f"表 {table_name} 创建成功或已存在") + return True + except Exception as e: + logger.error(f"创建表 {table_name} 失败: {e}") + return False + + def insert_dataframe( + self, df: pd.DataFrame, table_name: str, if_exists: str = "append" + ) -> bool: + """将DataFrame插入到数据库表""" + try: + engine = self.get_engine() + df.to_sql( + table_name, + engine, + if_exists=if_exists, + index=False, + method="multi", + chunksize=1000, + ) + logger.info(f"成功插入 {len(df)} 条记录到表 {table_name}") + return True + except Exception as e: + logger.error(f"插入数据到表 {table_name} 失败: {e}") + return False + + def execute_query(self, query: str, params: tuple = None) -> Optional[list]: + """执行查询并返回结果""" + try: + with self.get_connection() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute(query, params) + result = cursor.fetchall() + return result + except Exception as e: + logger.error(f"执行查询失败: {e}") + return None + + def close(self): + """关闭连接""" + if self.engine: + self.engine.dispose() + self.engine = None