diff --git a/backend/app/main.py b/backend/app/main.py index 83d8028..508696e 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -4,7 +4,8 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from app.config import settings -from app.database import Base, engine +from app.database import Base, engine, SessionLocal +from app.models import Mindmap from app.routers import chat, mindmaps logging.basicConfig( @@ -15,6 +16,15 @@ logger = logging.getLogger(__name__) Base.metadata.create_all(bind=engine) +# 删除过期记录 +db = SessionLocal() +try: + deleted_count = Mindmap.delete_expired_records(db) + if deleted_count > 0: + logger.info(f"Deleted {deleted_count} expired mindmap records") +finally: + db.close() + app = FastAPI(title=settings.app_name) app.add_middleware( diff --git a/backend/app/models.py b/backend/app/models.py index 38828e6..5a9465d 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -1,7 +1,7 @@ -from datetime import datetime +from datetime import datetime, timedelta from sqlalchemy import DateTime, Integer, String, Text -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, mapped_column, Session from app.database import Base @@ -22,3 +22,11 @@ class Mindmap(Base): updated_at: Mapped[datetime] = mapped_column( DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False ) + + @classmethod + def delete_expired_records(cls, db: Session, days: int = 5) -> int: + """删除超过指定天数的记录""" + cutoff_date = datetime.utcnow() - timedelta(days=days) + deleted = db.query(cls).filter(cls.created_at < cutoff_date).delete() + db.commit() + return deleted diff --git a/backend/app/routers/mindmaps.py b/backend/app/routers/mindmaps.py index 71880c3..0071bd2 100644 --- a/backend/app/routers/mindmaps.py +++ b/backend/app/routers/mindmaps.py @@ -1,4 +1,5 @@ import json +import logging import secrets from fastapi import APIRouter, Depends, HTTPException, status @@ -9,6 +10,8 @@ from app.database import get_db from app.models import Mindmap from app.schemas import MindmapCreateRequest, MindmapNode, MindmapResponse +logger = logging.getLogger(__name__) + router = APIRouter(prefix="/mindmaps", tags=["mindmaps"]) @@ -44,6 +47,11 @@ def create_mindmap( payload: MindmapCreateRequest, db: Session = Depends(get_db), ) -> MindmapResponse: + # 先删除过期记录 + deleted_count = Mindmap.delete_expired_records(db) + if deleted_count > 0: + logger.info(f"Deleted {deleted_count} expired mindmap records") + title = extract_title_from_json(payload.mindmap_json) raw_json = json.dumps(payload.mindmap_json, ensure_ascii=False) unique_id = generate_unique_id()