wms-py/app/utils/database.py
2025-06-04 10:39:32 +08:00

123 lines
3.8 KiB
Python

from sqlalchemy import create_engine, text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePool
from app.config.settings import settings
from typing import Generator
# 数据库引擎
engine = create_engine(
settings.database_url,
poolclass=QueuePool,
pool_size=10,
max_overflow=20,
pool_pre_ping=True,
pool_recycle=3600,
echo=settings.config.debug # 在调试模式下打印SQL
)
# 会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 基础模型类
Base = declarative_base()
def get_db() -> Generator[Session, None, None]:
db = SessionLocal()
try:
yield db
finally:
db.close()
class DatabaseUtils:
"""数据库工具类"""
@staticmethod
def test_connection() -> dict:
"""测试数据库连接"""
try:
with engine.connect() as conn:
result = conn.execute(text("SELECT 1 as test"))
row = result.fetchone()
return {
"status": "success",
"message": "数据库连接正常",
"test_result": row[0] if row else None,
"database_url": settings.database_url.replace(settings.config.database.password, "****")
}
except Exception as e:
return {
"status": "error",
"message": f"数据库连接失败: {str(e)}"
}
@staticmethod
def get_all_tables() -> dict:
"""获取所有表名"""
try:
with engine.connect() as conn:
result = conn.execute(text("SHOW TABLES"))
tables = [row[0] for row in result.fetchall()]
return {
"status": "success",
"tables": tables,
"count": len(tables)
}
except Exception as e:
return {
"status": "error",
"message": f"获取表名失败: {str(e)}"
}
@staticmethod
def execute_query(query: str, limit: int = 100) -> dict:
"""执行查询语句"""
try:
with engine.connect() as conn:
# 添加LIMIT限制
if "LIMIT" not in query.upper():
query = f"{query} LIMIT {limit}"
result = conn.execute(text(query))
# 获取列名
columns = list(result.keys()) if result.keys() else []
# 获取数据
rows = result.fetchall()
data = [dict(zip(columns, row)) for row in rows]
return {
"status": "success",
"columns": columns,
"data": data,
"count": len(data)
}
except Exception as e:
return {
"status": "error",
"message": f"查询执行失败: {str(e)}"
}
# 初始化数据库连接
def init_database():
"""初始化数据库"""
try:
# 测试连接
test_result = DatabaseUtils.test_connection()
if test_result["status"] == "success":
print(f"✓ 数据库连接成功 - 环境: {settings.environment}")
print(f"✓ 数据库: {settings.config.database.host}:{settings.config.database.port}/{settings.config.database.database}")
else:
print(f"✗ 数据库连接失败: {test_result['message']}")
except Exception as e:
print(f"✗ 数据库初始化错误: {str(e)}")
def close_database():
"""关闭数据库连接"""
engine.dispose()
print("数据库连接已关闭")