Add bot not interrupt and generated opener

This commit is contained in:
Xin Wang
2026-02-12 13:51:27 +08:00
parent 6179053388
commit d41db6418c
9 changed files with 215 additions and 12 deletions

View File

@@ -2,15 +2,42 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import os
from sqlalchemy import inspect, text
from .db import Base, engine
from .routers import assistants, voices, workflows, history, knowledge, llm, asr, tools
def _ensure_assistant_columns() -> None:
"""Best-effort SQLite schema evolution for assistant flags."""
inspector = inspect(engine)
if "assistants" not in inspector.get_table_names():
return
columns = {col["name"] for col in inspector.get_columns("assistants")}
alter_statements = []
if "generated_opener_enabled" not in columns:
alter_statements.append(
"ALTER TABLE assistants ADD COLUMN generated_opener_enabled BOOLEAN DEFAULT 0"
)
if "bot_cannot_be_interrupted" not in columns:
alter_statements.append(
"ALTER TABLE assistants ADD COLUMN bot_cannot_be_interrupted BOOLEAN DEFAULT 0"
)
if not alter_statements:
return
with engine.begin() as conn:
for stmt in alter_statements:
conn.execute(text(stmt))
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时创建表
Base.metadata.create_all(bind=engine)
_ensure_assistant_columns()
yield

View File

@@ -113,6 +113,7 @@ class Assistant(Base):
name: Mapped[str] = mapped_column(String(255), nullable=False)
call_count: Mapped[int] = mapped_column(Integer, default=0)
opener: Mapped[str] = mapped_column(Text, default="")
generated_opener_enabled: Mapped[bool] = mapped_column(default=False)
prompt: Mapped[str] = mapped_column(Text, default="")
knowledge_base_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
language: Mapped[str] = mapped_column(String(16), default="zh")
@@ -121,6 +122,7 @@ class Assistant(Base):
speed: Mapped[float] = mapped_column(Float, default=1.0)
hotwords: Mapped[dict] = mapped_column(JSON, default=list)
tools: Mapped[dict] = mapped_column(JSON, default=list)
bot_cannot_be_interrupted: Mapped[bool] = mapped_column(default=False)
interruption_sensitivity: Mapped[int] = mapped_column(Integer, default=500)
config_mode: Mapped[str] = mapped_column(String(32), default="platform")
api_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)

View File

@@ -21,7 +21,12 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
metadata = {
"systemPrompt": assistant.prompt or "",
"greeting": assistant.opener or "",
"generatedOpenerEnabled": bool(assistant.generated_opener_enabled),
"output": {"mode": "audio" if assistant.voice_output_enabled else "text"},
"bargeIn": {
"enabled": not bool(assistant.bot_cannot_be_interrupted),
"minDurationMs": int(assistant.interruption_sensitivity or 500),
},
"services": {},
}
warnings = []
@@ -100,6 +105,7 @@ def assistant_to_dict(assistant: Assistant) -> dict:
"name": assistant.name,
"callCount": assistant.call_count,
"opener": assistant.opener or "",
"generatedOpenerEnabled": bool(assistant.generated_opener_enabled),
"prompt": assistant.prompt or "",
"knowledgeBaseId": assistant.knowledge_base_id,
"language": assistant.language,
@@ -108,6 +114,7 @@ def assistant_to_dict(assistant: Assistant) -> dict:
"speed": assistant.speed,
"hotwords": assistant.hotwords or [],
"tools": assistant.tools or [],
"botCannotBeInterrupted": bool(assistant.bot_cannot_be_interrupted),
"interruptionSensitivity": assistant.interruption_sensitivity,
"configMode": assistant.config_mode,
"apiUrl": assistant.api_url,
@@ -125,8 +132,10 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
field_map = {
"knowledgeBaseId": "knowledge_base_id",
"interruptionSensitivity": "interruption_sensitivity",
"botCannotBeInterrupted": "bot_cannot_be_interrupted",
"configMode": "config_mode",
"voiceOutputEnabled": "voice_output_enabled",
"generatedOpenerEnabled": "generated_opener_enabled",
"apiUrl": "api_url",
"apiKey": "api_key",
"llmModelId": "llm_model_id",
@@ -184,6 +193,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
user_id=1, # 默认用户,后续添加认证
name=data.name,
opener=data.opener,
generated_opener_enabled=data.generatedOpenerEnabled,
prompt=data.prompt,
knowledge_base_id=data.knowledgeBaseId,
language=data.language,
@@ -192,6 +202,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
speed=data.speed,
hotwords=data.hotwords,
tools=data.tools,
bot_cannot_be_interrupted=data.botCannotBeInterrupted,
interruption_sensitivity=data.interruptionSensitivity,
config_mode=data.configMode,
api_url=data.apiUrl,

View File

@@ -273,6 +273,7 @@ class ToolResourceOut(ToolResourceBase):
class AssistantBase(BaseModel):
name: str
opener: str = ""
generatedOpenerEnabled: bool = False
prompt: str = ""
knowledgeBaseId: Optional[str] = None
language: str = "zh"
@@ -281,6 +282,7 @@ class AssistantBase(BaseModel):
speed: float = 1.0
hotwords: List[str] = []
tools: List[str] = []
botCannotBeInterrupted: bool = False
interruptionSensitivity: int = 500
configMode: str = "platform"
apiUrl: Optional[str] = None
@@ -299,6 +301,7 @@ class AssistantCreate(AssistantBase):
class AssistantUpdate(BaseModel):
name: Optional[str] = None
opener: Optional[str] = None
generatedOpenerEnabled: Optional[bool] = None
prompt: Optional[str] = None
knowledgeBaseId: Optional[str] = None
language: Optional[str] = None
@@ -307,6 +310,7 @@ class AssistantUpdate(BaseModel):
speed: Optional[float] = None
hotwords: Optional[List[str]] = None
tools: Optional[List[str]] = None
botCannotBeInterrupted: Optional[bool] = None
interruptionSensitivity: Optional[int] = None
configMode: Optional[str] = None
apiUrl: Optional[str] = None

View File

@@ -24,6 +24,8 @@ class TestAssistantAPI:
assert data["prompt"] == sample_assistant_data["prompt"]
assert data["language"] == sample_assistant_data["language"]
assert data["voiceOutputEnabled"] is True
assert data["generatedOpenerEnabled"] is False
assert data["botCannotBeInterrupted"] is False
assert "id" in data
assert data["callCount"] == 0
@@ -225,3 +227,27 @@ class TestAssistantAPI:
metadata = runtime_resp.json()["sessionStartMetadata"]
assert metadata["output"]["mode"] == "text"
assert metadata["services"]["tts"]["enabled"] is False
def test_assistant_interrupt_and_generated_opener_flags(self, client, sample_assistant_data):
sample_assistant_data.update({
"generatedOpenerEnabled": True,
"botCannotBeInterrupted": True,
"interruptionSensitivity": 900,
})
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
get_resp = client.get(f"/api/assistants/{assistant_id}")
assert get_resp.status_code == 200
payload = get_resp.json()
assert payload["generatedOpenerEnabled"] is True
assert payload["botCannotBeInterrupted"] is True
assert payload["interruptionSensitivity"] == 900
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
assert metadata["generatedOpenerEnabled"] is True
assert metadata["bargeIn"]["enabled"] is False
assert metadata["bargeIn"]["minDurationMs"] == 900