修复多个问题

This commit is contained in:
ZhangYonghao
2026-03-21 20:32:19 +08:00
parent f2c371b87d
commit 10d463a55f
12 changed files with 1021 additions and 275 deletions

View File

@@ -1,3 +1,4 @@
import json
import os
from pathlib import Path
@@ -7,16 +8,74 @@ backend_dir = Path(__file__).resolve().parent.parent
load_dotenv(backend_dir / ".env")
def _get_int_env(name: str, default: int) -> int:
value = os.getenv(name, "").strip()
if not value:
return default
try:
return int(value)
except ValueError:
return default
def _get_bool_env(name: str, default: bool = False) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
def _get_list_env(name: str, default: list[str]) -> list[str]:
value = os.getenv(name, "").strip()
if not value:
return default
try:
parsed = json.loads(value)
except json.JSONDecodeError:
parsed = None
if isinstance(parsed, list):
return [str(item).strip() for item in parsed if str(item).strip()]
return [item.strip() for item in value.split(",") if item.strip()]
def _get_path_env(name: str, default: Path) -> Path:
value = os.getenv(name, "").strip()
if not value:
return default.resolve()
path = Path(value)
if not path.is_absolute():
path = backend_dir / path
return path.resolve()
class Settings:
app_name = "HTML Generator API"
api_prefix = "/api"
app_name = os.getenv("APP_NAME", "HTML Knowledge API")
api_prefix = os.getenv("API_PREFIX", "/api").rstrip("/") or "/api"
backend_dir = backend_dir
data_dir = backend_dir / "data"
database_path = data_dir / "html_generator.db"
database_url = f"sqlite:///{database_path.as_posix()}"
allowed_origins = ["*"]
frontend_base_url = os.environ.get("FRONTEND_BASE_URL", "http://localhost:3000")
static_dir = backend_dir / "../frontend/public/static"
allowed_origins = _get_list_env("ALLOWED_ORIGINS", ["*"])
public_base_url = os.getenv("PUBLIC_BASE_URL", "http://localhost:8000").rstrip("/")
html_storage_dir = _get_path_env(
"HTML_STORAGE_DIR",
data_dir / "generated_html",
)
default_retention_days = max(1, _get_int_env("DEFAULT_RETENTION_DAYS", 7))
max_retention_days = max(
default_retention_days,
_get_int_env("MAX_RETENTION_DAYS", 30),
)
max_html_length = max(1024, _get_int_env("MAX_HTML_LENGTH", 200_000))
api_key = os.getenv("API_KEY", "").strip()
allow_unsafe_html = _get_bool_env("ALLOW_UNSAFE_HTML", False)
settings = Settings()
settings = Settings()

View File

@@ -1,6 +1,6 @@
from collections.abc import Generator
from sqlalchemy import create_engine
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.orm import declarative_base, sessionmaker
from app.config import settings
@@ -15,9 +15,39 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def ensure_database_schema() -> None:
inspector = inspect(engine)
if "html_files" not in inspector.get_table_names():
return
existing_columns = {
column["name"]
for column in inspector.get_columns("html_files")
}
column_migrations = {
"title": "ALTER TABLE html_files ADD COLUMN title VARCHAR(120)",
"source": "ALTER TABLE html_files ADD COLUMN source VARCHAR(80)",
"request_id": "ALTER TABLE html_files ADD COLUMN request_id VARCHAR(120)",
"size_bytes": "ALTER TABLE html_files ADD COLUMN size_bytes INTEGER",
"expires_at": "ALTER TABLE html_files ADD COLUMN expires_at DATETIME",
}
with engine.begin() as connection:
for column_name, statement in column_migrations.items():
if column_name not in existing_columns:
connection.execute(text(statement))
connection.execute(
text(
"CREATE INDEX IF NOT EXISTS ix_html_files_expires_at "
"ON html_files (expires_at)"
)
)
def get_db() -> Generator:
db = SessionLocal()
try:
yield db
finally:
db.close()
db.close()

View File

@@ -4,9 +4,9 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.config import settings
from app.database import Base, engine, SessionLocal
from app.models import HTMLFile
from app.database import Base, SessionLocal, engine, ensure_database_schema
from app.routers import html
from app.routers.html import cleanup_expired_files
logging.basicConfig(
level=logging.INFO,
@@ -14,18 +14,19 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
settings.html_storage_dir.mkdir(parents=True, exist_ok=True)
Base.metadata.create_all(bind=engine)
ensure_database_schema()
# 删除过期记录
db = SessionLocal()
try:
deleted_count = HTMLFile.delete_expired_records(db)
if deleted_count > 0:
logger.info(f"Deleted {deleted_count} expired HTML file records")
finally:
db.close()
app = FastAPI(title=settings.app_name)
app = FastAPI(
title=settings.app_name,
version="2.0.0",
description=(
"Store agent-generated educational HTML pages and return a direct access URL. "
"The generated OpenAPI document can be imported directly into Tencent Cloud "
"Agent plugins."
),
)
app.add_middleware(
CORSMiddleware,
@@ -38,6 +39,20 @@ app.add_middleware(
app.include_router(html.router, prefix=settings.api_prefix)
@app.get("/")
@app.on_event("startup")
def cleanup_on_startup() -> None:
db = SessionLocal()
try:
deleted_count = cleanup_expired_files(db)
if deleted_count > 0:
logger.info("Deleted %s expired HTML files during startup", deleted_count)
finally:
db.close()
@app.get("/", summary="Health check")
def health_check() -> dict[str, str]:
return {"message": "HTML Generator API is running"}
return {
"message": "HTML Knowledge API is running",
"openapi_url": "/openapi.json",
}

View File

@@ -1,7 +1,7 @@
from datetime import datetime, timedelta
from sqlalchemy import DateTime, Integer, String
from sqlalchemy.orm import Mapped, mapped_column, Session
from sqlalchemy import DateTime, Integer, String, and_, or_
from sqlalchemy.orm import Mapped, Session, mapped_column
from app.database import Base
@@ -11,20 +11,57 @@ class HTMLFile(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
unique_id: Mapped[str] = mapped_column(
String(32), unique=True, index=True, nullable=False
String(32),
unique=True,
index=True,
nullable=False,
)
filename: Mapped[str] = mapped_column(String(255), nullable=False)
title: Mapped[str | None] = mapped_column(String(120), nullable=True)
source: Mapped[str | None] = mapped_column(String(80), nullable=True)
request_id: Mapped[str | None] = mapped_column(String(120), nullable=True)
size_bytes: Mapped[int | None] = mapped_column(Integer, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, nullable=False
DateTime,
default=datetime.utcnow,
nullable=False,
)
expires_at: Mapped[datetime | None] = mapped_column(
DateTime,
nullable=True,
index=True,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
DateTime,
default=datetime.utcnow,
onupdate=datetime.utcnow,
nullable=False,
)
@classmethod
def delete_expired_records(cls, db: Session, days: int = 5) -> int:
"""删除超过指定天数的记录"""
cutoff_date = datetime.utcnow() - timedelta(days=days)
deleted = db.query(cls).filter(cls.created_at < cutoff_date).delete()
db.commit()
return deleted
def list_expired_records(
cls,
db: Session,
default_retention_days: int,
) -> list["HTMLFile"]:
now = datetime.utcnow()
fallback_cutoff = now - timedelta(days=default_retention_days)
return (
db.query(cls)
.filter(
or_(
cls.expires_at < now,
and_(
cls.expires_at.is_(None),
cls.created_at < fallback_cutoff,
),
)
)
.all()
)
def resolved_expires_at(self, default_retention_days: int) -> datetime:
if self.expires_at is not None:
return self.expires_at
return self.created_at + timedelta(days=default_retention_days)

View File

@@ -1,8 +1,13 @@
import os
import secrets
import logging
import re
import secrets
import tempfile
from datetime import datetime, timedelta
from html import escape
from pathlib import Path
from fastapi import APIRouter, HTTPException, status, Depends
from fastapi import APIRouter, Depends, Header, HTTPException, status
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from app.config import settings
@@ -13,74 +18,308 @@ from app.schemas import HTMLGenerateRequest, HTMLGenerateResponse
router = APIRouter(prefix="/html", tags=["html"])
logger = logging.getLogger(__name__)
DANGEROUS_HTML_PATTERNS = (
(re.compile(r"<\s*script\b", re.IGNORECASE), "script tags are not allowed"),
(re.compile(r"<\s*iframe\b", re.IGNORECASE), "iframe tags are not allowed"),
(re.compile(r"<\s*(?:object|embed|base)\b", re.IGNORECASE), "embedded active content is not allowed"),
(re.compile(r"<\s*form\b", re.IGNORECASE), "form tags are not allowed"),
(re.compile(r"<\s*link\b", re.IGNORECASE), "external stylesheet or import tags are not allowed"),
(
re.compile(r"<\s*meta\b[^>]*http-equiv\s*=\s*['\"]?\s*refresh", re.IGNORECASE),
"automatic refresh or redirect is not allowed",
),
(re.compile(r"\son[a-z]+\s*=", re.IGNORECASE), "inline event handlers are not allowed"),
(re.compile(r"javascript\s*:", re.IGNORECASE), "javascript URLs are not allowed"),
)
def generate_unique_id() -> str:
return secrets.token_urlsafe(16)
CONTENT_SECURITY_POLICY = "; ".join(
[
"default-src 'none'",
"img-src 'self' data: https:",
"style-src 'unsafe-inline'",
"font-src 'self' data: https:",
"media-src https:",
"script-src 'none'",
"connect-src 'none'",
"object-src 'none'",
"base-uri 'none'",
"form-action 'none'",
"frame-ancestors 'none'",
]
)
@router.post("/generate", response_model=HTMLGenerateResponse, status_code=status.HTTP_201_CREATED)
def generate_html(request: HTMLGenerateRequest, db: Session = Depends(get_db)):
def require_api_key(x_api_key: str | None = Header(default=None, alias="X-API-Key")) -> None:
if not settings.api_key:
return
if x_api_key != settings.api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
def build_content_url(unique_id: str) -> str:
return f"{settings.public_base_url}{settings.api_prefix}/html/{unique_id}/content"
def build_query_url(unique_id: str) -> str:
return f"{settings.public_base_url}{settings.api_prefix}/html/{unique_id}"
def generate_unique_id(db: Session) -> str:
for _ in range(10):
unique_id = secrets.token_urlsafe(12).replace("-", "").replace("_", "")
if not db.query(HTMLFile.id).filter(HTMLFile.unique_id == unique_id).first():
return unique_id
raise RuntimeError("Unable to generate a unique id")
def build_html_document(raw_html: str, title: str | None) -> str:
normalized_html = raw_html.strip()
if re.search(r"<!doctype\s+html|<html\b", normalized_html, re.IGNORECASE):
return normalized_html
escaped_title = escape(title or "知识点讲解")
return f"""<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>{escaped_title}</title>
<style>
:root {{
color-scheme: light;
}}
* {{
box-sizing: border-box;
}}
body {{
margin: 0;
background: #f5f7fb;
color: #18202a;
font-family: "PingFang SC", "Microsoft YaHei", sans-serif;
line-height: 1.75;
}}
main {{
max-width: 960px;
margin: 0 auto;
padding: 32px 20px 48px;
}}
</style>
</head>
<body>
<main>
{normalized_html}
</main>
</body>
</html>
"""
def validate_html_safety(html_content: str) -> None:
if settings.allow_unsafe_html:
return
for pattern, message in DANGEROUS_HTML_PATTERNS:
if pattern.search(html_content):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsafe HTML rejected: {message}",
)
def write_html_file(target_path: Path, html_content: str) -> None:
target_path.parent.mkdir(parents=True, exist_ok=True)
temporary_path: Path | None = None
try:
# 先删除过期记录
deleted_count = HTMLFile.delete_expired_records(db)
with tempfile.NamedTemporaryFile(
"w",
encoding="utf-8",
delete=False,
dir=target_path.parent,
suffix=".tmp",
) as temporary_file:
temporary_file.write(html_content)
temporary_path = Path(temporary_file.name)
temporary_path.replace(target_path)
finally:
if temporary_path and temporary_path.exists():
temporary_path.unlink(missing_ok=True)
def delete_stored_file(filename: str) -> None:
file_path = settings.html_storage_dir / filename
if file_path.exists():
file_path.unlink(missing_ok=True)
def cleanup_expired_files(db: Session) -> int:
expired_records = HTMLFile.list_expired_records(
db,
settings.default_retention_days,
)
if not expired_records:
return 0
for record in expired_records:
delete_stored_file(record.filename)
db.delete(record)
db.commit()
return len(expired_records)
def get_record_or_404(unique_id: str, db: Session) -> HTMLFile:
html_file = (
db.query(HTMLFile)
.filter(HTMLFile.unique_id == unique_id)
.first()
)
if html_file is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="HTML file not found",
)
expires_at = html_file.resolved_expires_at(settings.default_retention_days)
if expires_at <= datetime.utcnow():
delete_stored_file(html_file.filename)
db.delete(html_file)
db.commit()
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="HTML file has expired",
)
return html_file
def build_response(html_file: HTMLFile) -> HTMLGenerateResponse:
return HTMLGenerateResponse(
message="HTML file generated successfully",
unique_id=html_file.unique_id,
url=build_content_url(html_file.unique_id),
query_url=build_query_url(html_file.unique_id),
title=html_file.title,
source=html_file.source,
request_id=html_file.request_id,
size_bytes=html_file.size_bytes or 0,
created_at=html_file.created_at,
expires_at=html_file.resolved_expires_at(settings.default_retention_days),
)
@router.post(
"/generate",
response_model=HTMLGenerateResponse,
status_code=status.HTTP_201_CREATED,
summary="Generate and publish an HTML explanation page",
description=(
"Accepts agent-generated HTML, stores it with a unique random filename, "
"and returns a direct access URL."
),
)
def generate_html(
request: HTMLGenerateRequest,
_: None = Depends(require_api_key),
db: Session = Depends(get_db),
) -> HTMLGenerateResponse:
html_path: Path | None = None
try:
deleted_count = cleanup_expired_files(db)
if deleted_count > 0:
logger.info(f"Deleted {deleted_count} expired HTML file records")
# 生成唯一 ID
unique_id = generate_unique_id()
# 确保静态文件目录存在
static_dir = settings.static_dir.resolve()
static_dir.mkdir(parents=True, exist_ok=True)
# 生成 HTML 文件路径
logger.info("Deleted %s expired HTML files", deleted_count)
validate_html_safety(request.html_content)
unique_id = generate_unique_id(db)
html_filename = f"{unique_id}.html"
html_path = static_dir / html_filename
# 写入 HTML 内容
with open(html_path, "w", encoding="utf-8") as f:
f.write(request.html_content)
# 保存到数据库
html_path = settings.html_storage_dir / html_filename
html_document = build_html_document(request.html_content, request.title)
expires_at = datetime.utcnow() + timedelta(
days=request.ttl_days or settings.default_retention_days
)
size_bytes = len(html_document.encode("utf-8"))
write_html_file(html_path, html_document)
html_file = HTMLFile(
unique_id=unique_id,
filename=html_filename,
title=request.title,
source=request.source,
request_id=request.request_id,
size_bytes=size_bytes,
expires_at=expires_at,
)
db.add(html_file)
db.commit()
db.refresh(html_file)
# 生成完整链接
html_url = f"{settings.frontend_base_url}/static/{html_filename}"
return HTMLGenerateResponse(
message="HTML 文件生成成功",
unique_id=unique_id,
url=html_url
)
except Exception as e:
logger.error(f"生成 HTML 文件失败: {str(e)}")
return build_response(html_file)
except HTTPException:
raise
except Exception as exc:
logger.exception("Failed to generate HTML file")
db.rollback()
if html_path and html_path.exists():
html_path.unlink(missing_ok=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"生成 HTML 文件失败: {str(e)}"
)
detail=f"Failed to generate HTML file: {exc}",
) from exc
@router.get("/{unique_id}")
def get_html_file(unique_id: str, db: Session = Depends(get_db)):
html_file = db.query(HTMLFile).filter(HTMLFile.unique_id == unique_id).first()
if not html_file:
@router.get(
"/{unique_id}",
response_model=HTMLGenerateResponse,
summary="Query metadata for a generated HTML file",
)
def get_html_file(unique_id: str, db: Session = Depends(get_db)) -> HTMLGenerateResponse:
html_file = get_record_or_404(unique_id, db)
file_path = settings.html_storage_dir / html_file.filename
if not file_path.exists():
db.delete(html_file)
db.commit()
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="HTML 文件不存在"
detail="HTML file has been removed from storage",
)
# 生成完整链接
html_url = f"{settings.frontend_base_url}/static/{html_file.filename}"
return {
"message": "HTML 文件查询成功",
"unique_id": html_file.unique_id,
"url": html_url
}
return build_response(html_file)
@router.get(
"/{unique_id}/content",
summary="Serve the generated HTML content",
response_description="The generated HTML page",
)
def get_html_content(unique_id: str, db: Session = Depends(get_db)) -> FileResponse:
html_file = get_record_or_404(unique_id, db)
file_path = settings.html_storage_dir / html_file.filename
if not file_path.exists():
db.delete(html_file)
db.commit()
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="HTML file has been removed from storage",
)
return FileResponse(
path=file_path,
media_type="text/html",
headers={
"Content-Security-Policy": CONTENT_SECURITY_POLICY,
"X-Content-Type-Options": "nosniff",
"Referrer-Policy": "no-referrer",
"Cache-Control": "public, max-age=300",
},
)

View File

@@ -1,23 +1,95 @@
from datetime import datetime
from pydantic import BaseModel
from pydantic import BaseModel, Field, root_validator, validator
from app.config import settings
class HTMLGenerateRequest(BaseModel):
html_content: str
html_content: str = Field(
...,
description="Required HTML content or HTML fragment.",
)
title: str | None = Field(
default=None,
max_length=120,
description="Optional page title shown in the generated HTML document.",
)
source: str | None = Field(
default=None,
max_length=80,
description="Optional source identifier such as a Tencent agent name.",
)
request_id: str | None = Field(
default=None,
max_length=120,
description="Optional trace id used for debugging and log correlation.",
)
ttl_days: int | None = Field(
default=None,
ge=1,
description="Optional retention days for the file.",
)
@root_validator(pre=True)
def normalize_aliases(cls, values: dict) -> dict:
alias_map = {
"html": "html_content",
"content": "html_content",
"expire_days": "ttl_days",
}
for alias, target in alias_map.items():
if target not in values and alias in values:
values[target] = values[alias]
return values
@validator("html_content")
def validate_html_content(cls, value: str) -> str:
content = value.strip()
if not content:
raise ValueError("html_content cannot be empty")
if len(content.encode("utf-8")) > settings.max_html_length:
raise ValueError(
f"html_content exceeds the limit of {settings.max_html_length} bytes"
)
return content
@validator("title", "source", "request_id")
def normalize_optional_text(cls, value: str | None) -> str | None:
if value is None:
return None
normalized = value.strip()
return normalized or None
@validator("ttl_days")
def validate_ttl_days(cls, value: int | None) -> int | None:
if value is None:
return None
if value > settings.max_retention_days:
raise ValueError(
f"ttl_days cannot be greater than {settings.max_retention_days}"
)
return value
class Config:
extra = "ignore"
class HTMLGenerateResponse(BaseModel):
message: str
unique_id: str
url: str
class HTMLFileResponse(BaseModel):
id: int
unique_id: str
filename: str
url: str = Field(description="Direct URL that serves the generated HTML content.")
query_url: str = Field(description="Metadata URL for querying the generated record.")
title: str | None = None
source: str | None = None
request_id: str | None = None
size_bytes: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
expires_at: datetime