Add bot not interrupt and generated opener
This commit is contained in:
@@ -2,15 +2,42 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
from sqlalchemy import inspect, text
|
||||
|
||||
from .db import Base, engine
|
||||
from .routers import assistants, voices, workflows, history, knowledge, llm, asr, tools
|
||||
|
||||
|
||||
def _ensure_assistant_columns() -> None:
|
||||
"""Best-effort SQLite schema evolution for assistant flags."""
|
||||
inspector = inspect(engine)
|
||||
if "assistants" not in inspector.get_table_names():
|
||||
return
|
||||
|
||||
columns = {col["name"] for col in inspector.get_columns("assistants")}
|
||||
alter_statements = []
|
||||
if "generated_opener_enabled" not in columns:
|
||||
alter_statements.append(
|
||||
"ALTER TABLE assistants ADD COLUMN generated_opener_enabled BOOLEAN DEFAULT 0"
|
||||
)
|
||||
if "bot_cannot_be_interrupted" not in columns:
|
||||
alter_statements.append(
|
||||
"ALTER TABLE assistants ADD COLUMN bot_cannot_be_interrupted BOOLEAN DEFAULT 0"
|
||||
)
|
||||
|
||||
if not alter_statements:
|
||||
return
|
||||
|
||||
with engine.begin() as conn:
|
||||
for stmt in alter_statements:
|
||||
conn.execute(text(stmt))
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# 启动时创建表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_ensure_assistant_columns()
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@@ -113,6 +113,7 @@ class Assistant(Base):
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
call_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
opener: Mapped[str] = mapped_column(Text, default="")
|
||||
generated_opener_enabled: Mapped[bool] = mapped_column(default=False)
|
||||
prompt: Mapped[str] = mapped_column(Text, default="")
|
||||
knowledge_base_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
language: Mapped[str] = mapped_column(String(16), default="zh")
|
||||
@@ -121,6 +122,7 @@ class Assistant(Base):
|
||||
speed: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
hotwords: Mapped[dict] = mapped_column(JSON, default=list)
|
||||
tools: Mapped[dict] = mapped_column(JSON, default=list)
|
||||
bot_cannot_be_interrupted: Mapped[bool] = mapped_column(default=False)
|
||||
interruption_sensitivity: Mapped[int] = mapped_column(Integer, default=500)
|
||||
config_mode: Mapped[str] = mapped_column(String(32), default="platform")
|
||||
api_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
|
||||
@@ -21,7 +21,12 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
metadata = {
|
||||
"systemPrompt": assistant.prompt or "",
|
||||
"greeting": assistant.opener or "",
|
||||
"generatedOpenerEnabled": bool(assistant.generated_opener_enabled),
|
||||
"output": {"mode": "audio" if assistant.voice_output_enabled else "text"},
|
||||
"bargeIn": {
|
||||
"enabled": not bool(assistant.bot_cannot_be_interrupted),
|
||||
"minDurationMs": int(assistant.interruption_sensitivity or 500),
|
||||
},
|
||||
"services": {},
|
||||
}
|
||||
warnings = []
|
||||
@@ -100,6 +105,7 @@ def assistant_to_dict(assistant: Assistant) -> dict:
|
||||
"name": assistant.name,
|
||||
"callCount": assistant.call_count,
|
||||
"opener": assistant.opener or "",
|
||||
"generatedOpenerEnabled": bool(assistant.generated_opener_enabled),
|
||||
"prompt": assistant.prompt or "",
|
||||
"knowledgeBaseId": assistant.knowledge_base_id,
|
||||
"language": assistant.language,
|
||||
@@ -108,6 +114,7 @@ def assistant_to_dict(assistant: Assistant) -> dict:
|
||||
"speed": assistant.speed,
|
||||
"hotwords": assistant.hotwords or [],
|
||||
"tools": assistant.tools or [],
|
||||
"botCannotBeInterrupted": bool(assistant.bot_cannot_be_interrupted),
|
||||
"interruptionSensitivity": assistant.interruption_sensitivity,
|
||||
"configMode": assistant.config_mode,
|
||||
"apiUrl": assistant.api_url,
|
||||
@@ -125,8 +132,10 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
|
||||
field_map = {
|
||||
"knowledgeBaseId": "knowledge_base_id",
|
||||
"interruptionSensitivity": "interruption_sensitivity",
|
||||
"botCannotBeInterrupted": "bot_cannot_be_interrupted",
|
||||
"configMode": "config_mode",
|
||||
"voiceOutputEnabled": "voice_output_enabled",
|
||||
"generatedOpenerEnabled": "generated_opener_enabled",
|
||||
"apiUrl": "api_url",
|
||||
"apiKey": "api_key",
|
||||
"llmModelId": "llm_model_id",
|
||||
@@ -184,6 +193,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
||||
user_id=1, # 默认用户,后续添加认证
|
||||
name=data.name,
|
||||
opener=data.opener,
|
||||
generated_opener_enabled=data.generatedOpenerEnabled,
|
||||
prompt=data.prompt,
|
||||
knowledge_base_id=data.knowledgeBaseId,
|
||||
language=data.language,
|
||||
@@ -192,6 +202,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
||||
speed=data.speed,
|
||||
hotwords=data.hotwords,
|
||||
tools=data.tools,
|
||||
bot_cannot_be_interrupted=data.botCannotBeInterrupted,
|
||||
interruption_sensitivity=data.interruptionSensitivity,
|
||||
config_mode=data.configMode,
|
||||
api_url=data.apiUrl,
|
||||
|
||||
@@ -273,6 +273,7 @@ class ToolResourceOut(ToolResourceBase):
|
||||
class AssistantBase(BaseModel):
|
||||
name: str
|
||||
opener: str = ""
|
||||
generatedOpenerEnabled: bool = False
|
||||
prompt: str = ""
|
||||
knowledgeBaseId: Optional[str] = None
|
||||
language: str = "zh"
|
||||
@@ -281,6 +282,7 @@ class AssistantBase(BaseModel):
|
||||
speed: float = 1.0
|
||||
hotwords: List[str] = []
|
||||
tools: List[str] = []
|
||||
botCannotBeInterrupted: bool = False
|
||||
interruptionSensitivity: int = 500
|
||||
configMode: str = "platform"
|
||||
apiUrl: Optional[str] = None
|
||||
@@ -299,6 +301,7 @@ class AssistantCreate(AssistantBase):
|
||||
class AssistantUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
opener: Optional[str] = None
|
||||
generatedOpenerEnabled: Optional[bool] = None
|
||||
prompt: Optional[str] = None
|
||||
knowledgeBaseId: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
@@ -307,6 +310,7 @@ class AssistantUpdate(BaseModel):
|
||||
speed: Optional[float] = None
|
||||
hotwords: Optional[List[str]] = None
|
||||
tools: Optional[List[str]] = None
|
||||
botCannotBeInterrupted: Optional[bool] = None
|
||||
interruptionSensitivity: Optional[int] = None
|
||||
configMode: Optional[str] = None
|
||||
apiUrl: Optional[str] = None
|
||||
|
||||
@@ -24,6 +24,8 @@ class TestAssistantAPI:
|
||||
assert data["prompt"] == sample_assistant_data["prompt"]
|
||||
assert data["language"] == sample_assistant_data["language"]
|
||||
assert data["voiceOutputEnabled"] is True
|
||||
assert data["generatedOpenerEnabled"] is False
|
||||
assert data["botCannotBeInterrupted"] is False
|
||||
assert "id" in data
|
||||
assert data["callCount"] == 0
|
||||
|
||||
@@ -225,3 +227,27 @@ class TestAssistantAPI:
|
||||
metadata = runtime_resp.json()["sessionStartMetadata"]
|
||||
assert metadata["output"]["mode"] == "text"
|
||||
assert metadata["services"]["tts"]["enabled"] is False
|
||||
|
||||
def test_assistant_interrupt_and_generated_opener_flags(self, client, sample_assistant_data):
|
||||
sample_assistant_data.update({
|
||||
"generatedOpenerEnabled": True,
|
||||
"botCannotBeInterrupted": True,
|
||||
"interruptionSensitivity": 900,
|
||||
})
|
||||
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert assistant_resp.status_code == 200
|
||||
assistant_id = assistant_resp.json()["id"]
|
||||
|
||||
get_resp = client.get(f"/api/assistants/{assistant_id}")
|
||||
assert get_resp.status_code == 200
|
||||
payload = get_resp.json()
|
||||
assert payload["generatedOpenerEnabled"] is True
|
||||
assert payload["botCannotBeInterrupted"] is True
|
||||
assert payload["interruptionSensitivity"] == 900
|
||||
|
||||
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
|
||||
assert runtime_resp.status_code == 200
|
||||
metadata = runtime_resp.json()["sessionStartMetadata"]
|
||||
assert metadata["generatedOpenerEnabled"] is True
|
||||
assert metadata["bargeIn"]["enabled"] is False
|
||||
assert metadata["bargeIn"]["minDurationMs"] == 900
|
||||
|
||||
Reference in New Issue
Block a user