126 lines
3.8 KiB
Python
126 lines
3.8 KiB
Python
import json
|
|
import logging
|
|
import uuid
|
|
from typing import AsyncGenerator
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
|
|
|
TENCENT_SSE_URL = "https://wss.lke.cloud.tencent.com/v1/qbot/chat/sse"
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
session_id: str
|
|
content: str
|
|
visitor_biz_id: str = "default_visitor"
|
|
|
|
|
|
async def forward_events(
|
|
response: httpx.Response, request_id: str
|
|
) -> AsyncGenerator[bytes, None]:
|
|
"""Read the upstream SSE stream and forward it as-is."""
|
|
async for line in response.aiter_lines():
|
|
stripped = line.strip()
|
|
if not stripped:
|
|
yield b"\n"
|
|
continue
|
|
|
|
logger.info("[%s] Forwarding: %s", request_id, stripped[:200])
|
|
yield (stripped + "\n").encode("utf-8")
|
|
|
|
|
|
def build_error_event(message: str, request_id: str, code: int = 500) -> bytes:
|
|
event = {
|
|
"type": "error",
|
|
"request_id": request_id,
|
|
"error": {
|
|
"code": code,
|
|
"message": message,
|
|
},
|
|
}
|
|
return (
|
|
f"event: error\n"
|
|
f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
|
|
).encode("utf-8")
|
|
|
|
|
|
@router.post("")
|
|
async def chat(payload: ChatRequest):
|
|
request_id = str(uuid.uuid4())[:8]
|
|
logger.info(
|
|
"[%s] Chat request received: session_id=%s, content=%s",
|
|
request_id,
|
|
payload.session_id,
|
|
payload.content[:50],
|
|
)
|
|
|
|
if not settings.bot_app_key:
|
|
logger.error("[%s] BOT_APP_KEY is not configured", request_id)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="BOT_APP_KEY is not configured",
|
|
)
|
|
|
|
request_body = {
|
|
"request_id": request_id,
|
|
"content": payload.content,
|
|
"bot_app_key": settings.bot_app_key,
|
|
"visitor_biz_id": payload.visitor_biz_id,
|
|
"session_id": payload.session_id,
|
|
"stream": "enable",
|
|
}
|
|
|
|
logger.info("[%s] Sending to Tencent: %s", request_id, TENCENT_SSE_URL)
|
|
|
|
async def stream_generator() -> AsyncGenerator[bytes, None]:
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
try:
|
|
async with client.stream(
|
|
"POST",
|
|
TENCENT_SSE_URL,
|
|
json=request_body,
|
|
headers={"Accept": "text/event-stream"},
|
|
) as response:
|
|
logger.info(
|
|
"[%s] Tencent response status: %s",
|
|
request_id,
|
|
response.status_code,
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
body = await response.aread()
|
|
error_msg = body.decode("utf-8", errors="replace")
|
|
logger.error("[%s] Tencent error: %s", request_id, error_msg)
|
|
yield build_error_event(
|
|
error_msg,
|
|
request_id,
|
|
response.status_code,
|
|
)
|
|
return
|
|
|
|
logger.info("[%s] Starting to forward events", request_id)
|
|
async for chunk in forward_events(response, request_id):
|
|
yield chunk
|
|
|
|
logger.info("[%s] Stream completed", request_id)
|
|
except httpx.RequestError as exc:
|
|
logger.exception("[%s] Request to Tencent failed: %s", request_id, exc)
|
|
yield build_error_event(str(exc), request_id)
|
|
|
|
return StreamingResponse(
|
|
stream_generator(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|