123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|
from sqlalchemy.pool import QueuePool
|
|
from app.config.settings import settings
|
|
from typing import Generator
|
|
|
|
# 数据库引擎
|
|
engine = create_engine(
|
|
settings.database_url,
|
|
poolclass=QueuePool,
|
|
pool_size=10,
|
|
max_overflow=20,
|
|
pool_pre_ping=True,
|
|
pool_recycle=3600,
|
|
echo=settings.config.debug # 在调试模式下打印SQL
|
|
)
|
|
|
|
# 会话工厂
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
# 基础模型类
|
|
Base = declarative_base()
|
|
|
|
|
|
def get_db() -> Generator[Session, None, None]:
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
class DatabaseUtils:
|
|
"""数据库工具类"""
|
|
|
|
@staticmethod
|
|
def test_connection() -> dict:
|
|
"""测试数据库连接"""
|
|
try:
|
|
with engine.connect() as conn:
|
|
result = conn.execute(text("SELECT 1 as test"))
|
|
row = result.fetchone()
|
|
return {
|
|
"status": "success",
|
|
"message": "数据库连接正常",
|
|
"test_result": row[0] if row else None,
|
|
"database_url": settings.database_url.replace(settings.config.database.password, "****")
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"status": "error",
|
|
"message": f"数据库连接失败: {str(e)}"
|
|
}
|
|
|
|
@staticmethod
|
|
def get_all_tables() -> dict:
|
|
"""获取所有表名"""
|
|
try:
|
|
with engine.connect() as conn:
|
|
result = conn.execute(text("SHOW TABLES"))
|
|
tables = [row[0] for row in result.fetchall()]
|
|
return {
|
|
"status": "success",
|
|
"tables": tables,
|
|
"count": len(tables)
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"status": "error",
|
|
"message": f"获取表名失败: {str(e)}"
|
|
}
|
|
|
|
@staticmethod
|
|
def execute_query(query: str, limit: int = 100) -> dict:
|
|
"""执行查询语句"""
|
|
try:
|
|
with engine.connect() as conn:
|
|
# 添加LIMIT限制
|
|
if "LIMIT" not in query.upper():
|
|
query = f"{query} LIMIT {limit}"
|
|
|
|
result = conn.execute(text(query))
|
|
|
|
# 获取列名
|
|
columns = list(result.keys()) if result.keys() else []
|
|
|
|
# 获取数据
|
|
rows = result.fetchall()
|
|
data = [dict(zip(columns, row)) for row in rows]
|
|
|
|
return {
|
|
"status": "success",
|
|
"columns": columns,
|
|
"data": data,
|
|
"count": len(data)
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"status": "error",
|
|
"message": f"查询执行失败: {str(e)}"
|
|
}
|
|
|
|
|
|
# 初始化数据库连接
|
|
def init_database():
|
|
"""初始化数据库"""
|
|
try:
|
|
# 测试连接
|
|
test_result = DatabaseUtils.test_connection()
|
|
if test_result["status"] == "success":
|
|
print(f"✓ 数据库连接成功 - 环境: {settings.environment}")
|
|
print(f"✓ 数据库: {settings.config.database.host}:{settings.config.database.port}/{settings.config.database.database}")
|
|
else:
|
|
print(f"✗ 数据库连接失败: {test_result['message']}")
|
|
except Exception as e:
|
|
print(f"✗ 数据库初始化错误: {str(e)}")
|
|
|
|
|
|
def close_database():
|
|
"""关闭数据库连接"""
|
|
engine.dispose()
|
|
print("数据库连接已关闭") |