diff --git a/api/.gitignore b/api/.gitignore new file mode 100644 index 0000000..8d12426 --- /dev/null +++ b/api/.gitignore @@ -0,0 +1,66 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +venv/ +ENV/ +env/ +.venv + +# Environment variables +.env +.env.local +.env.*.local + +# Database +*.db +*.sqlite +*.sqlite3 + +# Vector store data +data/vector_store/ +!data/vector_store/.gitkeep + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS +.DS_Store +Thumbs.db + +# Docker +.docker/ + +# Logs +*.log + +# Pytest +.pytest_cache/ +.coverage +htmlcov/ + +# mypy +.mypy_cache/ diff --git a/api/Dockerfile b/api/Dockerfile new file mode 100644 index 0000000..5d4c296 --- /dev/null +++ b/api/Dockerfile @@ -0,0 +1,17 @@ +FROM python:3.11-slim + +WORKDIR /app + +# 安装依赖 +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# 复制代码 +COPY . . + +# 创建数据目录 +RUN mkdir -p /app/data + +EXPOSE 8000 + +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] diff --git a/api/README.md b/api/README.md index f048780..a83aa9d 100644 --- a/api/README.md +++ b/api/README.md @@ -1 +1,103 @@ -# Backend Service \ No newline at end of file +# AI VideoAssistant Backend + +Python 后端 API,配合前端 `ai-videoassistant-frontend` 使用。 + +## 快速开始 + +### 1. 安装依赖 + +```bash +cd ~/Code/ai-videoassistant-backend +pip install -r requirements.txt +``` + +### 2. 初始化数据库 + +```bash +python init_db.py +``` + +这会: +- 创建 `data/app.db` SQLite 数据库 +- 初始化默认声音数据 + +### 3. 启动服务 + +```bash +# 开发模式 (热重载) +python -m uvicorn main:app --reload --host 0.0.0.0 --port 8000 +``` + +### 4. 测试 API + +```bash +# 健康检查 +curl http://localhost:8000/health + +# 获取助手列表 +curl http://localhost:8000/api/assistants + +# 获取声音列表 +curl http://localhost:8000/api/voices + +# 获取通话历史 +curl http://localhost:8000/api/history +``` + +## API 文档 + +| 端点 | 方法 | 说明 | +|------|------|------| +| `/api/assistants` | GET | 助手列表 | +| `/api/assistants` | POST | 创建助手 | +| `/api/assistants/{id}` | GET | 助手详情 | +| `/api/assistants/{id}` | PUT | 更新助手 | +| `/api/assistants/{id}` | DELETE | 删除助手 | +| `/api/voices` | GET | 声音库列表 | +| `/api/history` | GET | 通话历史列表 | +| `/api/history/{id}` | GET | 通话详情 | +| `/api/history/{id}/transcripts` | POST | 添加转写 | +| `/api/history/{id}/audio/{turn}` | GET | 获取音频 | + +## 使用 Docker 启动 + +```bash +cd ~/Code/ai-videoassistant-backend + +# 启动所有服务 +docker-compose up -d + +# 查看日志 +docker-compose logs -f backend +``` + +## 目录结构 + +``` +backend/ +├── app/ +│ ├── __init__.py +│ ├── main.py # FastAPI 入口 +│ ├── db.py # SQLite 连接 +│ ├── models.py # 数据模型 +│ ├── schemas.py # Pydantic 模型 +│ ├── storage.py # MinIO 存储 +│ └── routers/ +│ ├── __init__.py +│ ├── assistants.py # 助手 API +│ └── history.py # 通话记录 API +├── data/ # 数据库文件 +├── requirements.txt +├── .env +└── docker-compose.yml +``` + +## 环境变量 + +| 变量 | 默认值 | 说明 | +|------|--------|------| +| `DATABASE_URL` | `sqlite:///./data/app.db` | 数据库连接 | +| `MINIO_ENDPOINT` | `localhost:9000` | MinIO 地址 | +| `MINIO_ACCESS_KEY` | `admin` | MinIO 密钥 | +| `MINIO_SECRET_KEY` | `password123` | MinIO 密码 | +| `MINIO_BUCKET` | `ai-audio` | 存储桶名称 | diff --git a/api/app/db.py b/api/app/db.py new file mode 100644 index 0000000..c3086c1 --- /dev/null +++ b/api/app/db.py @@ -0,0 +1,19 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, DeclarativeBase + +DATABASE_URL = "sqlite:///./data/app.db" + +engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +class Base(DeclarativeBase): + pass + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/api/app/main.py b/api/app/main.py new file mode 100644 index 0000000..751c5dc --- /dev/null +++ b/api/app/main.py @@ -0,0 +1,72 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager +import os + +from .db import Base, engine +from .routers import assistants, history + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # 启动时创建表 + Base.metadata.create_all(bind=engine) + yield + + +app = FastAPI( + title="AI VideoAssistant API", + description="Backend API for AI VideoAssistant", + version="1.0.0", + lifespan=lifespan +) + +# CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 路由 +app.include_router(assistants.router, prefix="/api") +app.include_router(history.router, prefix="/api") + + +@app.get("/") +def root(): + return {"message": "AI VideoAssistant API", "version": "1.0.0"} + + +@app.get("/health") +def health(): + return {"status": "ok"} + + +# 初始化默认数据 +@app.on_event("startup") +def init_default_data(): + from sqlalchemy.orm import Session + from .db import SessionLocal + from .models import Voice + + db = SessionLocal() + try: + # 检查是否已有数据 + if db.query(Voice).count() == 0: + # 插入默认声音 + voices = [ + Voice(id="v1", name="Xiaoyun", vendor="Ali", gender="Female", language="zh", description="Gentle and professional."), + Voice(id="v2", name="Kevin", vendor="Volcano", gender="Male", language="en", description="Deep and authoritative."), + Voice(id="v3", name="Abby", vendor="Minimax", gender="Female", language="en", description="Cheerful and lively."), + Voice(id="v4", name="Guang", vendor="Ali", gender="Male", language="zh", description="Standard newscast style."), + Voice(id="v5", name="Doubao", vendor="Volcano", gender="Female", language="zh", description="Cute and young."), + ] + for v in voices: + db.add(v) + db.commit() + print("✅ 默认声音数据已初始化") + finally: + db.close() diff --git a/api/app/models.py b/api/app/models.py new file mode 100644 index 0000000..0d80eeb --- /dev/null +++ b/api/app/models.py @@ -0,0 +1,165 @@ +from datetime import datetime +from typing import List, Optional +from sqlalchemy import String, Integer, DateTime, Text, Float, ForeignKey, JSON +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .db import Base + + +class User(Base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + email: Mapped[str] = mapped_column(String(255), unique=True, index=True, nullable=False) + password_hash: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + +class Voice(Base): + __tablename__ = "voices" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + name: Mapped[str] = mapped_column(String(128), nullable=False) + vendor: Mapped[str] = mapped_column(String(64), nullable=False) + gender: Mapped[str] = mapped_column(String(32), nullable=False) + language: Mapped[str] = mapped_column(String(16), nullable=False) + description: Mapped[str] = mapped_column(String(255), nullable=False) + voice_params: Mapped[dict] = mapped_column(JSON, default=dict) + + +class Assistant(Base): + __tablename__ = "assistants" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True) + name: Mapped[str] = mapped_column(String(255), nullable=False) + call_count: Mapped[int] = mapped_column(Integer, default=0) + opener: Mapped[str] = mapped_column(Text, default="") + prompt: Mapped[str] = mapped_column(Text, default="") + knowledge_base_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + language: Mapped[str] = mapped_column(String(16), default="zh") + voice: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + speed: Mapped[float] = mapped_column(Float, default=1.0) + hotwords: Mapped[dict] = mapped_column(JSON, default=list) + tools: Mapped[dict] = mapped_column(JSON, default=list) + interruption_sensitivity: Mapped[int] = mapped_column(Integer, default=500) + config_mode: Mapped[str] = mapped_column(String(32), default="platform") + api_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + api_key: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + user = relationship("User") + call_records = relationship("CallRecord", back_populates="assistant") + + +class KnowledgeBase(Base): + __tablename__ = "knowledge_bases" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str] = mapped_column(Text, default="") + embedding_model: Mapped[str] = mapped_column(String(64), default="text-embedding-3-small") + chunk_size: Mapped[int] = mapped_column(Integer, default=500) + chunk_overlap: Mapped[int] = mapped_column(Integer, default=50) + doc_count: Mapped[int] = mapped_column(Integer, default=0) + chunk_count: Mapped[int] = mapped_column(Integer, default=0) + status: Mapped[str] = mapped_column(String(32), default="active") + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + user = relationship("User") + documents = relationship("KnowledgeDocument", back_populates="kb") + + +class KnowledgeDocument(Base): + __tablename__ = "knowledge_documents" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + kb_id: Mapped[str] = mapped_column(String(64), ForeignKey("knowledge_bases.id"), index=True) + name: Mapped[str] = mapped_column(String(255), nullable=False) + size: Mapped[str] = mapped_column(String(64), nullable=False) + file_type: Mapped[str] = mapped_column(String(32), default="txt") + storage_url: Mapped[Optional[str]] = mapped_column(String(512), nullable=True) + status: Mapped[str] = mapped_column(String(32), default="pending") # pending/processing/completed/failed + chunk_count: Mapped[int] = mapped_column(Integer, default=0) + error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + upload_date: Mapped[str] = mapped_column(String(32), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + processed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + + kb = relationship("KnowledgeBase", back_populates="documents") + + +class Workflow(Base): + __tablename__ = "workflows" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True) + name: Mapped[str] = mapped_column(String(255), nullable=False) + node_count: Mapped[int] = mapped_column(Integer, default=0) + created_at: Mapped[str] = mapped_column(String(32), default="") + updated_at: Mapped[str] = mapped_column(String(32), default="") + global_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + nodes: Mapped[dict] = mapped_column(JSON, default=list) + edges: Mapped[dict] = mapped_column(JSON, default=list) + + user = relationship("User") + + +class CallRecord(Base): + __tablename__ = "call_records" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True) + assistant_id: Mapped[Optional[str]] = mapped_column(String(64), ForeignKey("assistants.id"), index=True) + source: Mapped[str] = mapped_column(String(32), default="debug") + status: Mapped[str] = mapped_column(String(32), default="connected") + started_at: Mapped[str] = mapped_column(String(32), nullable=False) + ended_at: Mapped[Optional[str]] = mapped_column(String(32), nullable=True) + duration_seconds: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + cost: Mapped[float] = mapped_column(Float, default=0.0) + call_metadata: Mapped[dict] = mapped_column(JSON, default=dict) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + user = relationship("User") + assistant = relationship("Assistant", back_populates="call_records") + transcripts = relationship("CallTranscript", back_populates="call_record") + audio_segments = relationship("CallAudioSegment", back_populates="call_record") + + +class CallTranscript(Base): + __tablename__ = "call_transcripts" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + call_id: Mapped[str] = mapped_column(String(64), ForeignKey("call_records.id"), index=True) + turn_index: Mapped[int] = mapped_column(Integer, nullable=False) + speaker: Mapped[str] = mapped_column(String(16), nullable=False) # human/ai + content: Mapped[str] = mapped_column(Text, nullable=False) + confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + start_ms: Mapped[int] = mapped_column(Integer, nullable=False) + end_ms: Mapped[int] = mapped_column(Integer, nullable=False) + duration_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + emotion: Mapped[Optional[str]] = mapped_column(String(32), nullable=True) + + call_record = relationship("CallRecord", back_populates="transcripts") + + +class CallAudioSegment(Base): + __tablename__ = "call_audio_segments" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + call_id: Mapped[str] = mapped_column(String(64), ForeignKey("call_records.id"), index=True) + transcript_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("call_transcripts.id"), nullable=True) + turn_index: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + audio_url: Mapped[str] = mapped_column(String(512), nullable=False) + audio_format: Mapped[str] = mapped_column(String(16), default="mp3") + file_size_bytes: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + start_ms: Mapped[int] = mapped_column(Integer, nullable=False) + end_ms: Mapped[int] = mapped_column(Integer, nullable=False) + duration_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + call_record = relationship("CallRecord", back_populates="audio_segments") diff --git a/api/app/routers/__init__.py b/api/app/routers/__init__.py new file mode 100644 index 0000000..5b92416 --- /dev/null +++ b/api/app/routers/__init__.py @@ -0,0 +1,11 @@ +from fastapi import APIRouter + +from . import assistants +from . import history +from . import knowledge + +router = APIRouter() + +router.include_router(assistants.router) +router.include_router(history.router) +router.include_router(knowledge.router) diff --git a/api/app/routers/assistants.py b/api/app/routers/assistants.py new file mode 100644 index 0000000..b67fc62 --- /dev/null +++ b/api/app/routers/assistants.py @@ -0,0 +1,157 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session +from typing import List +import uuid +from datetime import datetime + +from ..db import get_db +from ..models import Assistant, Voice, Workflow +from ..schemas import ( + AssistantCreate, AssistantUpdate, AssistantOut, + VoiceOut, + WorkflowCreate, WorkflowUpdate, WorkflowOut +) + +router = APIRouter() + + +# ============ Voices ============ +@router.get("/voices", response_model=List[VoiceOut]) +def list_voices(db: Session = Depends(get_db)): + """获取声音库列表""" + voices = db.query(Voice).all() + return voices + + +# ============ Assistants ============ +@router.get("/assistants") +def list_assistants( + page: int = 1, + limit: int = 50, + db: Session = Depends(get_db) +): + """获取助手列表""" + query = db.query(Assistant) + total = query.count() + assistants = query.order_by(Assistant.created_at.desc()) \ + .offset((page-1)*limit).limit(limit).all() + return {"total": total, "page": page, "limit": limit, "list": assistants} + + +@router.get("/assistants/{id}", response_model=AssistantOut) +def get_assistant(id: str, db: Session = Depends(get_db)): + """获取单个助手详情""" + assistant = db.query(Assistant).filter(Assistant.id == id).first() + if not assistant: + raise HTTPException(status_code=404, detail="Assistant not found") + return assistant + + +@router.post("/assistants", response_model=AssistantOut) +def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)): + """创建新助手""" + assistant = Assistant( + id=str(uuid.uuid4())[:8], + user_id=1, # 默认用户,后续添加认证 + name=data.name, + opener=data.opener, + prompt=data.prompt, + knowledge_base_id=data.knowledgeBaseId, + language=data.language, + voice=data.voice, + speed=data.speed, + hotwords=data.hotwords, + tools=data.tools, + interruption_sensitivity=data.interruptionSensitivity, + config_mode=data.configMode, + api_url=data.apiUrl, + api_key=data.apiKey, + ) + db.add(assistant) + db.commit() + db.refresh(assistant) + return assistant + + +@router.put("/assistants/{id}") +def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_db)): + """更新助手""" + assistant = db.query(Assistant).filter(Assistant.id == id).first() + if not assistant: + raise HTTPException(status_code=404, detail="Assistant not found") + + update_data = data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(assistant, field, value) + + assistant.updated_at = datetime.utcnow() + db.commit() + db.refresh(assistant) + return assistant + + +@router.delete("/assistants/{id}") +def delete_assistant(id: str, db: Session = Depends(get_db)): + """删除助手""" + assistant = db.query(Assistant).filter(Assistant.id == id).first() + if not assistant: + raise HTTPException(status_code=404, detail="Assistant not found") + db.delete(assistant) + db.commit() + return {"message": "Deleted successfully"} + + +# ============ Workflows ============ +@router.get("/workflows", response_model=List[WorkflowOut]) +def list_workflows(db: Session = Depends(get_db)): + """获取工作流列表""" + workflows = db.query(Workflow).all() + return workflows + + +@router.post("/workflows", response_model=WorkflowOut) +def create_workflow(data: WorkflowCreate, db: Session = Depends(get_db)): + """创建工作流""" + workflow = Workflow( + id=str(uuid.uuid4())[:8], + user_id=1, + name=data.name, + node_count=data.nodeCount, + created_at=data.createdAt or datetime.utcnow().isoformat(), + updated_at=data.updatedAt or "", + global_prompt=data.globalPrompt, + nodes=data.nodes, + edges=data.edges, + ) + db.add(workflow) + db.commit() + db.refresh(workflow) + return workflow + + +@router.put("/workflows/{id}", response_model=WorkflowOut) +def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)): + """更新工作流""" + workflow = db.query(Workflow).filter(Workflow.id == id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + update_data = data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(workflow, field, value) + + workflow.updated_at = datetime.utcnow().isoformat() + db.commit() + db.refresh(workflow) + return workflow + + +@router.delete("/workflows/{id}") +def delete_workflow(id: str, db: Session = Depends(get_db)): + """删除工作流""" + workflow = db.query(Workflow).filter(Workflow.id == id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + db.delete(workflow) + db.commit() + return {"message": "Deleted successfully"} diff --git a/api/app/routers/history.py b/api/app/routers/history.py new file mode 100644 index 0000000..9434541 --- /dev/null +++ b/api/app/routers/history.py @@ -0,0 +1,188 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session +from typing import Optional, List +import uuid +from datetime import datetime + +from ..db import get_db +from ..models import CallRecord, CallTranscript, CallAudioSegment +from ..storage import get_audio_url + +router = APIRouter(prefix="/history", tags=["history"]) + + +@router.get("") +def list_history( + assistant_id: Optional[str] = None, + status: Optional[str] = None, + page: int = 1, + limit: int = 20, + db: Session = Depends(get_db) +): + """获取通话记录列表""" + query = db.query(CallRecord) + + if assistant_id: + query = query.filter(CallRecord.assistant_id == assistant_id) + if status: + query = query.filter(CallRecord.status == status) + + total = query.count() + records = query.order_by(CallRecord.started_at.desc()) \ + .offset((page-1)*limit).limit(limit).all() + + return {"total": total, "page": page, "limit": limit, "list": records} + + +@router.get("/{call_id}") +def get_history_detail(call_id: str, db: Session = Depends(get_db)): + """获取通话详情""" + record = db.query(CallRecord).filter(CallRecord.id == call_id).first() + if not record: + raise HTTPException(status_code=404, detail="Call record not found") + + # 获取转写 + transcripts = db.query(CallTranscript) \ + .filter(CallTranscript.call_id == call_id) \ + .order_by(CallTranscript.turn_index).all() + + # 补充音频 URL + transcript_list = [] + for t in transcripts: + audio_url = t.audio_url or get_audio_url(call_id, t.turn_index) + transcript_list.append({ + "turnIndex": t.turn_index, + "speaker": t.speaker, + "content": t.content, + "confidence": t.confidence, + "startMs": t.start_ms, + "endMs": t.end_ms, + "durationMs": t.duration_ms, + "audioUrl": audio_url, + }) + + return { + "id": record.id, + "user_id": record.user_id, + "assistant_id": record.assistant_id, + "source": record.source, + "status": record.status, + "started_at": record.started_at, + "ended_at": record.ended_at, + "duration_seconds": record.duration_seconds, + "summary": record.summary, + "transcripts": transcript_list, + } + + +@router.post("") +def create_call_record( + user_id: int, + assistant_id: Optional[str] = None, + source: str = "debug", + db: Session = Depends(get_db) +): + """创建通话记录(引擎回调使用)""" + record = CallRecord( + id=str(uuid.uuid4())[:8], + user_id=user_id, + assistant_id=assistant_id, + source=source, + status="connected", + started_at=datetime.utcnow().isoformat(), + ) + db.add(record) + db.commit() + db.refresh(record) + return record + + +@router.put("/{call_id}") +def update_call_record( + call_id: str, + status: Optional[str] = None, + summary: Optional[str] = None, + duration_seconds: Optional[int] = None, + db: Session = Depends(get_db) +): + """更新通话记录""" + record = db.query(CallRecord).filter(CallRecord.id == call_id).first() + if not record: + raise HTTPException(status_code=404, detail="Call record not found") + + if status: + record.status = status + if summary: + record.summary = summary + if duration_seconds: + record.duration_seconds = duration_seconds + record.ended_at = datetime.utcnow().isoformat() + + db.commit() + return {"message": "Updated successfully"} + + +@router.post("/{call_id}/transcripts") +def add_transcript( + call_id: str, + turn_index: int, + speaker: str, + content: str, + start_ms: int, + end_ms: int, + confidence: Optional[float] = None, + duration_ms: Optional[int] = None, + emotion: Optional[str] = None, + db: Session = Depends(get_db) +): + """添加转写片段""" + transcript = CallTranscript( + call_id=call_id, + turn_index=turn_index, + speaker=speaker, + content=content, + confidence=confidence, + start_ms=start_ms, + end_ms=end_ms, + duration_ms=duration_ms, + emotion=emotion, + ) + db.add(transcript) + db.commit() + db.refresh(transcript) + + # 补充音频 URL + audio_url = get_audio_url(call_id, turn_index) + + return { + "id": transcript.id, + "turn_index": turn_index, + "speaker": speaker, + "content": content, + "confidence": confidence, + "start_ms": start_ms, + "end_ms": end_ms, + "duration_ms": duration_ms, + "audio_url": audio_url, + } + + +@router.get("/{call_id}/audio/{turn_index}") +def get_audio(call_id: str, turn_index: int): + """获取音频文件""" + audio_url = get_audio_url(call_id, turn_index) + if not audio_url: + raise HTTPException(status_code=404, detail="Audio not found") + from fastapi.responses import RedirectResponse + return RedirectResponse(audio_url) + + +@router.delete("/{call_id}") +def delete_call_record(call_id: str, db: Session = Depends(get_db)): + """删除通话记录""" + record = db.query(CallRecord).filter(CallRecord.id == call_id).first() + if not record: + raise HTTPException(status_code=404, detail="Call record not found") + db.delete(record) + db.commit() + return {"message": "Deleted successfully"} diff --git a/api/app/routers/knowledge.py b/api/app/routers/knowledge.py new file mode 100644 index 0000000..2d778fe --- /dev/null +++ b/api/app/routers/knowledge.py @@ -0,0 +1,234 @@ +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.orm import Session +from typing import Optional +import uuid +import os +from datetime import datetime + +from ..db import get_db +from ..models import KnowledgeBase, KnowledgeDocument +from ..schemas import ( + KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseOut, + KnowledgeSearchQuery, KnowledgeSearchResult, KnowledgeStats, + DocumentIndexRequest, +) +from ..vector_store import ( + vector_store, search_knowledge, index_document, delete_document_from_vector +) + +router = APIRouter(prefix="/knowledge", tags=["knowledge"]) + + +def kb_to_dict(kb: KnowledgeBase) -> dict: + return { + "id": kb.id, + "user_id": kb.user_id, + "name": kb.name, + "description": kb.description, + "embedding_model": kb.embedding_model, + "chunk_size": kb.chunk_size, + "chunk_overlap": kb.chunk_overlap, + "doc_count": kb.doc_count, + "chunk_count": kb.chunk_count, + "status": kb.status, + "created_at": kb.created_at.isoformat() if kb.created_at else None, + "updated_at": kb.updated_at.isoformat() if kb.updated_at else None, + } + + +def doc_to_dict(d: KnowledgeDocument) -> dict: + return { + "id": d.id, + "kb_id": d.kb_id, + "name": d.name, + "size": d.size, + "file_type": d.file_type, + "storage_url": d.storage_url, + "status": d.status, + "chunk_count": d.chunk_count, + "error_message": d.error_message, + "upload_date": d.upload_date, + "created_at": d.created_at.isoformat() if d.created_at else None, + "processed_at": d.processed_at.isoformat() if d.processed_at else None, + } + + +# ============ Knowledge Bases ============ +@router.get("/bases") +def list_knowledge_bases(user_id: int = 1, db: Session = Depends(get_db)): + kbs = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == user_id).all() + result = [] + for kb in kbs: + docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all() + kb_data = kb_to_dict(kb) + kb_data["documents"] = [doc_to_dict(d) for d in docs] + result.append(kb_data) + return {"total": len(result), "list": result} + + +@router.get("/bases/{kb_id}") +def get_knowledge_base(kb_id: str, db: Session = Depends(get_db)): + kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() + if not kb: + raise HTTPException(status_code=404, detail="Knowledge base not found") + docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all() + kb_data = kb_to_dict(kb) + kb_data["documents"] = [doc_to_dict(d) for d in docs] + return kb_data + + +@router.post("/bases") +def create_knowledge_base(data: KnowledgeBaseCreate, user_id: int = 1, db: Session = Depends(get_db)): + kb = KnowledgeBase( + id=str(uuid.uuid4())[:8], + user_id=user_id, + name=data.name, + description=data.description, + embedding_model=data.embeddingModel, + chunk_size=data.chunkSize, + chunk_overlap=data.chunkOverlap, + ) + db.add(kb) + db.commit() + db.refresh(kb) + vector_store.create_collection(kb.id, data.embeddingModel) + return kb_to_dict(kb) + + +@router.put("/bases/{kb_id}") +def update_knowledge_base(kb_id: str, data: KnowledgeBaseUpdate, db: Session = Depends(get_db)): + kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() + if not kb: + raise HTTPException(status_code=404, detail="Knowledge base not found") + update_data = data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(kb, field, value) + kb.updated_at = datetime.utcnow() + db.commit() + db.refresh(kb) + return kb_to_dict(kb) + + +@router.delete("/bases/{kb_id}") +def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)): + kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() + if not kb: + raise HTTPException(status_code=404, detail="Knowledge base not found") + vector_store.delete_collection(kb_id) + docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all() + for doc in docs: + db.delete(doc) + db.delete(kb) + db.commit() + return {"message": "Deleted successfully"} + + +# ============ Documents ============ +@router.post("/bases/{kb_id}/documents") +def upload_document( + kb_id: str, + name: str = Query(...), + size: str = Query(...), + file_type: str = Query("txt"), + storage_url: Optional[str] = Query(None), + db: Session = Depends(get_db) +): + kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() + if not kb: + raise HTTPException(status_code=404, detail="Knowledge base not found") + doc = KnowledgeDocument( + id=str(uuid.uuid4())[:8], + kb_id=kb_id, + name=name, + size=size, + file_type=file_type, + storage_url=storage_url, + status="pending", + upload_date=datetime.utcnow().isoformat() + ) + db.add(doc) + db.commit() + db.refresh(doc) + return {"id": doc.id, "name": doc.name, "status": doc.status, "message": "Document created"} + + +@router.post("/bases/{kb_id}/documents/{doc_id}/index") +def index_document_content(kb_id: str, doc_id: str, request: DocumentIndexRequest, db: Session = Depends(get_db)): + # 检查文档是否存在,不存在则创建 + doc = db.query(KnowledgeDocument).filter( + KnowledgeDocument.id == doc_id, + KnowledgeDocument.kb_id == kb_id + ).first() + + if not doc: + doc = KnowledgeDocument( + id=doc_id, + kb_id=kb_id, + name=f"doc-{doc_id}.txt", + size=str(len(request.content)), + file_type="txt", + status="pending", + upload_date=datetime.utcnow().isoformat() + ) + db.add(doc) + db.commit() + db.refresh(doc) + else: + # 更新已有文档 + doc.size = str(len(request.content)) + doc.status = "pending" + db.commit() + + try: + chunk_count = index_document(kb_id, doc_id, request.content) + doc.status = "completed" + doc.chunk_count = chunk_count + doc.processed_at = datetime.utcnow() + kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() + kb.doc_count = db.query(KnowledgeDocument).filter( + KnowledgeDocument.kb_id == kb_id, + KnowledgeDocument.status == "completed" + ).count() + kb.chunk_count += chunk_count + db.commit() + return {"message": "Document indexed", "chunkCount": chunk_count} + except Exception as e: + doc.status = "failed" + doc.error_message = str(e) + db.commit() + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/bases/{kb_id}/documents/{doc_id}") +def delete_document(kb_id: str, doc_id: str, db: Session = Depends(get_db)): + doc = db.query(KnowledgeDocument).filter( + KnowledgeDocument.id == doc_id, + KnowledgeDocument.kb_id == kb_id + ).first() + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + try: + delete_document_from_vector(kb_id, doc_id) + except Exception: + pass + kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() + kb.chunk_count -= doc.chunk_count + kb.doc_count -= 1 + db.delete(doc) + db.commit() + return {"message": "Deleted successfully"} + + +# ============ Search ============ +@router.post("/search") +def search_knowledge_base(query: KnowledgeSearchQuery): + return search_knowledge(kb_id=query.kb_id, query=query.query, n_results=query.nResults) + + +# ============ Stats ============ +@router.get("/bases/{kb_id}/stats") +def get_knowledge_stats(kb_id: str, db: Session = Depends(get_db)): + kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() + if not kb: + raise HTTPException(status_code=404, detail="Knowledge base not found") + return {"kb_id": kb_id, "docCount": kb.doc_count, "chunkCount": kb.chunk_count} diff --git a/api/app/schemas.py b/api/app/schemas.py new file mode 100644 index 0000000..afba1d0 --- /dev/null +++ b/api/app/schemas.py @@ -0,0 +1,271 @@ +from datetime import datetime +from typing import List, Optional +from pydantic import BaseModel + + +# ============ Voice ============ +class VoiceBase(BaseModel): + name: str + vendor: str + gender: str + language: str + description: str + + +class VoiceOut(VoiceBase): + id: str + + class Config: + from_attributes = True + + +# ============ Assistant ============ +class AssistantBase(BaseModel): + name: str + opener: str = "" + prompt: str = "" + knowledgeBaseId: Optional[str] = None + language: str = "zh" + voice: Optional[str] = None + speed: float = 1.0 + hotwords: List[str] = [] + tools: List[str] = [] + interruptionSensitivity: int = 500 + configMode: str = "platform" + apiUrl: Optional[str] = None + apiKey: Optional[str] = None + + +class AssistantCreate(AssistantBase): + pass + + +class AssistantUpdate(AssistantBase): + name: Optional[str] = None + + +class AssistantOut(AssistantBase): + id: str + callCount: int = 0 + created_at: Optional[datetime] = None + + class Config: + from_attributes = True + + +# ============ Knowledge Base ============ +class KnowledgeDocument(BaseModel): + id: str + name: str + size: str + fileType: str = "txt" + storageUrl: Optional[str] = None + status: str = "pending" + chunkCount: int = 0 + uploadDate: str + + +class KnowledgeDocumentCreate(BaseModel): + name: str + size: str + fileType: str = "txt" + storageUrl: Optional[str] = None + + +class KnowledgeDocumentUpdate(BaseModel): + status: Optional[str] = None + chunkCount: Optional[int] = None + errorMessage: Optional[str] = None + + +class KnowledgeBaseBase(BaseModel): + name: str + description: str = "" + embeddingModel: str = "text-embedding-3-small" + chunkSize: int = 500 + chunkOverlap: int = 50 + + +class KnowledgeBaseCreate(KnowledgeBaseBase): + pass + + +class KnowledgeBaseUpdate(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + embeddingModel: Optional[str] = None + chunkSize: Optional[int] = None + chunkOverlap: Optional[int] = None + status: Optional[str] = None + + +class KnowledgeBaseOut(KnowledgeBaseBase): + id: str + docCount: int = 0 + chunkCount: int = 0 + status: str = "active" + createdAt: Optional[datetime] = None + updatedAt: Optional[datetime] = None + documents: List[KnowledgeDocument] = [] + + class Config: + from_attributes = True + + +# ============ Knowledge Search ============ +class KnowledgeSearchQuery(BaseModel): + query: str + kb_id: str + nResults: int = 5 + + +class KnowledgeSearchResult(BaseModel): + query: str + results: List[dict] + + +class DocumentIndexRequest(BaseModel): + document_id: str + content: str + + +class KnowledgeStats(BaseModel): + kb_id: str + docCount: int + chunkCount: int + + +# ============ Workflow ============ +class WorkflowNode(BaseModel): + name: str + type: str + isStart: Optional[bool] = None + metadata: dict + prompt: Optional[str] = None + messagePlan: Optional[dict] = None + variableExtractionPlan: Optional[dict] = None + tool: Optional[dict] = None + globalNodePlan: Optional[dict] = None + + +class WorkflowEdge(BaseModel): + from_: str + to: str + label: Optional[str] = None + + class Config: + populate_by_name = True + + +class WorkflowBase(BaseModel): + name: str + nodeCount: int = 0 + createdAt: str = "" + updatedAt: str = "" + globalPrompt: Optional[str] = None + nodes: List[dict] = [] + edges: List[dict] = [] + + +class WorkflowCreate(WorkflowBase): + pass + + +class WorkflowUpdate(BaseModel): + name: Optional[str] = None + nodeCount: Optional[int] = None + nodes: Optional[List[dict]] = None + edges: Optional[List[dict]] = None + globalPrompt: Optional[str] = None + + +class WorkflowOut(WorkflowBase): + id: str + + class Config: + from_attributes = True + + +# ============ Call Record ============ +class TranscriptSegment(BaseModel): + turnIndex: int + speaker: str # human/ai + content: str + confidence: Optional[float] = None + startMs: int + endMs: int + durationMs: Optional[int] = None + audioUrl: Optional[str] = None + + +class CallRecordCreate(BaseModel): + user_id: int + assistant_id: Optional[str] = None + source: str = "debug" + + +class CallRecordUpdate(BaseModel): + status: Optional[str] = None + summary: Optional[str] = None + duration_seconds: Optional[int] = None + + +class CallRecordOut(BaseModel): + id: str + user_id: int + assistant_id: Optional[str] = None + source: str + status: str + started_at: str + ended_at: Optional[str] = None + duration_seconds: Optional[int] = None + summary: Optional[str] = None + transcripts: List[TranscriptSegment] = [] + + class Config: + from_attributes = True + + +# ============ Call Transcript ============ +class TranscriptCreate(BaseModel): + turn_index: int + speaker: str + content: str + confidence: Optional[float] = None + start_ms: int + end_ms: int + duration_ms: Optional[int] = None + emotion: Optional[str] = None + + +class TranscriptOut(TranscriptCreate): + id: int + audio_url: Optional[str] = None + + class Config: + from_attributes = True + + +# ============ Dashboard ============ +class DashboardStats(BaseModel): + totalCalls: int + answerRate: int + avgDuration: str + humanTransferCount: int + trend: List[dict] + + +# ============ API Response ============ +class Message(BaseModel): + message: str + + +class DocumentIndexRequest(BaseModel): + content: str + + +class ListResponse(BaseModel): + total: int + page: int + limit: int + list: List diff --git a/api/app/storage.py b/api/app/storage.py new file mode 100644 index 0000000..d316e68 --- /dev/null +++ b/api/app/storage.py @@ -0,0 +1,56 @@ +import os +from datetime import datetime +from minio import Minio +import uuid + +# MinIO 配置 +MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "localhost:9000") +MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "admin") +MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "password123") +MINIO_BUCKET = os.getenv("MINIO_BUCKET", "ai-audio") + +# 初始化客户端 +minio_client = Minio( + MINIO_ENDPOINT, + access_key=MINIO_ACCESS_KEY, + secret_key=MINIO_SECRET_KEY, + secure=False +) + + +def ensure_bucket(): + """确保 Bucket 存在""" + try: + if not minio_client.bucket_exists(MINIO_BUCKET): + minio_client.make_bucket(MINIO_BUCKET) + except Exception as e: + print(f"Warning: MinIO bucket check failed: {e}") + + +def upload_audio(file_path: str, call_id: str, turn_index: int) -> str: + """上传音频片段,返回访问 URL""" + ensure_bucket() + + ext = os.path.splitext(file_path)[1] or ".mp3" + object_name = f"{call_id}/{call_id}-{turn_index:03d}{ext}" + + try: + minio_client.fput_object(MINIO_BUCKET, object_name, file_path) + return minio_client.presigned_get_object(MINIO_BUCKET, object_name, expires=604800) + except Exception as e: + print(f"Warning: MinIO upload failed: {e}") + return "" + + +def get_audio_url(call_id: str, turn_index: int) -> str: + """获取音频 URL""" + object_name = f"{call_id}/{call_id}-{turn_index:03d}.mp3" + try: + return minio_client.presigned_get_object(MINIO_BUCKET, object_name, expires=604800) + except Exception: + return "" + + +def generate_local_url(call_id: str, turn_index: int) -> str: + """生成本地 URL(如果不用 MinIO)""" + return f"/api/history/{call_id}/audio/{turn_index}" diff --git a/api/app/vector_store.py b/api/app/vector_store.py new file mode 100644 index 0000000..418484a --- /dev/null +++ b/api/app/vector_store.py @@ -0,0 +1,311 @@ +""" +向量数据库服务 (ChromaDB) +""" +import os +from typing import List, Dict, Optional +import chromadb +from chromadb.config import Settings + +# 配置 +VECTOR_STORE_PATH = os.getenv("VECTOR_STORE_PATH", "./data/vector_store") +COLLECTION_NAME_PREFIX = "kb_" + + +class VectorStore: + """向量存储服务""" + + def __init__(self): + os.makedirs(VECTOR_STORE_PATH, exist_ok=True) + self.client = chromadb.PersistentClient( + path=VECTOR_STORE_PATH, + settings=Settings(anonymized_telemetry=False) + ) + + def get_collection(self, kb_id: str): + """获取知识库集合""" + collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}" + try: + return self.client.get_collection(name=collection_name) + except (ValueError, chromadb.errors.NotFoundError): + return None + + def create_collection(self, kb_id: str, embedding_model: str = "text-embedding-3-small"): + """创建知识库向量集合""" + collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}" + try: + self.client.get_collection(name=collection_name) + return collection_name + except (ValueError, chromadb.errors.NotFoundError): + self.client.create_collection( + name=collection_name, + metadata={ + "kb_id": kb_id, + "embedding_model": embedding_model + } + ) + return collection_name + + def delete_collection(self, kb_id: str): + """删除知识库向量集合""" + collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}" + try: + self.client.delete_collection(name=collection_name) + return True + except (ValueError, chromadb.errors.NotFoundError): + return False + + def add_documents( + self, + kb_id: str, + documents: List[str], + embeddings: Optional[List[List[float]]] = None, + ids: Optional[List[str]] = None, + metadatas: Optional[List[Dict]] = None + ): + """添加文档片段到向量库""" + collection = self.get_collection(kb_id) + + if ids is None: + ids = [f"chunk-{i}" for i in range(len(documents))] + + if embeddings is not None: + collection.add( + documents=documents, + embeddings=embeddings, + ids=ids, + metadatas=metadatas + ) + else: + collection.add( + documents=documents, + ids=ids, + metadatas=metadatas + ) + + return len(documents) + + def search( + self, + kb_id: str, + query: str, + n_results: int = 5, + where: Optional[Dict] = None + ) -> Dict: + """检索相似文档""" + collection = self.get_collection(kb_id) + + # 生成查询向量 + query_embedding = embedding_service.embed_query(query) + + results = collection.query( + query_embeddings=[query_embedding], + n_results=n_results, + where=where + ) + + return results + + def get_stats(self, kb_id: str) -> Dict: + """获取向量库统计""" + collection = self.get_collection(kb_id) + return { + "count": collection.count(), + "kb_id": kb_id + } + + def delete_documents(self, kb_id: str, ids: List[str]): + """删除指定文档片段""" + collection = self.get_collection(kb_id) + collection.delete(ids=ids) + + def delete_by_metadata(self, kb_id: str, document_id: str): + """根据文档 ID 删除所有片段""" + collection = self.get_collection(kb_id) + results = collection.get(where={"document_id": document_id}) + if results["ids"]: + collection.delete(ids=results["ids"]) + + +class EmbeddingService: + """ embedding 服务(支持多种模型)""" + + def __init__(self, model: str = "text-embedding-3-small"): + self.model = model + self._client = None + + def _get_client(self): + """获取 OpenAI 客户端""" + if self._client is None: + try: + from openai import OpenAI + api_key = os.getenv("OPENAI_API_KEY") + if api_key: + self._client = OpenAI(api_key=api_key) + except ImportError: + pass + return self._client + + def embed(self, texts: List[str]) -> List[List[float]]: + """生成 embedding 向量""" + client = self._get_client() + + if client is None: + # 返回随机向量(仅用于测试) + import random + import math + dim = 1536 if "3-small" in self.model else 1024 + return [[random.uniform(-1, 1) for _ in range(dim)] for _ in texts] + + response = client.embeddings.create( + model=self.model, + input=texts + ) + return [data.embedding for data in response.data] + + def embed_query(self, query: str) -> List[float]: + """生成查询向量""" + return self.embed([query])[0] + + +class DocumentProcessor: + """文档处理服务""" + + def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + def chunk_text(self, text: str, document_id: str = "") -> List[Dict]: + """将文本分块""" + # 简单分块(按句子/段落) + import re + + # 按句子分割 + sentences = re.split(r'[。!?\n]', text) + + chunks = [] + current_chunk = "" + current_size = 0 + + for i, sentence in enumerate(sentences): + sentence = sentence.strip() + if not sentence: + continue + + sentence_len = len(sentence) + + if current_size + sentence_len > self.chunk_size and current_chunk: + # 保存当前块 + chunks.append({ + "content": current_chunk.strip(), + "document_id": document_id, + "chunk_index": len(chunks), + "metadata": { + "source": "text" + } + }) + + # 处理重叠 + if self.chunk_overlap > 0: + # 保留末尾部分 + overlap_chars = current_chunk[-self.chunk_overlap:] + current_chunk = overlap_chars + " " + sentence + current_size = len(overlap_chars) + sentence_len + 1 + else: + current_chunk = sentence + current_size = sentence_len + else: + if current_chunk: + current_chunk += " " + current_chunk += sentence + current_size += sentence_len + 1 + + # 保存最后一个块 + if current_chunk.strip(): + chunks.append({ + "content": current_chunk.strip(), + "document_id": document_id, + "chunk_index": len(chunks), + "metadata": { + "source": "text" + } + }) + + return chunks + + def process_document(self, text: str, document_id: str = "") -> List[Dict]: + """完整处理文档""" + return self.chunk_text(text, document_id) + + +# 全局实例 +vector_store = VectorStore() +embedding_service = EmbeddingService() + + +def search_knowledge(kb_id: str, query: str, n_results: int = 5) -> Dict: + """知识库检索""" + # 生成查询向量 + query_vector = embedding_service.embed_query(query) + + # 检索 + results = vector_store.search( + kb_id=kb_id, + query=query, + n_results=n_results + ) + + return { + "query": query, + "results": [ + { + "content": doc, + "metadata": meta, + "distance": dist + } + for doc, meta, dist in zip( + results.get("documents", [[]])[0] if results.get("documents") else [], + results.get("metadatas", [[]])[0] if results.get("metadatas") else [], + results.get("distances", [[]])[0] if results.get("distances") else [] + ) + ] + } + + +def index_document(kb_id: str, document_id: str, text: str) -> int: + """索引文档到向量库""" + # 分块 + processor = DocumentProcessor() + chunks = processor.process_document(text, document_id) + + if not chunks: + return 0 + + # 生成向量 + contents = [c["content"] for c in chunks] + embeddings = embedding_service.embed(contents) + + # 添加到向量库 + ids = [f"{document_id}-{c['chunk_index']}" for c in chunks] + metadatas = [ + { + "document_id": c["document_id"], + "chunk_index": c["chunk_index"], + "kb_id": kb_id + } + for c in chunks + ] + + vector_store.add_documents( + kb_id=kb_id, + documents=contents, + embeddings=embeddings, + ids=ids, + metadatas=metadatas + ) + + return len(chunks) + + +def delete_document_from_vector(kb_id: str, document_id: str): + """从向量库删除文档""" + vector_store.delete_by_metadata(kb_id, document_id) diff --git a/api/init_db.py b/api/init_db.py new file mode 100644 index 0000000..13f40cd --- /dev/null +++ b/api/init_db.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +"""初始化数据库""" +import os +import sys + +# 添加路径 +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from app.db import Base, engine +from app.models import Voice + + +def init_db(): + """创建所有表""" + print("📦 创建数据库表...") + Base.metadata.drop_all(bind=engine) # 删除旧表 + Base.metadata.create_all(bind=engine) + print("✅ 数据库表创建完成") + + +def init_default_voices(): + """初始化默认声音""" + from app.db import SessionLocal + + db = SessionLocal() + try: + if db.query(Voice).count() == 0: + voices = [ + Voice(id="v1", name="Xiaoyun", vendor="Ali", gender="Female", language="zh", description="Gentle and professional."), + Voice(id="v2", name="Kevin", vendor="Volcano", gender="Male", language="en", description="Deep and authoritative."), + Voice(id="v3", name="Abby", vendor="Minimax", gender="Female", language="en", description="Cheerful and lively."), + Voice(id="v4", name="Guang", vendor="Ali", gender="Male", language="zh", description="Standard newscast style."), + Voice(id="v5", name="Doubao", vendor="Volcano", gender="Female", language="zh", description="Cute and young."), + ] + for v in voices: + db.add(v) + db.commit() + print("✅ 默认声音数据已初始化") + else: + print("ℹ️ 声音数据已存在,跳过初始化") + finally: + db.close() + + +if __name__ == "__main__": + # 确保 data 目录存在 + data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data") + os.makedirs(data_dir, exist_ok=True) + + init_db() + init_default_voices() + print("🎉 数据库初始化完成!") diff --git a/api/main.py b/api/main.py new file mode 100644 index 0000000..d74dc3e --- /dev/null +++ b/api/main.py @@ -0,0 +1,73 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager +import os + +from app.db import Base, engine +from app.routers import assistants, history, knowledge + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # 启动时创建表 + Base.metadata.create_all(bind=engine) + yield + + +app = FastAPI( + title="AI VideoAssistant API", + description="Backend API for AI VideoAssistant", + version="1.0.0", + lifespan=lifespan +) + +# CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 路由 +app.include_router(assistants.router, prefix="/api") +app.include_router(history.router, prefix="/api") +app.include_router(knowledge.router, prefix="/api") + + +@app.get("/") +def root(): + return {"message": "AI VideoAssistant API", "version": "1.0.0"} + + +@app.get("/health") +def health(): + return {"status": "ok"} + + +# 初始化默认数据 +@app.on_event("startup") +def init_default_data(): + from sqlalchemy.orm import Session + from app.db import SessionLocal + from app.models import Voice + + db = SessionLocal() + try: + # 检查是否已有数据 + if db.query(Voice).count() == 0: + # 插入默认声音 + voices = [ + Voice(id="v1", name="Xiaoyun", vendor="Ali", gender="Female", language="zh", description="Gentle and professional."), + Voice(id="v2", name="Kevin", vendor="Volcano", gender="Male", language="en", description="Deep and authoritative."), + Voice(id="v3", name="Abby", vendor="Minimax", gender="Female", language="en", description="Cheerful and lively."), + Voice(id="v4", name="Guang", vendor="Ali", gender="Male", language="zh", description="Standard newscast style."), + Voice(id="v5", name="Doubao", vendor="Volcano", gender="Female", language="zh", description="Cute and young."), + ] + for v in voices: + db.add(v) + db.commit() + print("✅ 默认声音数据已初始化") + finally: + db.close() diff --git a/api/requirements.txt b/api/requirements.txt new file mode 100644 index 0000000..95288af --- /dev/null +++ b/api/requirements.txt @@ -0,0 +1,11 @@ +aiosqlite==0.19.0 +fastapi==0.109.0 +uvicorn==0.27.0 +python-multipart==0.0.6 +python-dotenv==1.0.0 +pydantic==2.5.3 +sqlalchemy==2.0.25 +minio==7.2.0 +httpx==0.26.0 +chromadb==0.4.22 +openai==1.12.0 diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml new file mode 100644 index 0000000..f38b875 --- /dev/null +++ b/docker/docker-compose.yml @@ -0,0 +1,49 @@ +version: '3.8' + +services: + # 后端 API + backend: + build: + context: ./backend + dockerfile: Dockerfile + ports: + - "8000:8000" + environment: + - DATABASE_URL=sqlite:///./data/app.db + - MINIO_ENDPOINT=minio:9000 + - MINIO_ACCESS_KEY=admin + - MINIO_SECRET_KEY=password123 + - MINIO_BUCKET=ai-audio + volumes: + - ./backend:/app + - ./backend/data:/app/data + depends_on: + - minio + + # 对话引擎 (py-active-call) + engine: + build: + context: ../py-active-call + dockerfile: Dockerfile + ports: + - "8001:8001" + environment: + - BACKEND_URL=http://backend:8000 + depends_on: + - backend + + # MinIO (S3 兼容存储) + minio: + image: minio/minio + ports: + - "9000:9000" + - "9001:9001" + volumes: + - ./storage/minio/data:/data + environment: + MINIO_ROOT_USER: admin + MINIO_ROOT_PASSWORD: password123 + command: server /data --console-address ":9001" + +volumes: + minio-data: diff --git a/engine/.gitignore b/engine/.gitignore new file mode 100644 index 0000000..5cd10e8 --- /dev/null +++ b/engine/.gitignore @@ -0,0 +1,148 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +Pipfile.lock + +# poetry +poetry.lock + +# pdm +.pdm.toml + +# PEP 582 +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# IDEs +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Project specific +recordings/ +logs/ +running/ diff --git a/engine/README.md b/engine/README.md new file mode 100644 index 0000000..6e7da04 --- /dev/null +++ b/engine/README.md @@ -0,0 +1,25 @@ +# py-active-call-cc + +Python Active-Call: real-time audio streaming with WebSocket and WebRTC. + +This repo contains a Python 3.11+ codebase for building low-latency voice +pipelines (capture, stream, and process audio) using WebRTC and WebSockets. +It is currently in an early, experimental stage. + +# Usage + +启动 + +``` +uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 +``` + +测试 + +``` +python examples/test_websocket.py +``` + +``` +python mic_client.py +``` \ No newline at end of file diff --git a/engine/app/__init__.py b/engine/app/__init__.py new file mode 100644 index 0000000..c136b14 --- /dev/null +++ b/engine/app/__init__.py @@ -0,0 +1 @@ +"""Active-Call Application Package""" diff --git a/engine/app/config.py b/engine/app/config.py new file mode 100644 index 0000000..689eee5 --- /dev/null +++ b/engine/app/config.py @@ -0,0 +1,120 @@ +"""Configuration management using Pydantic settings.""" + +from typing import List, Optional +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict +import json + + +class Settings(BaseSettings): + """Application settings loaded from environment variables.""" + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + extra="ignore" + ) + + # Server Configuration + host: str = Field(default="0.0.0.0", description="Server host address") + port: int = Field(default=8000, description="Server port") + external_ip: Optional[str] = Field(default=None, description="External IP for NAT traversal") + + # Audio Configuration + sample_rate: int = Field(default=16000, description="Audio sample rate in Hz") + chunk_size_ms: int = Field(default=20, description="Audio chunk duration in milliseconds") + default_codec: str = Field(default="pcm", description="Default audio codec") + + # VAD Configuration + vad_type: str = Field(default="silero", description="VAD algorithm type") + vad_model_path: str = Field(default="data/vad/silero_vad.onnx", description="Path to VAD model") + vad_threshold: float = Field(default=0.5, description="VAD detection threshold") + vad_min_speech_duration_ms: int = Field(default=250, description="Minimum speech duration in milliseconds") + vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds") + + # OpenAI / LLM Configuration + openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key") + openai_api_url: Optional[str] = Field(default=None, description="OpenAI API base URL (for Azure/compatible)") + llm_model: str = Field(default="gpt-4o-mini", description="LLM model name") + llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation") + + # TTS Configuration + tts_provider: str = Field(default="siliconflow", description="TTS provider (edge, siliconflow)") + tts_voice: str = Field(default="anna", description="TTS voice name") + tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier") + + # SiliconFlow Configuration + siliconflow_api_key: Optional[str] = Field(default=None, description="SiliconFlow API key") + siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model") + + # ASR Configuration + asr_provider: str = Field(default="siliconflow", description="ASR provider (siliconflow, buffered)") + siliconflow_asr_model: str = Field(default="FunAudioLLM/SenseVoiceSmall", description="SiliconFlow ASR model") + asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms") + asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result") + + # Duplex Pipeline Configuration + duplex_enabled: bool = Field(default=True, description="Enable duplex voice pipeline") + duplex_greeting: Optional[str] = Field(default=None, description="Optional greeting message") + duplex_system_prompt: Optional[str] = Field( + default="You are a helpful, friendly voice assistant. Keep your responses concise and conversational.", + description="System prompt for LLM" + ) + + # Barge-in (interruption) Configuration + barge_in_min_duration_ms: int = Field( + default=200, + description="Minimum speech duration (ms) required to trigger barge-in. Lower=more sensitive." + ) + + # Logging + log_level: str = Field(default="INFO", description="Logging level") + log_format: str = Field(default="json", description="Log format (json or text)") + + # CORS + cors_origins: str = Field( + default='["http://localhost:3000", "http://localhost:8080"]', + description="CORS allowed origins" + ) + + # ICE Servers (WebRTC) + ice_servers: str = Field( + default='[{"urls": "stun:stun.l.google.com:19302"}]', + description="ICE servers configuration" + ) + + # WebSocket heartbeat and inactivity + inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)") + heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds") + + @property + def chunk_size_bytes(self) -> int: + """Calculate chunk size in bytes based on sample rate and duration.""" + # 16-bit (2 bytes) per sample, mono channel + return int(self.sample_rate * 2 * (self.chunk_size_ms / 1000.0)) + + @property + def cors_origins_list(self) -> List[str]: + """Parse CORS origins from JSON string.""" + try: + return json.loads(self.cors_origins) + except json.JSONDecodeError: + return ["http://localhost:3000", "http://localhost:8080"] + + @property + def ice_servers_list(self) -> List[dict]: + """Parse ICE servers from JSON string.""" + try: + return json.loads(self.ice_servers) + except json.JSONDecodeError: + return [{"urls": "stun:stun.l.google.com:19302"}] + + +# Global settings instance +settings = Settings() + + +def get_settings() -> Settings: + """Get application settings instance.""" + return settings diff --git a/engine/app/main.py b/engine/app/main.py new file mode 100644 index 0000000..fa77621 --- /dev/null +++ b/engine/app/main.py @@ -0,0 +1,390 @@ +"""FastAPI application with WebSocket and WebRTC endpoints.""" + +import asyncio +import json +import time +import uuid +from pathlib import Path +from typing import Dict, Any, Optional, List +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, FileResponse +from loguru import logger + +# Try to import aiortc (optional for WebRTC functionality) +try: + from aiortc import RTCPeerConnection, RTCSessionDescription + AIORTC_AVAILABLE = True +except ImportError: + AIORTC_AVAILABLE = False + logger.warning("aiortc not available - WebRTC endpoint will be disabled") + +from app.config import settings +from core.transports import SocketTransport, WebRtcTransport, BaseTransport +from core.session import Session +from processors.tracks import Resampled16kTrack +from core.events import get_event_bus, reset_event_bus + +# Check interval for heartbeat/timeout (seconds) +_HEARTBEAT_CHECK_INTERVAL_SEC = 5 + + +async def heartbeat_and_timeout_task( + transport: BaseTransport, + session: Session, + session_id: str, + last_received_at: List[float], + last_heartbeat_at: List[float], + inactivity_timeout_sec: int, + heartbeat_interval_sec: int, +) -> None: + """ + Background task: send heartBeat every ~heartbeat_interval_sec and close + connection if no message from client for inactivity_timeout_sec. + """ + while True: + await asyncio.sleep(_HEARTBEAT_CHECK_INTERVAL_SEC) + if transport.is_closed: + break + now = time.monotonic() + if now - last_received_at[0] > inactivity_timeout_sec: + logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing") + await session.cleanup() + break + if now - last_heartbeat_at[0] >= heartbeat_interval_sec: + try: + await transport.send_event({ + "event": "heartBeat", + "timestamp": int(time.time() * 1000), + }) + last_heartbeat_at[0] = now + except Exception as e: + logger.debug(f"Session {session_id}: heartbeat send failed: {e}") + break + + +# Initialize FastAPI +app = FastAPI(title="Python Active-Call", version="0.1.0") +_WEB_CLIENT_PATH = Path(__file__).resolve().parent.parent / "examples" / "web_client.html" + +# Configure CORS +app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins_list, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Active sessions storage +active_sessions: Dict[str, Session] = {} + +# Configure logging +logger.remove() +logger.add( + "./logs/active_call_{time}.log", + rotation="1 day", + retention="7 days", + level=settings.log_level, + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}" +) +logger.add( + lambda msg: print(msg, end=""), + level=settings.log_level, + format="{time:HH:mm:ss} | {level: <8} | {message}" +) + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy", "sessions": len(active_sessions)} + + +@app.get("/") +async def web_client_root(): + """Serve the web client.""" + if not _WEB_CLIENT_PATH.exists(): + raise HTTPException(status_code=404, detail="Web client not found") + return FileResponse(_WEB_CLIENT_PATH) + + +@app.get("/client") +async def web_client_alias(): + """Alias for the web client.""" + if not _WEB_CLIENT_PATH.exists(): + raise HTTPException(status_code=404, detail="Web client not found") + return FileResponse(_WEB_CLIENT_PATH) + + + + +@app.get("/iceservers") +async def get_ice_servers(): + """Get ICE servers configuration for WebRTC.""" + return settings.ice_servers_list + + +@app.get("/call/lists") +async def list_calls(): + """List all active calls.""" + return { + "calls": [ + { + "id": session_id, + "state": session.state, + "created_at": session.created_at + } + for session_id, session in active_sessions.items() + ] + } + + +@app.post("/call/kill/{session_id}") +async def kill_call(session_id: str): + """Kill a specific active call.""" + if session_id not in active_sessions: + raise HTTPException(status_code=404, detail="Session not found") + + session = active_sessions[session_id] + await session.cleanup() + del active_sessions[session_id] + + return True + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + """ + WebSocket endpoint for raw audio streaming. + + Accepts mixed text/binary frames: + - Text frames: JSON commands + - Binary frames: PCM audio data (16kHz, 16-bit, mono) + """ + await websocket.accept() + session_id = str(uuid.uuid4()) + + # Create transport and session + transport = SocketTransport(websocket) + session = Session(session_id, transport) + active_sessions[session_id] = session + + logger.info(f"WebSocket connection established: {session_id}") + + last_received_at: List[float] = [time.monotonic()] + last_heartbeat_at: List[float] = [0.0] + hb_task = asyncio.create_task( + heartbeat_and_timeout_task( + transport, + session, + session_id, + last_received_at, + last_heartbeat_at, + settings.inactivity_timeout_sec, + settings.heartbeat_interval_sec, + ) + ) + + try: + # Receive loop + while True: + message = await websocket.receive() + last_received_at[0] = time.monotonic() + + # Handle binary audio data + if "bytes" in message: + await session.handle_audio(message["bytes"]) + + # Handle text commands + elif "text" in message: + await session.handle_text(message["text"]) + + except WebSocketDisconnect: + logger.info(f"WebSocket disconnected: {session_id}") + + except Exception as e: + logger.error(f"WebSocket error: {e}", exc_info=True) + + finally: + hb_task.cancel() + try: + await hb_task + except asyncio.CancelledError: + pass + # Cleanup session + if session_id in active_sessions: + await session.cleanup() + del active_sessions[session_id] + + logger.info(f"Session {session_id} removed") + + +@app.websocket("/webrtc") +async def webrtc_endpoint(websocket: WebSocket): + """ + WebRTC endpoint for WebRTC audio streaming. + + Uses WebSocket for signaling (SDP exchange) and WebRTC for media transport. + """ + # Check if aiortc is available + if not AIORTC_AVAILABLE: + await websocket.close(code=1011, reason="WebRTC not available - aiortc/av not installed") + logger.warning("WebRTC connection attempted but aiortc is not available") + return + await websocket.accept() + session_id = str(uuid.uuid4()) + + # Create WebRTC peer connection + pc = RTCPeerConnection() + + # Create transport and session + transport = WebRtcTransport(websocket, pc) + session = Session(session_id, transport) + active_sessions[session_id] = session + + logger.info(f"WebRTC connection established: {session_id}") + + last_received_at: List[float] = [time.monotonic()] + last_heartbeat_at: List[float] = [0.0] + hb_task = asyncio.create_task( + heartbeat_and_timeout_task( + transport, + session, + session_id, + last_received_at, + last_heartbeat_at, + settings.inactivity_timeout_sec, + settings.heartbeat_interval_sec, + ) + ) + + # Track handler for incoming audio + @pc.on("track") + def on_track(track): + logger.info(f"Track received: {track.kind}") + + if track.kind == "audio": + # Wrap track with resampler + wrapped_track = Resampled16kTrack(track) + + # Create task to pull audio from track + async def pull_audio(): + try: + while True: + frame = await wrapped_track.recv() + # Convert frame to bytes + pcm_bytes = frame.to_ndarray().tobytes() + # Feed to session + await session.handle_audio(pcm_bytes) + except Exception as e: + logger.error(f"Error pulling audio from track: {e}") + + asyncio.create_task(pull_audio()) + + @pc.on("connectionstatechange") + async def on_connectionstatechange(): + logger.info(f"Connection state: {pc.connectionState}") + if pc.connectionState == "failed" or pc.connectionState == "closed": + await session.cleanup() + + try: + # Signaling loop + while True: + message = await websocket.receive() + + if "text" not in message: + continue + + last_received_at[0] = time.monotonic() + data = json.loads(message["text"]) + + # Handle SDP offer/answer + if "sdp" in data and "type" in data: + logger.info(f"Received SDP {data['type']}") + + # Set remote description + offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"]) + await pc.setRemoteDescription(offer) + + # Create and set local description + if data["type"] == "offer": + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + # Send answer back + await websocket.send_text(json.dumps({ + "event": "answer", + "trackId": session_id, + "timestamp": int(asyncio.get_event_loop().time() * 1000), + "sdp": pc.localDescription.sdp + })) + + logger.info(f"Sent SDP answer") + + else: + # Handle other commands + await session.handle_text(message["text"]) + + except WebSocketDisconnect: + logger.info(f"WebRTC WebSocket disconnected: {session_id}") + + except Exception as e: + logger.error(f"WebRTC error: {e}", exc_info=True) + + finally: + hb_task.cancel() + try: + await hb_task + except asyncio.CancelledError: + pass + # Cleanup + await pc.close() + if session_id in active_sessions: + await session.cleanup() + del active_sessions[session_id] + + logger.info(f"WebRTC session {session_id} removed") + + +@app.on_event("startup") +async def startup_event(): + """Run on application startup.""" + logger.info("Starting Python Active-Call server") + logger.info(f"Server: {settings.host}:{settings.port}") + logger.info(f"Sample rate: {settings.sample_rate} Hz") + logger.info(f"VAD model: {settings.vad_model_path}") + + +@app.on_event("shutdown") +async def shutdown_event(): + """Run on application shutdown.""" + logger.info("Shutting down Python Active-Call server") + + # Cleanup all sessions + for session_id, session in active_sessions.items(): + await session.cleanup() + + # Close event bus + event_bus = get_event_bus() + await event_bus.close() + reset_event_bus() + + logger.info("Server shutdown complete") + + +if __name__ == "__main__": + import uvicorn + + # Create logs directory + import os + os.makedirs("logs", exist_ok=True) + + # Run server + uvicorn.run( + "app.main:app", + host=settings.host, + port=settings.port, + reload=True, + log_level=settings.log_level.lower() + ) diff --git a/engine/core/__init__.py b/engine/core/__init__.py new file mode 100644 index 0000000..0110686 --- /dev/null +++ b/engine/core/__init__.py @@ -0,0 +1,20 @@ +"""Core Components Package""" + +from core.events import EventBus, get_event_bus +from core.transports import BaseTransport, SocketTransport, WebRtcTransport +from core.session import Session +from core.conversation import ConversationManager, ConversationState, ConversationTurn +from core.duplex_pipeline import DuplexPipeline + +__all__ = [ + "EventBus", + "get_event_bus", + "BaseTransport", + "SocketTransport", + "WebRtcTransport", + "Session", + "ConversationManager", + "ConversationState", + "ConversationTurn", + "DuplexPipeline", +] diff --git a/engine/core/conversation.py b/engine/core/conversation.py new file mode 100644 index 0000000..f3cb63a --- /dev/null +++ b/engine/core/conversation.py @@ -0,0 +1,255 @@ +"""Conversation management for voice AI. + +Handles conversation context, turn-taking, and message history +for multi-turn voice conversations. +""" + +import asyncio +from typing import List, Optional, Dict, Any, Callable, Awaitable +from dataclasses import dataclass, field +from enum import Enum +from loguru import logger + +from services.base import LLMMessage + + +class ConversationState(Enum): + """State of the conversation.""" + IDLE = "idle" # Waiting for user input + LISTENING = "listening" # User is speaking + PROCESSING = "processing" # Processing user input (LLM) + SPEAKING = "speaking" # Bot is speaking + INTERRUPTED = "interrupted" # Bot was interrupted + + +@dataclass +class ConversationTurn: + """A single turn in the conversation.""" + role: str # "user" or "assistant" + text: str + audio_duration_ms: Optional[int] = None + timestamp: float = field(default_factory=lambda: asyncio.get_event_loop().time()) + was_interrupted: bool = False + + +class ConversationManager: + """ + Manages conversation state and history. + + Provides: + - Message history for LLM context + - Turn management + - State tracking + - Event callbacks for state changes + """ + + def __init__( + self, + system_prompt: Optional[str] = None, + max_history: int = 20, + greeting: Optional[str] = None + ): + """ + Initialize conversation manager. + + Args: + system_prompt: System prompt for LLM + max_history: Maximum number of turns to keep + greeting: Optional greeting message when conversation starts + """ + self.system_prompt = system_prompt or ( + "You are a helpful, friendly voice assistant. " + "Keep your responses concise and conversational. " + "Respond naturally as if having a phone conversation. " + "If you don't understand something, ask for clarification." + ) + self.max_history = max_history + self.greeting = greeting + + # State + self.state = ConversationState.IDLE + self.turns: List[ConversationTurn] = [] + + # Callbacks + self._state_callbacks: List[Callable[[ConversationState, ConversationState], Awaitable[None]]] = [] + self._turn_callbacks: List[Callable[[ConversationTurn], Awaitable[None]]] = [] + + # Current turn tracking + self._current_user_text: str = "" + self._current_assistant_text: str = "" + + logger.info("ConversationManager initialized") + + def on_state_change( + self, + callback: Callable[[ConversationState, ConversationState], Awaitable[None]] + ) -> None: + """Register callback for state changes.""" + self._state_callbacks.append(callback) + + def on_turn_complete( + self, + callback: Callable[[ConversationTurn], Awaitable[None]] + ) -> None: + """Register callback for turn completion.""" + self._turn_callbacks.append(callback) + + async def set_state(self, new_state: ConversationState) -> None: + """Set conversation state and notify listeners.""" + if new_state != self.state: + old_state = self.state + self.state = new_state + logger.debug(f"Conversation state: {old_state.value} -> {new_state.value}") + + for callback in self._state_callbacks: + try: + await callback(old_state, new_state) + except Exception as e: + logger.error(f"State callback error: {e}") + + def get_messages(self) -> List[LLMMessage]: + """ + Get conversation history as LLM messages. + + Returns: + List of LLMMessage objects including system prompt + """ + messages = [LLMMessage(role="system", content=self.system_prompt)] + + # Add conversation history + for turn in self.turns[-self.max_history:]: + messages.append(LLMMessage(role=turn.role, content=turn.text)) + + # Add current user text if any + if self._current_user_text: + messages.append(LLMMessage(role="user", content=self._current_user_text)) + + return messages + + async def start_user_turn(self) -> None: + """Signal that user has started speaking.""" + await self.set_state(ConversationState.LISTENING) + self._current_user_text = "" + + async def update_user_text(self, text: str, is_final: bool = False) -> None: + """ + Update current user text (from ASR). + + Args: + text: Transcribed text + is_final: Whether this is the final transcript + """ + self._current_user_text = text + + async def end_user_turn(self, text: str) -> None: + """ + End user turn and add to history. + + Args: + text: Final user text + """ + if text.strip(): + turn = ConversationTurn(role="user", text=text.strip()) + self.turns.append(turn) + + for callback in self._turn_callbacks: + try: + await callback(turn) + except Exception as e: + logger.error(f"Turn callback error: {e}") + + logger.info(f"User: {text[:50]}...") + + self._current_user_text = "" + await self.set_state(ConversationState.PROCESSING) + + async def start_assistant_turn(self) -> None: + """Signal that assistant has started speaking.""" + await self.set_state(ConversationState.SPEAKING) + self._current_assistant_text = "" + + async def update_assistant_text(self, text: str) -> None: + """ + Update current assistant text (streaming). + + Args: + text: Text chunk from LLM + """ + self._current_assistant_text += text + + async def end_assistant_turn(self, was_interrupted: bool = False) -> None: + """ + End assistant turn and add to history. + + Args: + was_interrupted: Whether the turn was interrupted by user + """ + text = self._current_assistant_text.strip() + if text: + turn = ConversationTurn( + role="assistant", + text=text, + was_interrupted=was_interrupted + ) + self.turns.append(turn) + + for callback in self._turn_callbacks: + try: + await callback(turn) + except Exception as e: + logger.error(f"Turn callback error: {e}") + + status = " (interrupted)" if was_interrupted else "" + logger.info(f"Assistant{status}: {text[:50]}...") + + self._current_assistant_text = "" + + if was_interrupted: + await self.set_state(ConversationState.INTERRUPTED) + else: + await self.set_state(ConversationState.IDLE) + + async def interrupt(self) -> None: + """Handle interruption (barge-in).""" + if self.state == ConversationState.SPEAKING: + await self.end_assistant_turn(was_interrupted=True) + + def reset(self) -> None: + """Reset conversation history.""" + self.turns = [] + self._current_user_text = "" + self._current_assistant_text = "" + self.state = ConversationState.IDLE + logger.info("Conversation reset") + + @property + def turn_count(self) -> int: + """Get number of turns in conversation.""" + return len(self.turns) + + @property + def last_user_text(self) -> Optional[str]: + """Get last user text.""" + for turn in reversed(self.turns): + if turn.role == "user": + return turn.text + return None + + @property + def last_assistant_text(self) -> Optional[str]: + """Get last assistant text.""" + for turn in reversed(self.turns): + if turn.role == "assistant": + return turn.text + return None + + def get_context_summary(self) -> Dict[str, Any]: + """Get a summary of conversation context.""" + return { + "state": self.state.value, + "turn_count": self.turn_count, + "last_user": self.last_user_text, + "last_assistant": self.last_assistant_text, + "current_user": self._current_user_text or None, + "current_assistant": self._current_assistant_text or None + } diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py new file mode 100644 index 0000000..9d4a938 --- /dev/null +++ b/engine/core/duplex_pipeline.py @@ -0,0 +1,719 @@ +"""Full duplex audio pipeline for AI voice conversation. + +This module implements the core duplex pipeline that orchestrates: +- VAD (Voice Activity Detection) +- EOU (End of Utterance) Detection +- ASR (Automatic Speech Recognition) - optional +- LLM (Language Model) +- TTS (Text-to-Speech) + +Inspired by pipecat's frame-based architecture and active-call's +event-driven design. +""" + +import asyncio +import time +from typing import Optional, Callable, Awaitable +from loguru import logger + +from core.transports import BaseTransport +from core.conversation import ConversationManager, ConversationState +from core.events import get_event_bus +from processors.vad import VADProcessor, SileroVAD +from processors.eou import EouDetector +from services.base import BaseLLMService, BaseTTSService, BaseASRService +from services.llm import OpenAILLMService, MockLLMService +from services.tts import EdgeTTSService, MockTTSService +from services.asr import BufferedASRService +from services.siliconflow_tts import SiliconFlowTTSService +from services.siliconflow_asr import SiliconFlowASRService +from app.config import settings + + +class DuplexPipeline: + """ + Full duplex audio pipeline for AI voice conversation. + + Handles bidirectional audio flow with: + - User speech detection and transcription + - AI response generation + - Text-to-speech synthesis + - Barge-in (interruption) support + + Architecture (inspired by pipecat): + + User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out + ↓ + Barge-in Detection → Interrupt + """ + + def __init__( + self, + transport: BaseTransport, + session_id: str, + llm_service: Optional[BaseLLMService] = None, + tts_service: Optional[BaseTTSService] = None, + asr_service: Optional[BaseASRService] = None, + system_prompt: Optional[str] = None, + greeting: Optional[str] = None + ): + """ + Initialize duplex pipeline. + + Args: + transport: Transport for sending audio/events + session_id: Session identifier + llm_service: LLM service (defaults to OpenAI) + tts_service: TTS service (defaults to EdgeTTS) + asr_service: ASR service (optional) + system_prompt: System prompt for LLM + greeting: Optional greeting to speak on start + """ + self.transport = transport + self.session_id = session_id + self.event_bus = get_event_bus() + + # Initialize VAD + self.vad_model = SileroVAD( + model_path=settings.vad_model_path, + sample_rate=settings.sample_rate + ) + self.vad_processor = VADProcessor( + vad_model=self.vad_model, + threshold=settings.vad_threshold + ) + + # Initialize EOU detector + self.eou_detector = EouDetector( + silence_threshold_ms=settings.vad_eou_threshold_ms, + min_speech_duration_ms=settings.vad_min_speech_duration_ms + ) + + # Initialize services + self.llm_service = llm_service + self.tts_service = tts_service + self.asr_service = asr_service # Will be initialized in start() + + # Track last sent transcript to avoid duplicates + self._last_sent_transcript = "" + + # Conversation manager + self.conversation = ConversationManager( + system_prompt=system_prompt, + greeting=greeting + ) + + # State + self._running = True + self._is_bot_speaking = False + self._current_turn_task: Optional[asyncio.Task] = None + self._audio_buffer: bytes = b"" + max_buffer_seconds = settings.max_audio_buffer_seconds if hasattr(settings, "max_audio_buffer_seconds") else 30 + self._max_audio_buffer_bytes = int(settings.sample_rate * 2 * max_buffer_seconds) + self._last_vad_status: str = "Silence" + self._process_lock = asyncio.Lock() + + # Interruption handling + self._interrupt_event = asyncio.Event() + + # Latency tracking - TTFB (Time to First Byte) + self._turn_start_time: Optional[float] = None + self._first_audio_sent: bool = False + + # Barge-in filtering - require minimum speech duration to interrupt + self._barge_in_speech_start_time: Optional[float] = None + self._barge_in_min_duration_ms: int = settings.barge_in_min_duration_ms if hasattr(settings, 'barge_in_min_duration_ms') else 50 + self._barge_in_speech_frames: int = 0 # Count speech frames + self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in + self._barge_in_silence_tolerance: int = 3 # Allow up to 3 silence frames (60ms at 20ms chunks) + + logger.info(f"DuplexPipeline initialized for session {session_id}") + + async def start(self) -> None: + """Start the pipeline and connect services.""" + try: + # Connect LLM service + if not self.llm_service: + if settings.openai_api_key: + self.llm_service = OpenAILLMService( + api_key=settings.openai_api_key, + base_url=settings.openai_api_url, + model=settings.llm_model + ) + else: + logger.warning("No OpenAI API key - using mock LLM") + self.llm_service = MockLLMService() + + await self.llm_service.connect() + + # Connect TTS service + if not self.tts_service: + if settings.tts_provider == "siliconflow" and settings.siliconflow_api_key: + self.tts_service = SiliconFlowTTSService( + api_key=settings.siliconflow_api_key, + voice=settings.tts_voice, + model=settings.siliconflow_tts_model, + sample_rate=settings.sample_rate, + speed=settings.tts_speed + ) + logger.info("Using SiliconFlow TTS service") + else: + self.tts_service = EdgeTTSService( + voice=settings.tts_voice, + sample_rate=settings.sample_rate + ) + logger.info("Using Edge TTS service") + + await self.tts_service.connect() + + # Connect ASR service + if not self.asr_service: + if settings.asr_provider == "siliconflow" and settings.siliconflow_api_key: + self.asr_service = SiliconFlowASRService( + api_key=settings.siliconflow_api_key, + model=settings.siliconflow_asr_model, + sample_rate=settings.sample_rate, + interim_interval_ms=settings.asr_interim_interval_ms, + min_audio_for_interim_ms=settings.asr_min_audio_ms, + on_transcript=self._on_transcript_callback + ) + logger.info("Using SiliconFlow ASR service") + else: + self.asr_service = BufferedASRService( + sample_rate=settings.sample_rate + ) + logger.info("Using Buffered ASR service (no real transcription)") + + await self.asr_service.connect() + + logger.info("DuplexPipeline services connected") + + # Speak greeting if configured + if self.conversation.greeting: + await self._speak(self.conversation.greeting) + + except Exception as e: + logger.error(f"Failed to start pipeline: {e}") + raise + + async def process_audio(self, pcm_bytes: bytes) -> None: + """ + Process incoming audio chunk. + + This is the main entry point for audio from the user. + + Args: + pcm_bytes: PCM audio data (16-bit, mono, 16kHz) + """ + if not self._running: + return + + try: + async with self._process_lock: + # 1. Process through VAD + vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms) + + vad_status = "Silence" + if vad_result: + event_type, probability = vad_result + vad_status = "Speech" if event_type == "speaking" else "Silence" + + # Emit VAD event + await self.event_bus.publish(event_type, { + "trackId": self.session_id, + "probability": probability + }) + else: + # No state change - keep previous status + vad_status = self._last_vad_status + + # Update state based on VAD + if vad_status == "Speech" and self._last_vad_status != "Speech": + await self._on_speech_start() + + self._last_vad_status = vad_status + + # 2. Check for barge-in (user speaking while bot speaking) + # Filter false interruptions by requiring minimum speech duration + if self._is_bot_speaking: + if vad_status == "Speech": + # User is speaking while bot is speaking + self._barge_in_silence_frames = 0 # Reset silence counter + + if self._barge_in_speech_start_time is None: + # Start tracking speech duration + self._barge_in_speech_start_time = time.time() + self._barge_in_speech_frames = 1 + logger.debug("Potential barge-in detected, tracking duration...") + else: + self._barge_in_speech_frames += 1 + # Check if speech duration exceeds threshold + speech_duration_ms = (time.time() - self._barge_in_speech_start_time) * 1000 + if speech_duration_ms >= self._barge_in_min_duration_ms: + logger.info(f"Barge-in confirmed after {speech_duration_ms:.0f}ms of speech ({self._barge_in_speech_frames} frames)") + await self._handle_barge_in() + else: + # Silence frame during potential barge-in + if self._barge_in_speech_start_time is not None: + self._barge_in_silence_frames += 1 + # Allow brief silence gaps (VAD flickering) + if self._barge_in_silence_frames > self._barge_in_silence_tolerance: + # Too much silence - reset barge-in tracking + logger.debug(f"Barge-in cancelled after {self._barge_in_silence_frames} silence frames") + self._barge_in_speech_start_time = None + self._barge_in_speech_frames = 0 + self._barge_in_silence_frames = 0 + + # 3. Buffer audio for ASR + if vad_status == "Speech" or self.conversation.state == ConversationState.LISTENING: + self._audio_buffer += pcm_bytes + if len(self._audio_buffer) > self._max_audio_buffer_bytes: + # Keep only the most recent audio to cap memory usage + self._audio_buffer = self._audio_buffer[-self._max_audio_buffer_bytes:] + await self.asr_service.send_audio(pcm_bytes) + + # For SiliconFlow ASR, trigger interim transcription periodically + # The service handles timing internally via start_interim_transcription() + + # 4. Check for End of Utterance - this triggers LLM response + if self.eou_detector.process(vad_status): + await self._on_end_of_utterance() + + except Exception as e: + logger.error(f"Pipeline audio processing error: {e}", exc_info=True) + + async def process_text(self, text: str) -> None: + """ + Process text input (chat command). + + Allows direct text input to bypass ASR. + + Args: + text: User text input + """ + if not self._running: + return + + logger.info(f"Processing text input: {text[:50]}...") + + # Cancel any current speaking + await self._stop_current_speech() + + # Start new turn + await self.conversation.end_user_turn(text) + self._current_turn_task = asyncio.create_task(self._handle_turn(text)) + + async def interrupt(self) -> None: + """Interrupt current bot speech (manual interrupt command).""" + await self._handle_barge_in() + + async def _on_transcript_callback(self, text: str, is_final: bool) -> None: + """ + Callback for ASR transcription results. + + Streams transcription to client for display. + + Args: + text: Transcribed text + is_final: Whether this is the final transcription + """ + # Avoid sending duplicate transcripts + if text == self._last_sent_transcript and not is_final: + return + + self._last_sent_transcript = text + + # Send transcript event to client + await self.transport.send_event({ + "event": "transcript", + "trackId": self.session_id, + "text": text, + "isFinal": is_final, + "timestamp": self._get_timestamp_ms() + }) + + logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...") + + async def _on_speech_start(self) -> None: + """Handle user starting to speak.""" + if self.conversation.state == ConversationState.IDLE: + await self.conversation.start_user_turn() + self._audio_buffer = b"" + self._last_sent_transcript = "" + self.eou_detector.reset() + + # Clear ASR buffer and start interim transcriptions + if hasattr(self.asr_service, 'clear_buffer'): + self.asr_service.clear_buffer() + if hasattr(self.asr_service, 'start_interim_transcription'): + await self.asr_service.start_interim_transcription() + + logger.debug("User speech started") + + async def _on_end_of_utterance(self) -> None: + """Handle end of user utterance.""" + if self.conversation.state != ConversationState.LISTENING: + return + + # Stop interim transcriptions + if hasattr(self.asr_service, 'stop_interim_transcription'): + await self.asr_service.stop_interim_transcription() + + # Get final transcription from ASR service + user_text = "" + + if hasattr(self.asr_service, 'get_final_transcription'): + # SiliconFlow ASR - get final transcription + user_text = await self.asr_service.get_final_transcription() + elif hasattr(self.asr_service, 'get_and_clear_text'): + # Buffered ASR - get accumulated text + user_text = self.asr_service.get_and_clear_text() + + # Skip if no meaningful text + if not user_text or not user_text.strip(): + logger.debug("EOU detected but no transcription - skipping") + # Reset for next utterance + self._audio_buffer = b"" + self._last_sent_transcript = "" + # Return to idle; don't force LISTENING which causes buffering on silence + await self.conversation.set_state(ConversationState.IDLE) + return + + logger.info(f"EOU detected - user said: {user_text[:100]}...") + + # Send final transcription to client + await self.transport.send_event({ + "event": "transcript", + "trackId": self.session_id, + "text": user_text, + "isFinal": True, + "timestamp": self._get_timestamp_ms() + }) + + # Clear buffers + self._audio_buffer = b"" + self._last_sent_transcript = "" + + # Process the turn - trigger LLM response + # Cancel any existing turn to avoid overlapping assistant responses + await self._stop_current_speech() + await self.conversation.end_user_turn(user_text) + self._current_turn_task = asyncio.create_task(self._handle_turn(user_text)) + + async def _handle_turn(self, user_text: str) -> None: + """ + Handle a complete conversation turn. + + Uses sentence-by-sentence streaming TTS for lower latency. + + Args: + user_text: User's transcribed text + """ + try: + # Start latency tracking + self._turn_start_time = time.time() + self._first_audio_sent = False + + # Get AI response (streaming) + messages = self.conversation.get_messages() + full_response = "" + + await self.conversation.start_assistant_turn() + self._is_bot_speaking = True + self._interrupt_event.clear() + + # Sentence buffer for streaming TTS + sentence_buffer = "" + sentence_ends = {',', '。', '!', '?', '\n'} + first_audio_sent = False + + # Stream LLM response and TTS sentence by sentence + async for text_chunk in self.llm_service.generate_stream(messages): + if self._interrupt_event.is_set(): + break + + full_response += text_chunk + sentence_buffer += text_chunk + await self.conversation.update_assistant_text(text_chunk) + + # Send LLM response streaming event to client + await self.transport.send_event({ + "event": "llmResponse", + "trackId": self.session_id, + "text": text_chunk, + "isFinal": False, + "timestamp": self._get_timestamp_ms() + }) + + # Check for sentence completion - synthesize immediately for low latency + while any(end in sentence_buffer for end in sentence_ends): + # Find first sentence end + min_idx = len(sentence_buffer) + for end in sentence_ends: + idx = sentence_buffer.find(end) + if idx != -1 and idx < min_idx: + min_idx = idx + + if min_idx < len(sentence_buffer): + sentence = sentence_buffer[:min_idx + 1].strip() + sentence_buffer = sentence_buffer[min_idx + 1:] + + if sentence and not self._interrupt_event.is_set(): + # Send track start on first audio + if not first_audio_sent: + await self.transport.send_event({ + "event": "trackStart", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms() + }) + first_audio_sent = True + + # Synthesize and send this sentence immediately + await self._speak_sentence(sentence) + else: + break + + # Send final LLM response event + if full_response and not self._interrupt_event.is_set(): + await self.transport.send_event({ + "event": "llmResponse", + "trackId": self.session_id, + "text": full_response, + "isFinal": True, + "timestamp": self._get_timestamp_ms() + }) + + # Speak any remaining text + if sentence_buffer.strip() and not self._interrupt_event.is_set(): + if not first_audio_sent: + await self.transport.send_event({ + "event": "trackStart", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms() + }) + first_audio_sent = True + await self._speak_sentence(sentence_buffer.strip()) + + # Send track end + if first_audio_sent: + await self.transport.send_event({ + "event": "trackEnd", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms() + }) + + # End assistant turn + await self.conversation.end_assistant_turn( + was_interrupted=self._interrupt_event.is_set() + ) + + except asyncio.CancelledError: + logger.info("Turn handling cancelled") + await self.conversation.end_assistant_turn(was_interrupted=True) + except Exception as e: + logger.error(f"Turn handling error: {e}", exc_info=True) + await self.conversation.end_assistant_turn(was_interrupted=True) + finally: + self._is_bot_speaking = False + # Reset barge-in tracking when bot finishes speaking + self._barge_in_speech_start_time = None + self._barge_in_speech_frames = 0 + self._barge_in_silence_frames = 0 + + async def _speak_sentence(self, text: str) -> None: + """ + Synthesize and send a single sentence. + + Args: + text: Sentence to speak + """ + if not text.strip() or self._interrupt_event.is_set(): + return + + try: + async for chunk in self.tts_service.synthesize_stream(text): + # Check interrupt at the start of each iteration + if self._interrupt_event.is_set(): + logger.debug("TTS sentence interrupted") + break + + # Track and log first audio packet latency (TTFB) + if not self._first_audio_sent and self._turn_start_time: + ttfb_ms = (time.time() - self._turn_start_time) * 1000 + self._first_audio_sent = True + logger.info(f"[TTFB] Server first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})") + + # Send TTFB event to client + await self.transport.send_event({ + "event": "ttfb", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms(), + "latencyMs": round(ttfb_ms) + }) + + # Double-check interrupt right before sending audio + if self._interrupt_event.is_set(): + break + + await self.transport.send_audio(chunk.audio) + await asyncio.sleep(0.005) # Small delay to prevent flooding + except asyncio.CancelledError: + logger.debug("TTS sentence cancelled") + except Exception as e: + logger.error(f"TTS sentence error: {e}") + + async def _speak(self, text: str) -> None: + """ + Synthesize and send speech. + + Args: + text: Text to speak + """ + if not text.strip(): + return + + try: + # Start latency tracking for greeting + speak_start_time = time.time() + first_audio_sent = False + + # Send track start event + await self.transport.send_event({ + "event": "trackStart", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms() + }) + + self._is_bot_speaking = True + + # Stream TTS audio + async for chunk in self.tts_service.synthesize_stream(text): + if self._interrupt_event.is_set(): + logger.info("TTS interrupted by barge-in") + break + + # Track and log first audio packet latency (TTFB) + if not first_audio_sent: + ttfb_ms = (time.time() - speak_start_time) * 1000 + first_audio_sent = True + logger.info(f"[TTFB] Greeting first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})") + + # Send TTFB event to client + await self.transport.send_event({ + "event": "ttfb", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms(), + "latencyMs": round(ttfb_ms) + }) + + # Send audio to client + await self.transport.send_audio(chunk.audio) + + # Small delay to prevent flooding + await asyncio.sleep(0.01) + + # Send track end event + await self.transport.send_event({ + "event": "trackEnd", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms() + }) + + except asyncio.CancelledError: + logger.info("TTS cancelled") + raise + except Exception as e: + logger.error(f"TTS error: {e}") + finally: + self._is_bot_speaking = False + + async def _handle_barge_in(self) -> None: + """Handle user barge-in (interruption).""" + if not self._is_bot_speaking: + return + + logger.info("Barge-in detected - interrupting bot speech") + + # Reset barge-in tracking + self._barge_in_speech_start_time = None + self._barge_in_speech_frames = 0 + self._barge_in_silence_frames = 0 + + # IMPORTANT: Signal interruption FIRST to stop audio sending + self._interrupt_event.set() + self._is_bot_speaking = False + + # Send interrupt event to client IMMEDIATELY + # This must happen BEFORE canceling services, so client knows to discard in-flight audio + await self.transport.send_event({ + "event": "interrupt", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms() + }) + + # Cancel TTS + if self.tts_service: + await self.tts_service.cancel() + + # Cancel LLM + if self.llm_service and hasattr(self.llm_service, 'cancel'): + self.llm_service.cancel() + + # Interrupt conversation only if there is no active turn task. + # When a turn task exists, it will handle end_assistant_turn() to avoid double callbacks. + if not (self._current_turn_task and not self._current_turn_task.done()): + await self.conversation.interrupt() + + # Reset for new user turn + await self.conversation.start_user_turn() + self._audio_buffer = b"" + self.eou_detector.reset() + + async def _stop_current_speech(self) -> None: + """Stop any current speech task.""" + if self._current_turn_task and not self._current_turn_task.done(): + self._interrupt_event.set() + self._current_turn_task.cancel() + try: + await self._current_turn_task + except asyncio.CancelledError: + pass + + # Ensure underlying services are cancelled to avoid leaking work/audio + if self.tts_service: + await self.tts_service.cancel() + if self.llm_service and hasattr(self.llm_service, 'cancel'): + self.llm_service.cancel() + + self._is_bot_speaking = False + self._interrupt_event.clear() + + async def cleanup(self) -> None: + """Cleanup pipeline resources.""" + logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}") + + self._running = False + await self._stop_current_speech() + + # Disconnect services + if self.llm_service: + await self.llm_service.disconnect() + if self.tts_service: + await self.tts_service.disconnect() + if self.asr_service: + await self.asr_service.disconnect() + + def _get_timestamp_ms(self) -> int: + """Get current timestamp in milliseconds.""" + import time + return int(time.time() * 1000) + + @property + def is_speaking(self) -> bool: + """Check if bot is currently speaking.""" + return self._is_bot_speaking + + @property + def state(self) -> ConversationState: + """Get current conversation state.""" + return self.conversation.state diff --git a/engine/core/events.py b/engine/core/events.py new file mode 100644 index 0000000..1762148 --- /dev/null +++ b/engine/core/events.py @@ -0,0 +1,134 @@ +"""Event bus for pub/sub communication between components.""" + +import asyncio +from typing import Callable, Dict, List, Any, Optional +from collections import defaultdict +from loguru import logger + + +class EventBus: + """ + Async event bus for pub/sub communication. + + Similar to the original Rust implementation's broadcast channel. + Components can subscribe to specific event types and receive events asynchronously. + """ + + def __init__(self): + """Initialize the event bus.""" + self._subscribers: Dict[str, List[Callable]] = defaultdict(list) + self._lock = asyncio.Lock() + self._running = True + + def subscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None: + """ + Subscribe to an event type. + + Args: + event_type: Type of event to subscribe to (e.g., "speaking", "silence") + callback: Async callback function that receives event data + """ + if not self._running: + logger.warning(f"Event bus is shut down, ignoring subscription to {event_type}") + return + + self._subscribers[event_type].append(callback) + logger.debug(f"Subscribed to event type: {event_type}") + + def unsubscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None: + """ + Unsubscribe from an event type. + + Args: + event_type: Type of event to unsubscribe from + callback: Callback function to remove + """ + if callback in self._subscribers[event_type]: + self._subscribers[event_type].remove(callback) + logger.debug(f"Unsubscribed from event type: {event_type}") + + async def publish(self, event_type: str, event_data: Dict[str, Any]) -> None: + """ + Publish an event to all subscribers. + + Args: + event_type: Type of event to publish + event_data: Event data to send to subscribers + """ + if not self._running: + logger.warning(f"Event bus is shut down, ignoring event: {event_type}") + return + + # Get subscribers for this event type + subscribers = self._subscribers.get(event_type, []) + + if not subscribers: + logger.debug(f"No subscribers for event type: {event_type}") + return + + # Notify all subscribers concurrently + tasks = [] + for callback in subscribers: + try: + # Create task for each subscriber + task = asyncio.create_task(self._call_subscriber(callback, event_data)) + tasks.append(task) + except Exception as e: + logger.error(f"Error creating task for subscriber: {e}") + + # Wait for all subscribers to complete + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + logger.debug(f"Published event '{event_type}' to {len(tasks)} subscribers") + + async def _call_subscriber(self, callback: Callable[[Dict[str, Any]], None], event_data: Dict[str, Any]) -> None: + """ + Call a subscriber callback with error handling. + + Args: + callback: Subscriber callback function + event_data: Event data to pass to callback + """ + try: + # Check if callback is a coroutine function + if asyncio.iscoroutinefunction(callback): + await callback(event_data) + else: + callback(event_data) + except Exception as e: + logger.error(f"Error in subscriber callback: {e}", exc_info=True) + + async def close(self) -> None: + """Close the event bus and stop processing events.""" + self._running = False + self._subscribers.clear() + logger.info("Event bus closed") + + @property + def is_running(self) -> bool: + """Check if the event bus is running.""" + return self._running + + +# Global event bus instance +_event_bus: Optional[EventBus] = None + + +def get_event_bus() -> EventBus: + """ + Get the global event bus instance. + + Returns: + EventBus instance + """ + global _event_bus + if _event_bus is None: + _event_bus = EventBus() + return _event_bus + + +def reset_event_bus() -> None: + """Reset the global event bus (mainly for testing).""" + global _event_bus + _event_bus = None diff --git a/engine/core/session.py b/engine/core/session.py new file mode 100644 index 0000000..54bf0d4 --- /dev/null +++ b/engine/core/session.py @@ -0,0 +1,285 @@ +"""Session management for active calls.""" + +import uuid +import json +from typing import Optional, Dict, Any +from loguru import logger + +from core.transports import BaseTransport +from core.duplex_pipeline import DuplexPipeline +from models.commands import parse_command, TTSCommand, ChatCommand, InterruptCommand, HangupCommand +from app.config import settings + + +class Session: + """ + Manages a single call session. + + Handles command routing, audio processing, and session lifecycle. + Uses full duplex voice conversation pipeline. + """ + + def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None): + """ + Initialize session. + + Args: + session_id: Unique session identifier + transport: Transport instance for communication + use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled) + """ + self.id = session_id + self.transport = transport + self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled + + self.pipeline = DuplexPipeline( + transport=transport, + session_id=session_id, + system_prompt=settings.duplex_system_prompt, + greeting=settings.duplex_greeting + ) + + # Session state + self.created_at = None + self.state = "created" # created, invited, accepted, ringing, hungup + self._pipeline_started = False + + # Track IDs + self.current_track_id: Optional[str] = str(uuid.uuid4()) + + logger.info(f"Session {self.id} created (duplex={self.use_duplex})") + + async def handle_text(self, text_data: str) -> None: + """ + Handle incoming text data (JSON commands). + + Args: + text_data: JSON text data + """ + try: + data = json.loads(text_data) + command = parse_command(data) + command_type = command.command + + logger.info(f"Session {self.id} received command: {command_type}") + + # Route command to appropriate handler + if command_type == "invite": + await self._handle_invite(data) + + elif command_type == "accept": + await self._handle_accept(data) + + elif command_type == "reject": + await self._handle_reject(data) + + elif command_type == "ringing": + await self._handle_ringing(data) + + elif command_type == "tts": + await self._handle_tts(command) + + elif command_type == "play": + await self._handle_play(data) + + elif command_type == "interrupt": + await self._handle_interrupt(command) + + elif command_type == "pause": + await self._handle_pause() + + elif command_type == "resume": + await self._handle_resume() + + elif command_type == "hangup": + await self._handle_hangup(command) + + elif command_type == "history": + await self._handle_history(data) + + elif command_type == "chat": + await self._handle_chat(command) + + else: + logger.warning(f"Session {self.id} unknown command: {command_type}") + + except json.JSONDecodeError as e: + logger.error(f"Session {self.id} JSON decode error: {e}") + await self._send_error("client", f"Invalid JSON: {e}") + + except ValueError as e: + logger.error(f"Session {self.id} command parse error: {e}") + await self._send_error("client", f"Invalid command: {e}") + + except Exception as e: + logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True) + await self._send_error("server", f"Internal error: {e}") + + async def handle_audio(self, audio_bytes: bytes) -> None: + """ + Handle incoming audio data. + + Args: + audio_bytes: PCM audio data + """ + try: + await self.pipeline.process_audio(audio_bytes) + except Exception as e: + logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True) + + async def _handle_invite(self, data: Dict[str, Any]) -> None: + """Handle invite command.""" + self.state = "invited" + option = data.get("option", {}) + + # Send answer event + await self.transport.send_event({ + "event": "answer", + "trackId": self.current_track_id, + "timestamp": self._get_timestamp_ms() + }) + + # Start duplex pipeline + if not self._pipeline_started: + try: + await self.pipeline.start() + self._pipeline_started = True + logger.info(f"Session {self.id} duplex pipeline started") + except Exception as e: + logger.error(f"Failed to start duplex pipeline: {e}") + + logger.info(f"Session {self.id} invited with codec: {option.get('codec', 'pcm')}") + + async def _handle_accept(self, data: Dict[str, Any]) -> None: + """Handle accept command.""" + self.state = "accepted" + logger.info(f"Session {self.id} accepted") + + async def _handle_reject(self, data: Dict[str, Any]) -> None: + """Handle reject command.""" + self.state = "rejected" + reason = data.get("reason", "Rejected") + logger.info(f"Session {self.id} rejected: {reason}") + + async def _handle_ringing(self, data: Dict[str, Any]) -> None: + """Handle ringing command.""" + self.state = "ringing" + logger.info(f"Session {self.id} ringing") + + async def _handle_tts(self, command: TTSCommand) -> None: + """Handle TTS command.""" + logger.info(f"Session {self.id} TTS: {command.text[:50]}...") + + # Send track start event + await self.transport.send_event({ + "event": "trackStart", + "trackId": self.current_track_id, + "timestamp": self._get_timestamp_ms(), + "playId": command.play_id + }) + + # TODO: Implement actual TTS synthesis + # For now, just send track end event + await self.transport.send_event({ + "event": "trackEnd", + "trackId": self.current_track_id, + "timestamp": self._get_timestamp_ms(), + "duration": 1000, + "ssrc": 0, + "playId": command.play_id + }) + + async def _handle_play(self, data: Dict[str, Any]) -> None: + """Handle play command.""" + url = data.get("url", "") + logger.info(f"Session {self.id} play: {url}") + + # Send track start event + await self.transport.send_event({ + "event": "trackStart", + "trackId": self.current_track_id, + "timestamp": self._get_timestamp_ms(), + "playId": url + }) + + # TODO: Implement actual audio playback + # For now, just send track end event + await self.transport.send_event({ + "event": "trackEnd", + "trackId": self.current_track_id, + "timestamp": self._get_timestamp_ms(), + "duration": 1000, + "ssrc": 0, + "playId": url + }) + + async def _handle_interrupt(self, command: InterruptCommand) -> None: + """Handle interrupt command.""" + if command.graceful: + logger.info(f"Session {self.id} graceful interrupt") + else: + logger.info(f"Session {self.id} immediate interrupt") + await self.pipeline.interrupt() + + async def _handle_pause(self) -> None: + """Handle pause command.""" + logger.info(f"Session {self.id} paused") + + async def _handle_resume(self) -> None: + """Handle resume command.""" + logger.info(f"Session {self.id} resumed") + + async def _handle_hangup(self, command: HangupCommand) -> None: + """Handle hangup command.""" + self.state = "hungup" + reason = command.reason or "User requested" + logger.info(f"Session {self.id} hung up: {reason}") + + # Send hangup event + await self.transport.send_event({ + "event": "hangup", + "timestamp": self._get_timestamp_ms(), + "reason": reason, + "initiator": command.initiator or "user" + }) + + # Close transport + await self.transport.close() + + async def _handle_history(self, data: Dict[str, Any]) -> None: + """Handle history command.""" + speaker = data.get("speaker", "unknown") + text = data.get("text", "") + logger.info(f"Session {self.id} history [{speaker}]: {text[:50]}...") + + async def _handle_chat(self, command: ChatCommand) -> None: + """Handle chat command.""" + logger.info(f"Session {self.id} chat: {command.text[:50]}...") + await self.pipeline.process_text(command.text) + + async def _send_error(self, sender: str, error_message: str) -> None: + """ + Send error event to client. + + Args: + sender: Component that generated the error + error_message: Error message + """ + await self.transport.send_event({ + "event": "error", + "trackId": self.current_track_id, + "timestamp": self._get_timestamp_ms(), + "sender": sender, + "error": error_message + }) + + def _get_timestamp_ms(self) -> int: + """Get current timestamp in milliseconds.""" + import time + return int(time.time() * 1000) + + async def cleanup(self) -> None: + """Cleanup session resources.""" + logger.info(f"Session {self.id} cleaning up") + await self.pipeline.cleanup() + await self.transport.close() diff --git a/engine/core/transports.py b/engine/core/transports.py new file mode 100644 index 0000000..6945225 --- /dev/null +++ b/engine/core/transports.py @@ -0,0 +1,207 @@ +"""Transport layer for WebSocket and WebRTC communication.""" + +import asyncio +import json +from abc import ABC, abstractmethod +from typing import Optional +from fastapi import WebSocket +from loguru import logger + +# Try to import aiortc (optional for WebRTC functionality) +try: + from aiortc import RTCPeerConnection + AIORTC_AVAILABLE = True +except ImportError: + AIORTC_AVAILABLE = False + RTCPeerConnection = None # Type hint placeholder + + +class BaseTransport(ABC): + """ + Abstract base class for transports. + + All transports must implement send_event and send_audio methods. + """ + + @abstractmethod + async def send_event(self, event: dict) -> None: + """ + Send a JSON event to the client. + + Args: + event: Event data as dictionary + """ + pass + + @abstractmethod + async def send_audio(self, pcm_bytes: bytes) -> None: + """ + Send audio data to the client. + + Args: + pcm_bytes: PCM audio data (16-bit, mono, 16kHz) + """ + pass + + @abstractmethod + async def close(self) -> None: + """Close the transport and cleanup resources.""" + pass + + +class SocketTransport(BaseTransport): + """ + WebSocket transport for raw audio streaming. + + Handles mixed text/binary frames over WebSocket connection. + Uses asyncio.Lock to prevent frame interleaving. + """ + + def __init__(self, websocket: WebSocket): + """ + Initialize WebSocket transport. + + Args: + websocket: FastAPI WebSocket connection + """ + self.ws = websocket + self.lock = asyncio.Lock() # Prevent frame interleaving + self._closed = False + + async def send_event(self, event: dict) -> None: + """ + Send a JSON event via WebSocket. + + Args: + event: Event data as dictionary + """ + if self._closed: + logger.warning("Attempted to send event on closed transport") + return + + async with self.lock: + try: + await self.ws.send_text(json.dumps(event)) + logger.debug(f"Sent event: {event.get('event', 'unknown')}") + except Exception as e: + logger.error(f"Error sending event: {e}") + self._closed = True + + async def send_audio(self, pcm_bytes: bytes) -> None: + """ + Send PCM audio data via WebSocket. + + Args: + pcm_bytes: PCM audio data (16-bit, mono, 16kHz) + """ + if self._closed: + logger.warning("Attempted to send audio on closed transport") + return + + async with self.lock: + try: + await self.ws.send_bytes(pcm_bytes) + except Exception as e: + logger.error(f"Error sending audio: {e}") + self._closed = True + + async def close(self) -> None: + """Close the WebSocket connection.""" + self._closed = True + try: + await self.ws.close() + except Exception as e: + logger.error(f"Error closing WebSocket: {e}") + + @property + def is_closed(self) -> bool: + """Check if the transport is closed.""" + return self._closed + + +class WebRtcTransport(BaseTransport): + """ + WebRTC transport for WebRTC audio streaming. + + Uses WebSocket for signaling and RTCPeerConnection for media. + """ + + def __init__(self, websocket: WebSocket, pc): + """ + Initialize WebRTC transport. + + Args: + websocket: FastAPI WebSocket connection for signaling + pc: RTCPeerConnection for media transport + """ + if not AIORTC_AVAILABLE: + raise RuntimeError("aiortc is not available - WebRTC transport cannot be used") + + self.ws = websocket + self.pc = pc + self.outbound_track = None # MediaStreamTrack for outbound audio + self._closed = False + + async def send_event(self, event: dict) -> None: + """ + Send a JSON event via WebSocket signaling. + + Args: + event: Event data as dictionary + """ + if self._closed: + logger.warning("Attempted to send event on closed transport") + return + + try: + await self.ws.send_text(json.dumps(event)) + logger.debug(f"Sent event: {event.get('event', 'unknown')}") + except Exception as e: + logger.error(f"Error sending event: {e}") + self._closed = True + + async def send_audio(self, pcm_bytes: bytes) -> None: + """ + Send audio data via WebRTC track. + + Note: In WebRTC, you don't send bytes directly. You push frames + to a MediaStreamTrack that the peer connection is reading. + + Args: + pcm_bytes: PCM audio data (16-bit, mono, 16kHz) + """ + if self._closed: + logger.warning("Attempted to send audio on closed transport") + return + + # This would require a custom MediaStreamTrack implementation + # For now, we'll log this as a placeholder + logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes") + + # TODO: Implement outbound audio track if needed + # if self.outbound_track: + # await self.outbound_track.add_frame(pcm_bytes) + + async def close(self) -> None: + """Close the WebRTC connection.""" + self._closed = True + try: + await self.pc.close() + await self.ws.close() + except Exception as e: + logger.error(f"Error closing WebRTC transport: {e}") + + @property + def is_closed(self) -> bool: + """Check if the transport is closed.""" + return self._closed + + def set_outbound_track(self, track): + """ + Set the outbound audio track for sending audio to client. + + Args: + track: MediaStreamTrack for outbound audio + """ + self.outbound_track = track + logger.debug("Set outbound track for WebRTC transport") diff --git a/engine/data/audio_examples/single_utterance_16k.wav b/engine/data/audio_examples/single_utterance_16k.wav new file mode 100644 index 0000000..8c7bbe5 Binary files /dev/null and b/engine/data/audio_examples/single_utterance_16k.wav differ diff --git a/engine/data/audio_examples/three_utterances.wav b/engine/data/audio_examples/three_utterances.wav new file mode 100644 index 0000000..c2dca2f Binary files /dev/null and b/engine/data/audio_examples/three_utterances.wav differ diff --git a/engine/data/audio_examples/two_utterances.wav b/engine/data/audio_examples/two_utterances.wav new file mode 100644 index 0000000..5c66f70 Binary files /dev/null and b/engine/data/audio_examples/two_utterances.wav differ diff --git a/engine/data/vad/silero_vad.onnx b/engine/data/vad/silero_vad.onnx new file mode 100644 index 0000000..b3e3a90 Binary files /dev/null and b/engine/data/vad/silero_vad.onnx differ diff --git a/engine/docs/duplex_interaction.svg b/engine/docs/duplex_interaction.svg new file mode 100644 index 0000000..9ccd0bb --- /dev/null +++ b/engine/docs/duplex_interaction.svg @@ -0,0 +1,96 @@ + + + + + + + + + + Web Client + WS JSON commands + WS binary PCM audio + + + FastAPI /ws + Session + Transport + + + DuplexPipeline + process_audio / process_text + + + ConversationManager + turns + state + + + VADProcessor + speech/silence + + + EOU Detector + end-of-utterance + + + ASR + transcripts + + + LLM (stream) + llmResponse events + + + TTS (stream) + PCM audio + + + Web Client + audio playback + UI + + + JSON / PCM + + + dispatch + + + turn mgmt + + + audio chunks + + + vad status + + + audio buffer + + + EOU -> LLM + + + text stream + + + PCM audio + + + events: trackStart/End + + + UI updates + + + barge-in detection + + + interrupt event + cancel + diff --git a/engine/docs/proejct_todo.md b/engine/docs/proejct_todo.md new file mode 100644 index 0000000..18a9f17 --- /dev/null +++ b/engine/docs/proejct_todo.md @@ -0,0 +1,187 @@ +# OmniSense: 12-Week Sprint Board + Tech Stack (Python Backend) — TODO + +## Scope +- [ ] Build a realtime AI SaaS (OmniSense) focused on web-first audio + video with WebSocket + WebRTC endpoints +- [ ] Deliver assistant builder, tool execution, observability, evals, optional telephony later +- [ ] Keep scope aligned to 2-person team, self-hosted services + +--- + +## Sprint Board (12 weeks, 2-week sprints) +Team assumption: 2 engineers. Scope prioritized to web-first audio + video, with BYO-SFU adapters. + +### Sprint 1 (Weeks 1–2) — Realtime Core MVP (WebSocket + WebRTC Audio) +- Deliverables + - [ ] WebSocket transport: audio in/out streaming (1:1) + - [ ] WebRTC transport: audio in/out streaming (1:1) + - [ ] Adapter contract wired into runtime (transport-agnostic session core) + - [ ] ASR → LLM → TTS pipeline, streaming both directions + - [ ] Basic session state (start/stop, silence timeout) + - [ ] Transcript persistence +- Acceptance criteria + - [ ] < 1.5s median round-trip for short responses + - [ ] Stable streaming for 10+ minute session + +### Sprint 2 (Weeks 3–4) — Video + Realtime UX +- Deliverables + - [ ] WebRTC video capture + streaming (assistant can “see” frames) + - [ ] WebSocket video streaming for local/dev mode + - [ ] Low-latency UI: push-to-talk, live captions, speaking indicator + - [ ] Recording + transcript storage (web sessions) +- Acceptance criteria + - [ ] Video < 2.5s end-to-end latency for analysis + - [ ] Audio quality acceptable (no clipping, jitter handling) + +### Sprint 3 (Weeks 5–6) — Assistant Builder v1 +- Deliverables + - [ ] Assistant schema + versioning + - [ ] UI: Model/Voice/Transcriber/Tools/Video/Transport tabs + - [ ] “Test/Chat/Talk to Assistant” (web) +- Acceptance criteria + - [ ] Create/publish assistant and run a live web session + - [ ] All config changes tracked by version + +### Sprint 4 (Weeks 7–8) — Tooling + Structured Outputs +- Deliverables + - [ ] Tool registry + custom HTTP tools + - [ ] Tool auth secrets management + - [ ] Structured outputs (JSON extraction) +- Acceptance criteria + - [ ] Tool calls executed with retries/timeouts + - [ ] Structured JSON stored per call/session + +### Sprint 5 (Weeks 9–10) — Observability + QA + Dev Platform +- Deliverables + - [ ] Session logs + chat logs + media logs + - [ ] Evals engine + test suites + - [ ] Basic analytics dashboard + - [ ] Public WebSocket API spec + message schema + - [ ] JS/TS SDK (connect, send audio/video, receive transcripts) +- Acceptance criteria + - [ ] Reproducible test suite runs + - [ ] Log filters by assistant/time/status + - [ ] SDK demo app runs end-to-end + +### Sprint 6 (Weeks 11–12) — SaaS Hardening +- Deliverables + - [ ] Org/RBAC + API keys + rate limits + - [ ] Usage metering + credits + - [ ] Stripe billing integration + - [ ] Self-hosted DB ops (migrations, backup/restore, monitoring) +- Acceptance criteria + - [ ] Metered usage per org + - [ ] Credits decrement correctly + - [ ] Optional telephony spike documented (defer build) + - [ ] Enterprise adapter guide published (BYO-SFU) + +--- + +## Tech Stack by Service (Self-Hosted, Web-First) + +### 1) Transport Gateway (Realtime) +- [ ] WebRTC (browser) + WebSocket (lightweight/dev) protocols +- [ ] BYO-SFU adapter (enterprise) + LiveKit optional adapter + WS transport server +- [ ] Python core (FastAPI + asyncio) + Node.js mediasoup adapters when needed +- [ ] Media: Opus/VP8, jitter buffer, VAD, echo cancellation +- [ ] Storage: S3-compatible (MinIO) for recordings + +### 2) ASR Service +- [ ] Whisper (self-hosted) baseline +- [ ] gRPC/WebSocket streaming transport +- [ ] Python native service +- [ ] Optional cloud provider fallback (later) + +### 3) TTS Service +- [ ] Piper or Coqui TTS (self-hosted) +- [ ] gRPC/WebSocket streaming transport +- [ ] Python native service +- [ ] Redis cache for common phrases + +### 4) LLM Orchestrator +- [ ] Self-hosted (vLLM + open model) +- [ ] Python (FastAPI + asyncio) +- [ ] Streaming, tool calling, JSON mode +- [ ] Safety filters + prompt templates + +### 5) Assistant Config Service +- [ ] PostgreSQL +- [ ] Python (SQLAlchemy or SQLModel) +- [ ] Versioning, publish/rollback + +### 6) Session Service +- [ ] PostgreSQL + Redis +- [ ] Python +- [ ] State machine, timeouts, events + +### 7) Tool Execution Layer +- [ ] PostgreSQL +- [ ] Python +- [ ] Auth secret vault, retry policies, tool schemas + +### 8) Observability + Logs +- [ ] Postgres (metadata), ClickHouse (logs/metrics) +- [ ] OpenSearch for search +- [ ] Prometheus + Grafana metrics +- [ ] OpenTelemetry tracing + +### 9) Billing + Usage Metering +- [ ] Stripe billing +- [ ] PostgreSQL +- [ ] NATS JetStream (events) + Redis counters + +### 10) Web App (Dashboard) +- [ ] React + Next.js +- [ ] Tailwind or Radix UI +- [ ] WebRTC client + WS client; adapter-based RTC integration +- [ ] ECharts/Recharts + +### 11) Auth + RBAC +- [ ] Keycloak (self-hosted) or custom JWT +- [ ] Org/user/role tables in Postgres + +### 12) Public WebSocket API + SDK +- [ ] WS API: versioned schema, binary audio frames + JSON control messages +- [ ] SDKs: JS/TS first, optional Python/Go clients +- [ ] Docs: quickstart, auth flow, session lifecycle, examples + +--- + +## Infrastructure (Self-Hosted) +- [ ] Docker Compose → k3s (later) +- [ ] Redis Streams or NATS +- [ ] MinIO object store +- [ ] GitHub Actions + Helm or kustomize +- [ ] Self-hosted Postgres + pgbackrest backups +- [ ] Vault for secrets + +--- + +## Suggested MVP Sequence +- [ ] WebRTC demo + ASR/LLM/TTS streaming +- [ ] Assistant schema + versioning (web-first) +- [ ] Video capture + multimodal analysis +- [ ] Tool execution + structured outputs +- [ ] Logs + evals + public WS API + SDK +- [ ] Telephony (optional, later) + +--- + +## Public WebSocket API (Minimum Spec) +- [ ] Auth: API key or JWT in initial `hello` message +- [ ] Core messages: `session.start`, `session.stop`, `audio.append`, `audio.commit`, `video.append`, `transcript.delta`, `assistant.response`, `tool.call`, `tool.result`, `error` +- [ ] Binary payloads: PCM/Opus frames with metadata in control channel +- [ ] Versioning: `v1` schema with backward compatibility rules + +--- + +## Self-Hosted DB Ops Checklist +- [ ] Postgres in Docker/k3s with persistent volumes +- [ ] Migrations: `alembic` or `atlas` +- [ ] Backups: `pgbackrest` nightly + on-demand +- [ ] Monitoring: postgres_exporter + alerts + +--- + +## RTC Adapter Contract (BYO-SFU First) +- [ ] Keep RTC pluggable; LiveKit optional, not core dependency +- [ ] Define adapter interface (TypeScript sketch) \ No newline at end of file diff --git a/engine/examples/mic_client.py b/engine/examples/mic_client.py new file mode 100644 index 0000000..509aeaa --- /dev/null +++ b/engine/examples/mic_client.py @@ -0,0 +1,601 @@ +#!/usr/bin/env python3 +""" +Microphone client for testing duplex voice conversation. + +This client captures audio from the microphone, sends it to the server, +and plays back the AI's voice response through the speakers. +It also displays the LLM's text responses in the console. + +Usage: + python examples/mic_client.py --url ws://localhost:8000/ws + python examples/mic_client.py --url ws://localhost:8000/ws --chat "Hello!" + python examples/mic_client.py --url ws://localhost:8000/ws --verbose + +Requirements: + pip install sounddevice soundfile websockets numpy +""" + +import argparse +import asyncio +import json +import sys +import time +import threading +import queue +from pathlib import Path + +try: + import numpy as np +except ImportError: + print("Please install numpy: pip install numpy") + sys.exit(1) + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice: pip install sounddevice") + sys.exit(1) + +try: + import websockets +except ImportError: + print("Please install websockets: pip install websockets") + sys.exit(1) + + +class MicrophoneClient: + """ + Full-duplex microphone client for voice conversation. + + Features: + - Real-time microphone capture + - Real-time speaker playback + - WebSocket communication + - Text chat support + """ + + def __init__( + self, + url: str, + sample_rate: int = 16000, + chunk_duration_ms: int = 20, + input_device: int = None, + output_device: int = None + ): + """ + Initialize microphone client. + + Args: + url: WebSocket server URL + sample_rate: Audio sample rate (Hz) + chunk_duration_ms: Audio chunk duration (ms) + input_device: Input device ID (None for default) + output_device: Output device ID (None for default) + """ + self.url = url + self.sample_rate = sample_rate + self.chunk_duration_ms = chunk_duration_ms + self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) + self.input_device = input_device + self.output_device = output_device + + # WebSocket connection + self.ws = None + self.running = False + + # Audio buffers + self.audio_input_queue = queue.Queue() + self.audio_output_buffer = b"" # Continuous buffer for smooth playback + self.audio_output_lock = threading.Lock() + + # Statistics + self.bytes_sent = 0 + self.bytes_received = 0 + + # State + self.is_recording = True + self.is_playing = True + + # TTFB tracking (Time to First Byte) + self.request_start_time = None + self.first_audio_received = False + + # Interrupt handling - discard audio until next trackStart + self._discard_audio = False + self._audio_sequence = 0 # Track audio sequence to detect stale chunks + + # Verbose mode for streaming LLM responses + self.verbose = False + + async def connect(self) -> None: + """Connect to WebSocket server.""" + print(f"Connecting to {self.url}...") + self.ws = await websockets.connect(self.url) + self.running = True + print("Connected!") + + # Send invite command + await self.send_command({ + "command": "invite", + "option": { + "codec": "pcm", + "sampleRate": self.sample_rate + } + }) + + async def send_command(self, cmd: dict) -> None: + """Send JSON command to server.""" + if self.ws: + await self.ws.send(json.dumps(cmd)) + print(f"→ Command: {cmd.get('command', 'unknown')}") + + async def send_chat(self, text: str) -> None: + """Send chat message (text input).""" + # Reset TTFB tracking for new request + self.request_start_time = time.time() + self.first_audio_received = False + + await self.send_command({ + "command": "chat", + "text": text + }) + print(f"→ Chat: {text}") + + async def send_interrupt(self) -> None: + """Send interrupt command.""" + await self.send_command({ + "command": "interrupt" + }) + + async def send_hangup(self, reason: str = "User quit") -> None: + """Send hangup command.""" + await self.send_command({ + "command": "hangup", + "reason": reason + }) + + def _audio_input_callback(self, indata, frames, time, status): + """Callback for audio input (microphone).""" + if status: + print(f"Input status: {status}") + + if self.is_recording and self.running: + # Convert to 16-bit PCM + audio_data = (indata[:, 0] * 32767).astype(np.int16).tobytes() + self.audio_input_queue.put(audio_data) + + def _add_audio_to_buffer(self, audio_data: bytes): + """Add audio data to playback buffer.""" + with self.audio_output_lock: + self.audio_output_buffer += audio_data + + def _playback_thread_func(self): + """Thread function for continuous audio playback.""" + import time + + # Chunk size: 50ms of audio + chunk_samples = int(self.sample_rate * 0.05) + chunk_bytes = chunk_samples * 2 + + print(f"Audio playback thread started (device: {self.output_device or 'default'})") + + try: + # Create output stream with callback + with sd.OutputStream( + samplerate=self.sample_rate, + channels=1, + dtype='int16', + blocksize=chunk_samples, + device=self.output_device, + latency='low' + ) as stream: + while self.running: + # Get audio from buffer + with self.audio_output_lock: + if len(self.audio_output_buffer) >= chunk_bytes: + audio_data = self.audio_output_buffer[:chunk_bytes] + self.audio_output_buffer = self.audio_output_buffer[chunk_bytes:] + else: + # Not enough audio - output silence + audio_data = b'\x00' * chunk_bytes + + # Convert to numpy array and write to stream + samples = np.frombuffer(audio_data, dtype=np.int16).reshape(-1, 1) + stream.write(samples) + + except Exception as e: + print(f"Playback thread error: {e}") + import traceback + traceback.print_exc() + + async def _playback_task(self): + """Start playback thread and monitor it.""" + # Run playback in a dedicated thread for reliable timing + playback_thread = threading.Thread(target=self._playback_thread_func, daemon=True) + playback_thread.start() + + # Wait for client to stop + while self.running and playback_thread.is_alive(): + await asyncio.sleep(0.1) + + print("Audio playback stopped") + + async def audio_sender(self) -> None: + """Send audio from microphone to server.""" + while self.running: + try: + # Get audio from queue with timeout + try: + audio_data = await asyncio.get_event_loop().run_in_executor( + None, lambda: self.audio_input_queue.get(timeout=0.1) + ) + except queue.Empty: + continue + + # Send to server + if self.ws and self.is_recording: + await self.ws.send(audio_data) + self.bytes_sent += len(audio_data) + + except asyncio.CancelledError: + break + except Exception as e: + print(f"Audio sender error: {e}") + break + + async def receiver(self) -> None: + """Receive messages from server.""" + try: + while self.running: + try: + message = await asyncio.wait_for(self.ws.recv(), timeout=0.1) + + if isinstance(message, bytes): + # Audio data received + self.bytes_received += len(message) + + # Check if we should discard this audio (after interrupt) + if self._discard_audio: + duration_ms = len(message) / (self.sample_rate * 2) * 1000 + print(f"← Audio: {duration_ms:.0f}ms (DISCARDED - waiting for new track)") + continue + + if self.is_playing: + self._add_audio_to_buffer(message) + + # Calculate and display TTFB for first audio packet + if not self.first_audio_received and self.request_start_time: + client_ttfb_ms = (time.time() - self.request_start_time) * 1000 + self.first_audio_received = True + print(f"← [TTFB] Client first audio latency: {client_ttfb_ms:.0f}ms") + + # Show progress (less verbose) + with self.audio_output_lock: + buffer_ms = len(self.audio_output_buffer) / (self.sample_rate * 2) * 1000 + duration_ms = len(message) / (self.sample_rate * 2) * 1000 + print(f"← Audio: {duration_ms:.0f}ms (buffer: {buffer_ms:.0f}ms)") + + else: + # JSON event + event = json.loads(message) + await self._handle_event(event) + + except asyncio.TimeoutError: + continue + except websockets.ConnectionClosed: + print("Connection closed") + self.running = False + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Receiver error: {e}") + self.running = False + + async def _handle_event(self, event: dict) -> None: + """Handle incoming event.""" + event_type = event.get("event", "unknown") + + if event_type == "answer": + print("← Session ready!") + elif event_type == "speaking": + print("← User speech detected") + elif event_type == "silence": + print("← User silence detected") + elif event_type == "transcript": + # Display user speech transcription + text = event.get("text", "") + is_final = event.get("isFinal", False) + if is_final: + # Clear the interim line and print final + print(" " * 80, end="\r") # Clear previous interim text + print(f"→ You: {text}") + else: + # Interim result - show with indicator (overwrite same line) + display_text = text[:60] + "..." if len(text) > 60 else text + print(f" [listening] {display_text}".ljust(80), end="\r") + elif event_type == "ttfb": + # Server-side TTFB event + latency_ms = event.get("latencyMs", 0) + print(f"← [TTFB] Server reported latency: {latency_ms}ms") + elif event_type == "llmResponse": + # LLM text response + text = event.get("text", "") + is_final = event.get("isFinal", False) + if is_final: + # Print final LLM response + print(f"← AI: {text}") + elif self.verbose: + # Show streaming chunks only in verbose mode + display_text = text[:60] + "..." if len(text) > 60 else text + print(f" [streaming] {display_text}") + elif event_type == "trackStart": + print("← Bot started speaking") + # IMPORTANT: Accept audio again after trackStart + self._discard_audio = False + self._audio_sequence += 1 + # Reset TTFB tracking for voice responses (when no chat was sent) + if self.request_start_time is None: + self.request_start_time = time.time() + self.first_audio_received = False + # Clear any old audio in buffer + with self.audio_output_lock: + self.audio_output_buffer = b"" + elif event_type == "trackEnd": + print("← Bot finished speaking") + # Reset TTFB tracking after response completes + self.request_start_time = None + self.first_audio_received = False + elif event_type == "interrupt": + print("← Bot interrupted!") + # IMPORTANT: Discard all audio until next trackStart + self._discard_audio = True + # Clear audio buffer immediately + with self.audio_output_lock: + buffer_ms = len(self.audio_output_buffer) / (self.sample_rate * 2) * 1000 + self.audio_output_buffer = b"" + print(f" (cleared {buffer_ms:.0f}ms, discarding audio until new track)") + elif event_type == "error": + print(f"← Error: {event.get('error')}") + elif event_type == "hangup": + print(f"← Hangup: {event.get('reason')}") + self.running = False + else: + print(f"← Event: {event_type}") + + async def interactive_mode(self) -> None: + """Run interactive mode for text chat.""" + print("\n" + "=" * 50) + print("Voice Conversation Client") + print("=" * 50) + print("Speak into your microphone to talk to the AI.") + print("Or type messages to send text.") + print("") + print("Commands:") + print(" /quit - End conversation") + print(" /mute - Mute microphone") + print(" /unmute - Unmute microphone") + print(" /interrupt - Interrupt AI speech") + print(" /stats - Show statistics") + print("=" * 50 + "\n") + + while self.running: + try: + user_input = await asyncio.get_event_loop().run_in_executor( + None, input, "" + ) + + if not user_input: + continue + + # Handle commands + if user_input.startswith("/"): + cmd = user_input.lower().strip() + + if cmd == "/quit": + await self.send_hangup("User quit") + break + elif cmd == "/mute": + self.is_recording = False + print("Microphone muted") + elif cmd == "/unmute": + self.is_recording = True + print("Microphone unmuted") + elif cmd == "/interrupt": + await self.send_interrupt() + elif cmd == "/stats": + print(f"Sent: {self.bytes_sent / 1024:.1f} KB") + print(f"Received: {self.bytes_received / 1024:.1f} KB") + else: + print(f"Unknown command: {cmd}") + else: + # Send as chat message + await self.send_chat(user_input) + + except EOFError: + break + except Exception as e: + print(f"Input error: {e}") + + async def run(self, chat_message: str = None, interactive: bool = True) -> None: + """ + Run the client. + + Args: + chat_message: Optional single chat message to send + interactive: Whether to run in interactive mode + """ + try: + await self.connect() + + # Wait for answer + await asyncio.sleep(0.5) + + # Start audio input stream + print("Starting audio streams...") + + input_stream = sd.InputStream( + samplerate=self.sample_rate, + channels=1, + dtype=np.float32, + blocksize=self.chunk_samples, + device=self.input_device, + callback=self._audio_input_callback + ) + + input_stream.start() + print("Audio streams started") + + # Start background tasks + sender_task = asyncio.create_task(self.audio_sender()) + receiver_task = asyncio.create_task(self.receiver()) + playback_task = asyncio.create_task(self._playback_task()) + + if chat_message: + # Send single message and wait + await self.send_chat(chat_message) + await asyncio.sleep(15) + elif interactive: + # Run interactive mode + await self.interactive_mode() + else: + # Just wait + while self.running: + await asyncio.sleep(0.1) + + # Cleanup + self.running = False + sender_task.cancel() + receiver_task.cancel() + playback_task.cancel() + + try: + await sender_task + except asyncio.CancelledError: + pass + + try: + await receiver_task + except asyncio.CancelledError: + pass + + try: + await playback_task + except asyncio.CancelledError: + pass + + input_stream.stop() + + except ConnectionRefusedError: + print(f"Error: Could not connect to {self.url}") + print("Make sure the server is running.") + except Exception as e: + print(f"Error: {e}") + finally: + await self.close() + + async def close(self) -> None: + """Close the connection.""" + self.running = False + if self.ws: + await self.ws.close() + + print(f"\nSession ended") + print(f" Total sent: {self.bytes_sent / 1024:.1f} KB") + print(f" Total received: {self.bytes_received / 1024:.1f} KB") + + +def list_devices(): + """List available audio devices.""" + print("\nAvailable audio devices:") + print("-" * 60) + devices = sd.query_devices() + for i, device in enumerate(devices): + direction = [] + if device['max_input_channels'] > 0: + direction.append("IN") + if device['max_output_channels'] > 0: + direction.append("OUT") + direction_str = "/".join(direction) if direction else "N/A" + + default = "" + if i == sd.default.device[0]: + default += " [DEFAULT INPUT]" + if i == sd.default.device[1]: + default += " [DEFAULT OUTPUT]" + + print(f" {i:2d}: {device['name'][:40]:40s} ({direction_str}){default}") + print("-" * 60) + + +async def main(): + parser = argparse.ArgumentParser( + description="Microphone client for duplex voice conversation" + ) + parser.add_argument( + "--url", + default="ws://localhost:8000/ws", + help="WebSocket server URL" + ) + parser.add_argument( + "--chat", + help="Send a single chat message instead of using microphone" + ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="Audio sample rate (default: 16000)" + ) + parser.add_argument( + "--input-device", + type=int, + help="Input device ID" + ) + parser.add_argument( + "--output-device", + type=int, + help="Output device ID" + ) + parser.add_argument( + "--list-devices", + action="store_true", + help="List available audio devices and exit" + ) + parser.add_argument( + "--no-interactive", + action="store_true", + help="Disable interactive mode" + ) + parser.add_argument( + "--verbose", "-v", + action="store_true", + help="Show streaming LLM response chunks" + ) + + args = parser.parse_args() + + if args.list_devices: + list_devices() + return + + client = MicrophoneClient( + url=args.url, + sample_rate=args.sample_rate, + input_device=args.input_device, + output_device=args.output_device + ) + client.verbose = args.verbose + + await client.run( + chat_message=args.chat, + interactive=not args.no_interactive + ) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nInterrupted by user") diff --git a/engine/examples/simple_client.py b/engine/examples/simple_client.py new file mode 100644 index 0000000..4280f93 --- /dev/null +++ b/engine/examples/simple_client.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +""" +Simple WebSocket client for testing voice conversation. +Uses PyAudio for more reliable audio playback on Windows. + +Usage: + python examples/simple_client.py + python examples/simple_client.py --text "Hello" +""" + +import argparse +import asyncio +import json +import sys +import time +import wave +import io + +try: + import numpy as np +except ImportError: + print("pip install numpy") + sys.exit(1) + +try: + import websockets +except ImportError: + print("pip install websockets") + sys.exit(1) + +# Try PyAudio first (more reliable on Windows) +try: + import pyaudio + PYAUDIO_AVAILABLE = True +except ImportError: + PYAUDIO_AVAILABLE = False + print("PyAudio not available, trying sounddevice...") + +try: + import sounddevice as sd + SD_AVAILABLE = True +except ImportError: + SD_AVAILABLE = False + +if not PYAUDIO_AVAILABLE and not SD_AVAILABLE: + print("Please install pyaudio or sounddevice:") + print(" pip install pyaudio") + print(" or: pip install sounddevice") + sys.exit(1) + + +class SimpleVoiceClient: + """Simple voice client with reliable audio playback.""" + + def __init__(self, url: str, sample_rate: int = 16000): + self.url = url + self.sample_rate = sample_rate + self.ws = None + self.running = False + + # Audio buffer + self.audio_buffer = b"" + + # PyAudio setup + if PYAUDIO_AVAILABLE: + self.pa = pyaudio.PyAudio() + self.stream = None + + # Stats + self.bytes_received = 0 + + # TTFB tracking (Time to First Byte) + self.request_start_time = None + self.first_audio_received = False + + # Interrupt handling - discard audio until next trackStart + self._discard_audio = False + + async def connect(self): + """Connect to server.""" + print(f"Connecting to {self.url}...") + self.ws = await websockets.connect(self.url) + self.running = True + print("Connected!") + + # Send invite + await self.ws.send(json.dumps({ + "command": "invite", + "option": {"codec": "pcm", "sampleRate": self.sample_rate} + })) + print("-> invite") + + async def send_chat(self, text: str): + """Send chat message.""" + # Reset TTFB tracking for new request + self.request_start_time = time.time() + self.first_audio_received = False + + await self.ws.send(json.dumps({"command": "chat", "text": text})) + print(f"-> chat: {text}") + + def play_audio(self, audio_data: bytes): + """Play audio data immediately.""" + if len(audio_data) == 0: + return + + if PYAUDIO_AVAILABLE: + # Use PyAudio - more reliable on Windows + if self.stream is None: + self.stream = self.pa.open( + format=pyaudio.paInt16, + channels=1, + rate=self.sample_rate, + output=True, + frames_per_buffer=1024 + ) + self.stream.write(audio_data) + elif SD_AVAILABLE: + # Use sounddevice + samples = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32767.0 + sd.play(samples, self.sample_rate, blocking=True) + + async def receive_loop(self): + """Receive and play audio.""" + print("\nWaiting for response...") + + while self.running: + try: + msg = await asyncio.wait_for(self.ws.recv(), timeout=0.1) + + if isinstance(msg, bytes): + # Audio data + self.bytes_received += len(msg) + duration_ms = len(msg) / (self.sample_rate * 2) * 1000 + + # Check if we should discard this audio (after interrupt) + if self._discard_audio: + print(f"<- audio: {len(msg)} bytes ({duration_ms:.0f}ms) [DISCARDED]") + continue + + # Calculate and display TTFB for first audio packet + if not self.first_audio_received and self.request_start_time: + client_ttfb_ms = (time.time() - self.request_start_time) * 1000 + self.first_audio_received = True + print(f"<- [TTFB] Client first audio latency: {client_ttfb_ms:.0f}ms") + + print(f"<- audio: {len(msg)} bytes ({duration_ms:.0f}ms)") + + # Play immediately in executor to not block + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self.play_audio, msg) + else: + # JSON event + event = json.loads(msg) + etype = event.get("event", "?") + + if etype == "transcript": + # User speech transcription + text = event.get("text", "") + is_final = event.get("isFinal", False) + if is_final: + print(f"<- You said: {text}") + else: + print(f"<- [listening] {text}", end="\r") + elif etype == "ttfb": + # Server-side TTFB event + latency_ms = event.get("latencyMs", 0) + print(f"<- [TTFB] Server reported latency: {latency_ms}ms") + elif etype == "trackStart": + # New track starting - accept audio again + self._discard_audio = False + print(f"<- {etype}") + elif etype == "interrupt": + # Interrupt - discard audio until next trackStart + self._discard_audio = True + print(f"<- {etype} (discarding audio until new track)") + elif etype == "hangup": + print(f"<- {etype}") + self.running = False + break + else: + print(f"<- {etype}") + + except asyncio.TimeoutError: + continue + except websockets.ConnectionClosed: + print("Connection closed") + self.running = False + break + + async def run(self, text: str = None): + """Run the client.""" + try: + await self.connect() + await asyncio.sleep(0.5) + + # Start receiver + recv_task = asyncio.create_task(self.receive_loop()) + + if text: + await self.send_chat(text) + # Wait for response + await asyncio.sleep(30) + else: + # Interactive mode + print("\nType a message and press Enter (or 'quit' to exit):") + while self.running: + try: + user_input = await asyncio.get_event_loop().run_in_executor( + None, input, "> " + ) + if user_input.lower() == 'quit': + break + if user_input.strip(): + await self.send_chat(user_input) + except EOFError: + break + + self.running = False + recv_task.cancel() + try: + await recv_task + except asyncio.CancelledError: + pass + + finally: + await self.close() + + async def close(self): + """Close connections.""" + self.running = False + + if PYAUDIO_AVAILABLE: + if self.stream: + self.stream.stop_stream() + self.stream.close() + self.pa.terminate() + + if self.ws: + await self.ws.close() + + print(f"\nTotal audio received: {self.bytes_received / 1024:.1f} KB") + + +def list_audio_devices(): + """List available audio devices.""" + print("\n=== Audio Devices ===") + + if PYAUDIO_AVAILABLE: + pa = pyaudio.PyAudio() + print("\nPyAudio devices:") + for i in range(pa.get_device_count()): + info = pa.get_device_info_by_index(i) + if info['maxOutputChannels'] > 0: + default = " [DEFAULT]" if i == pa.get_default_output_device_info()['index'] else "" + print(f" {i}: {info['name']}{default}") + pa.terminate() + + if SD_AVAILABLE: + print("\nSounddevice devices:") + for i, d in enumerate(sd.query_devices()): + if d['max_output_channels'] > 0: + default = " [DEFAULT]" if i == sd.default.device[1] else "" + print(f" {i}: {d['name']}{default}") + + +async def main(): + parser = argparse.ArgumentParser(description="Simple voice client") + parser.add_argument("--url", default="ws://localhost:8000/ws") + parser.add_argument("--text", help="Send text and play response") + parser.add_argument("--list-devices", action="store_true") + parser.add_argument("--sample-rate", type=int, default=16000) + + args = parser.parse_args() + + if args.list_devices: + list_audio_devices() + return + + client = SimpleVoiceClient(args.url, args.sample_rate) + await client.run(args.text) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/engine/examples/test_websocket.py b/engine/examples/test_websocket.py new file mode 100644 index 0000000..20d388d --- /dev/null +++ b/engine/examples/test_websocket.py @@ -0,0 +1,166 @@ +"""WebSocket endpoint test client. + +Tests the /ws endpoint with sine wave or file audio streaming. +Based on reference/py-active-call/exec/test_ws_endpoint/test_ws.py +""" + +import asyncio +import aiohttp +import json +import struct +import math +import argparse +import os +from datetime import datetime + +# Configuration +SERVER_URL = "ws://localhost:8000/ws" +SAMPLE_RATE = 16000 +FREQUENCY = 440 # 440Hz Sine Wave +CHUNK_DURATION_MS = 20 +# 16kHz * 16-bit (2 bytes) * 20ms = 640 bytes per chunk +CHUNK_SIZE_BYTES = int(SAMPLE_RATE * 2 * (CHUNK_DURATION_MS / 1000.0)) + + +def generate_sine_wave(duration_ms=1000): + """Generates sine wave audio (16kHz mono PCM 16-bit).""" + num_samples = int(SAMPLE_RATE * (duration_ms / 1000.0)) + audio_data = bytearray() + + for x in range(num_samples): + # Generate sine wave sample + value = int(32767.0 * math.sin(2 * math.pi * FREQUENCY * x / SAMPLE_RATE)) + # Pack as little-endian 16-bit integer + audio_data.extend(struct.pack(' None: + """Connect to WebSocket server.""" + self.log_event("→", f"Connecting to {self.url}...") + self.ws = await websockets.connect(self.url) + self.running = True + self.log_event("←", "Connected!") + + # Send invite command + await self.send_command({ + "command": "invite", + "option": { + "codec": "pcm", + "sampleRate": self.sample_rate + } + }) + + async def send_command(self, cmd: dict) -> None: + """Send JSON command to server.""" + if self.ws: + await self.ws.send(json.dumps(cmd)) + self.log_event("→", f"Command: {cmd.get('command', 'unknown')}") + + async def send_hangup(self, reason: str = "Session complete") -> None: + """Send hangup command.""" + await self.send_command({ + "command": "hangup", + "reason": reason + }) + + def load_wav_file(self) -> tuple[np.ndarray, int]: + """ + Load and prepare WAV file for sending. + + Returns: + Tuple of (audio_data as int16 numpy array, original sample rate) + """ + if not self.input_file.exists(): + raise FileNotFoundError(f"Input file not found: {self.input_file}") + + # Load audio file + audio_data, file_sample_rate = sf.read(self.input_file) + self.log_event("→", f"Loaded: {self.input_file}") + self.log_event("→", f" Original sample rate: {file_sample_rate} Hz") + self.log_event("→", f" Duration: {len(audio_data) / file_sample_rate:.2f}s") + + # Convert stereo to mono if needed + if len(audio_data.shape) > 1: + audio_data = audio_data.mean(axis=1) + self.log_event("→", " Converted stereo to mono") + + # Resample if needed + if file_sample_rate != self.sample_rate: + # Simple resampling using numpy + duration = len(audio_data) / file_sample_rate + num_samples = int(duration * self.sample_rate) + indices = np.linspace(0, len(audio_data) - 1, num_samples) + audio_data = np.interp(indices, np.arange(len(audio_data)), audio_data) + self.log_event("→", f" Resampled to {self.sample_rate} Hz") + + # Convert to int16 + if audio_data.dtype != np.int16: + # Normalize to [-1, 1] if needed + max_val = np.max(np.abs(audio_data)) + if max_val > 1.0: + audio_data = audio_data / max_val + audio_data = (audio_data * 32767).astype(np.int16) + + self.log_event("→", f" Prepared: {len(audio_data)} samples ({len(audio_data)/self.sample_rate:.2f}s)") + + return audio_data, file_sample_rate + + async def audio_sender(self, audio_data: np.ndarray) -> None: + """Send audio data to server in chunks.""" + total_samples = len(audio_data) + chunk_size = self.chunk_samples + sent_samples = 0 + + self.send_start_time = time.time() + self.log_event("→", f"Starting audio transmission ({total_samples} samples)...") + + while sent_samples < total_samples and self.running: + # Get next chunk + end_sample = min(sent_samples + chunk_size, total_samples) + chunk = audio_data[sent_samples:end_sample] + chunk_bytes = chunk.tobytes() + + # Send to server + if self.ws: + await self.ws.send(chunk_bytes) + self.bytes_sent += len(chunk_bytes) + + sent_samples = end_sample + + # Progress logging (every 500ms worth of audio) + if self.verbose and sent_samples % (self.sample_rate // 2) == 0: + progress = (sent_samples / total_samples) * 100 + print(f" Sending: {progress:.0f}%", end="\r") + + # Delay to simulate real-time streaming + # Server expects audio at real-time pace for VAD/ASR to work properly + await asyncio.sleep(self.chunk_duration_ms / 1000) + + self.send_completed = True + elapsed = time.time() - self.send_start_time + self.log_event("→", f"Audio transmission complete ({elapsed:.2f}s, {self.bytes_sent/1024:.1f} KB)") + + async def receiver(self) -> None: + """Receive messages from server.""" + try: + while self.running: + try: + message = await asyncio.wait_for(self.ws.recv(), timeout=0.1) + + if isinstance(message, bytes): + # Audio data received + self.bytes_received += len(message) + self.received_audio.extend(message) + + # Calculate TTFB on first audio of each response + if self.waiting_for_first_audio and self.response_start_time is not None: + ttfb_ms = (time.time() - self.response_start_time) * 1000 + self.ttfb_ms = ttfb_ms + self.ttfb_list.append(ttfb_ms) + self.waiting_for_first_audio = False + self.log_event("←", f"[TTFB] First audio latency: {ttfb_ms:.0f}ms") + + # Log progress + duration_ms = len(message) / (self.sample_rate * 2) * 1000 + total_ms = len(self.received_audio) / (self.sample_rate * 2) * 1000 + if self.verbose: + print(f"← Audio: +{duration_ms:.0f}ms (total: {total_ms:.0f}ms)", end="\r") + + else: + # JSON event + event = json.loads(message) + await self._handle_event(event) + + except asyncio.TimeoutError: + continue + except websockets.ConnectionClosed: + self.log_event("←", "Connection closed") + self.running = False + break + + except asyncio.CancelledError: + pass + except Exception as e: + self.log_event("!", f"Receiver error: {e}") + self.running = False + + async def _handle_event(self, event: dict) -> None: + """Handle incoming event.""" + event_type = event.get("event", "unknown") + + if event_type == "answer": + self.log_event("←", "Session ready!") + elif event_type == "speaking": + self.log_event("←", "Speech detected") + elif event_type == "silence": + self.log_event("←", "Silence detected") + elif event_type == "transcript": + # ASR transcript (interim = asrDelta-style, final = asrFinal-style) + text = event.get("text", "") + is_final = event.get("isFinal", False) + if is_final: + # Clear interim line and print final + print(" " * 80, end="\r") + self.log_event("←", f"→ You: {text}") + else: + # Interim result - show with indicator (overwrite same line, as in mic_client) + display_text = text[:60] + "..." if len(text) > 60 else text + print(f" [listening] {display_text}".ljust(80), end="\r") + elif event_type == "ttfb": + latency_ms = event.get("latencyMs", 0) + self.log_event("←", f"[TTFB] Server latency: {latency_ms}ms") + elif event_type == "llmResponse": + text = event.get("text", "") + is_final = event.get("isFinal", False) + if is_final: + self.log_event("←", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}") + elif self.verbose: + # Show streaming chunks only in verbose mode + self.log_event("←", f"LLM: {text}") + elif event_type == "trackStart": + self.track_started = True + self.response_start_time = time.time() + self.waiting_for_first_audio = True + self.log_event("←", "Bot started speaking") + elif event_type == "trackEnd": + self.track_ended = True + self.log_event("←", "Bot finished speaking") + elif event_type == "interrupt": + self.log_event("←", "Bot interrupted!") + elif event_type == "error": + self.log_event("!", f"Error: {event.get('error')}") + elif event_type == "hangup": + self.log_event("←", f"Hangup: {event.get('reason')}") + self.running = False + else: + self.log_event("←", f"Event: {event_type}") + + def save_output_wav(self) -> None: + """Save received audio to output WAV file.""" + if not self.received_audio: + self.log_event("!", "No audio received to save") + return + + # Convert bytes to numpy array + audio_data = np.frombuffer(bytes(self.received_audio), dtype=np.int16) + + # Ensure output directory exists + self.output_file.parent.mkdir(parents=True, exist_ok=True) + + # Save using wave module for compatibility + with wave.open(str(self.output_file), 'wb') as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) # 16-bit + wav_file.setframerate(self.sample_rate) + wav_file.writeframes(audio_data.tobytes()) + + duration = len(audio_data) / self.sample_rate + self.log_event("→", f"Saved output: {self.output_file}") + self.log_event("→", f" Duration: {duration:.2f}s ({len(audio_data)} samples)") + self.log_event("→", f" Size: {len(self.received_audio)/1024:.1f} KB") + + async def run(self) -> None: + """Run the WAV file test.""" + try: + # Load input WAV file + audio_data, _ = self.load_wav_file() + + # Connect to server + await self.connect() + + # Wait for answer + await asyncio.sleep(0.5) + + # Start receiver task + receiver_task = asyncio.create_task(self.receiver()) + + # Send audio + await self.audio_sender(audio_data) + + # Wait for response + self.log_event("→", f"Waiting {self.wait_time}s for response...") + + wait_start = time.time() + while self.running and (time.time() - wait_start) < self.wait_time: + # Check if track has ended (response complete) + if self.track_ended and self.send_completed: + # Give a little extra time for any remaining audio + await asyncio.sleep(1.0) + break + await asyncio.sleep(0.1) + + # Cleanup + self.running = False + receiver_task.cancel() + + try: + await receiver_task + except asyncio.CancelledError: + pass + + # Save output + self.save_output_wav() + + # Print summary + self._print_summary() + + except FileNotFoundError as e: + print(f"Error: {e}") + sys.exit(1) + except ConnectionRefusedError: + print(f"Error: Could not connect to {self.url}") + print("Make sure the server is running.") + sys.exit(1) + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + finally: + await self.close() + + def _print_summary(self): + """Print session summary.""" + print("\n" + "=" * 50) + print("Session Summary") + print("=" * 50) + print(f" Input file: {self.input_file}") + print(f" Output file: {self.output_file}") + print(f" Bytes sent: {self.bytes_sent / 1024:.1f} KB") + print(f" Bytes received: {self.bytes_received / 1024:.1f} KB") + if self.ttfb_list: + if len(self.ttfb_list) == 1: + print(f" TTFB: {self.ttfb_list[0]:.0f} ms") + else: + print(f" TTFB (per response): {', '.join(f'{t:.0f}ms' for t in self.ttfb_list)}") + if self.received_audio: + duration = len(self.received_audio) / (self.sample_rate * 2) + print(f" Response duration: {duration:.2f}s") + print("=" * 50) + + async def close(self) -> None: + """Close the connection.""" + self.running = False + if self.ws: + try: + await self.ws.close() + except: + pass + + +async def main(): + parser = argparse.ArgumentParser( + description="WAV file client for testing duplex voice conversation" + ) + parser.add_argument( + "--input", "-i", + required=True, + help="Input WAV file path" + ) + parser.add_argument( + "--output", "-o", + required=True, + help="Output WAV file path for response" + ) + parser.add_argument( + "--url", + default="ws://localhost:8000/ws", + help="WebSocket server URL (default: ws://localhost:8000/ws)" + ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="Target sample rate for audio (default: 16000)" + ) + parser.add_argument( + "--chunk-duration", + type=int, + default=20, + help="Chunk duration in ms for sending (default: 20)" + ) + parser.add_argument( + "--wait-time", "-w", + type=float, + default=15.0, + help="Time to wait for response after sending (default: 15.0)" + ) + parser.add_argument( + "--verbose", "-v", + action="store_true", + help="Enable verbose output" + ) + + args = parser.parse_args() + + client = WavFileClient( + url=args.url, + input_file=args.input, + output_file=args.output, + sample_rate=args.sample_rate, + chunk_duration_ms=args.chunk_duration, + wait_time=args.wait_time, + verbose=args.verbose + ) + + await client.run() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nInterrupted by user") diff --git a/engine/examples/web_client.html b/engine/examples/web_client.html new file mode 100644 index 0000000..bee3d28 --- /dev/null +++ b/engine/examples/web_client.html @@ -0,0 +1,742 @@ + + + + + + Duplex Voice Web Client + + + +
+
+

Duplex Voice Client

+
Browser client for the WebSocket duplex pipeline. Device selection + event logging.
+
+ +
+
+

Connection

+
+ + +
+
+ + +
+
+
+
+
Disconnected
+
Waiting for connection
+
+
+ +

Devices

+
+
+ + +
+
+ + +
+
+
+ + + +
+ +

Chat

+
+ +
+ + +
+
+
+ +
+
+

Chat History

+
+
+
+

Event Log

+
+
+
+
+ + + + + + + + diff --git a/engine/models/__init__.py b/engine/models/__init__.py new file mode 100644 index 0000000..924d5fd --- /dev/null +++ b/engine/models/__init__.py @@ -0,0 +1 @@ +"""Data Models Package""" diff --git a/engine/models/commands.py b/engine/models/commands.py new file mode 100644 index 0000000..5bcf47e --- /dev/null +++ b/engine/models/commands.py @@ -0,0 +1,143 @@ +"""Protocol command models matching the original active-call API.""" + +from typing import Optional, Dict, Any +from pydantic import BaseModel, Field + + +class InviteCommand(BaseModel): + """Invite command to initiate a call.""" + + command: str = Field(default="invite", description="Command type") + option: Optional[Dict[str, Any]] = Field(default=None, description="Call configuration options") + + +class AcceptCommand(BaseModel): + """Accept command to accept an incoming call.""" + + command: str = Field(default="accept", description="Command type") + option: Optional[Dict[str, Any]] = Field(default=None, description="Call configuration options") + + +class RejectCommand(BaseModel): + """Reject command to reject an incoming call.""" + + command: str = Field(default="reject", description="Command type") + reason: str = Field(default="", description="Reason for rejection") + code: Optional[int] = Field(default=None, description="SIP response code") + + +class RingingCommand(BaseModel): + """Ringing command to send ringing response.""" + + command: str = Field(default="ringing", description="Command type") + recorder: Optional[Dict[str, Any]] = Field(default=None, description="Call recording configuration") + early_media: bool = Field(default=False, description="Enable early media") + ringtone: Optional[str] = Field(default=None, description="Custom ringtone URL") + + +class TTSCommand(BaseModel): + """TTS command to convert text to speech.""" + + command: str = Field(default="tts", description="Command type") + text: str = Field(..., description="Text to synthesize") + speaker: Optional[str] = Field(default=None, description="Speaker voice name") + play_id: Optional[str] = Field(default=None, description="Unique identifier for this TTS session") + auto_hangup: bool = Field(default=False, description="Auto hangup after TTS completion") + streaming: bool = Field(default=False, description="Streaming text input") + end_of_stream: bool = Field(default=False, description="End of streaming input") + wait_input_timeout: Optional[int] = Field(default=None, description="Max time to wait for input (seconds)") + option: Optional[Dict[str, Any]] = Field(default=None, description="TTS provider specific options") + + +class PlayCommand(BaseModel): + """Play command to play audio from URL.""" + + command: str = Field(default="play", description="Command type") + url: str = Field(..., description="URL of audio file to play") + auto_hangup: bool = Field(default=False, description="Auto hangup after playback") + wait_input_timeout: Optional[int] = Field(default=None, description="Max time to wait for input (seconds)") + + +class InterruptCommand(BaseModel): + """Interrupt command to interrupt current playback.""" + + command: str = Field(default="interrupt", description="Command type") + graceful: bool = Field(default=False, description="Wait for current TTS to complete") + + +class PauseCommand(BaseModel): + """Pause command to pause current playback.""" + + command: str = Field(default="pause", description="Command type") + + +class ResumeCommand(BaseModel): + """Resume command to resume paused playback.""" + + command: str = Field(default="resume", description="Command type") + + +class HangupCommand(BaseModel): + """Hangup command to end the call.""" + + command: str = Field(default="hangup", description="Command type") + reason: Optional[str] = Field(default=None, description="Reason for hangup") + initiator: Optional[str] = Field(default=None, description="Who initiated the hangup") + + +class HistoryCommand(BaseModel): + """History command to add conversation history.""" + + command: str = Field(default="history", description="Command type") + speaker: str = Field(..., description="Speaker identifier") + text: str = Field(..., description="Conversation text") + + +class ChatCommand(BaseModel): + """Chat command for text-based conversation.""" + + command: str = Field(default="chat", description="Command type") + text: str = Field(..., description="Chat text message") + + +# Command type mapping +COMMAND_TYPES = { + "invite": InviteCommand, + "accept": AcceptCommand, + "reject": RejectCommand, + "ringing": RingingCommand, + "tts": TTSCommand, + "play": PlayCommand, + "interrupt": InterruptCommand, + "pause": PauseCommand, + "resume": ResumeCommand, + "hangup": HangupCommand, + "history": HistoryCommand, + "chat": ChatCommand, +} + + +def parse_command(data: Dict[str, Any]) -> BaseModel: + """ + Parse a command from JSON data. + + Args: + data: JSON data as dictionary + + Returns: + Parsed command model + + Raises: + ValueError: If command type is unknown + """ + command_type = data.get("command") + + if not command_type: + raise ValueError("Missing 'command' field") + + command_class = COMMAND_TYPES.get(command_type) + + if not command_class: + raise ValueError(f"Unknown command type: {command_type}") + + return command_class(**data) diff --git a/engine/models/config.py b/engine/models/config.py new file mode 100644 index 0000000..009411e --- /dev/null +++ b/engine/models/config.py @@ -0,0 +1,126 @@ +"""Configuration models for call options.""" + +from typing import Optional, Dict, Any, List +from pydantic import BaseModel, Field + + +class VADOption(BaseModel): + """Voice Activity Detection configuration.""" + + type: str = Field(default="silero", description="VAD algorithm type (silero, webrtc)") + samplerate: int = Field(default=16000, description="Audio sample rate for VAD") + speech_padding: int = Field(default=250, description="Speech padding in milliseconds") + silence_padding: int = Field(default=100, description="Silence padding in milliseconds") + ratio: float = Field(default=0.5, description="Voice detection ratio threshold") + voice_threshold: float = Field(default=0.5, description="Voice energy threshold") + max_buffer_duration_secs: int = Field(default=50, description="Maximum buffer duration in seconds") + silence_timeout: Optional[int] = Field(default=None, description="Silence timeout in milliseconds") + endpoint: Optional[str] = Field(default=None, description="Custom VAD service endpoint") + secret_key: Optional[str] = Field(default=None, description="VAD service secret key") + secret_id: Optional[str] = Field(default=None, description="VAD service secret ID") + + +class ASROption(BaseModel): + """Automatic Speech Recognition configuration.""" + + provider: str = Field(..., description="ASR provider (tencent, aliyun, openai, etc.)") + language: Optional[str] = Field(default=None, description="Language code (zh-CN, en-US)") + app_id: Optional[str] = Field(default=None, description="Application ID") + secret_id: Optional[str] = Field(default=None, description="Secret ID for authentication") + secret_key: Optional[str] = Field(default=None, description="Secret key for authentication") + model_type: Optional[str] = Field(default=None, description="ASR model type (16k_zh, 8k_en)") + buffer_size: Optional[int] = Field(default=None, description="Audio buffer size in bytes") + samplerate: Optional[int] = Field(default=None, description="Audio sample rate") + endpoint: Optional[str] = Field(default=None, description="Custom ASR service endpoint") + extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional parameters") + start_when_answer: bool = Field(default=False, description="Start ASR when call is answered") + + +class TTSOption(BaseModel): + """Text-to-Speech configuration.""" + + samplerate: Optional[int] = Field(default=None, description="TTS output sample rate") + provider: str = Field(default="msedge", description="TTS provider (tencent, aliyun, deepgram, msedge)") + speed: float = Field(default=1.0, description="Speech speed multiplier") + app_id: Optional[str] = Field(default=None, description="Application ID") + secret_id: Optional[str] = Field(default=None, description="Secret ID for authentication") + secret_key: Optional[str] = Field(default=None, description="Secret key for authentication") + volume: Optional[int] = Field(default=None, description="Speech volume level (1-10)") + speaker: Optional[str] = Field(default=None, description="Voice speaker name") + codec: Optional[str] = Field(default=None, description="Audio codec") + subtitle: bool = Field(default=False, description="Enable subtitle generation") + emotion: Optional[str] = Field(default=None, description="Speech emotion") + endpoint: Optional[str] = Field(default=None, description="Custom TTS service endpoint") + extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional parameters") + max_concurrent_tasks: Optional[int] = Field(default=None, description="Max concurrent tasks") + + +class RecorderOption(BaseModel): + """Call recording configuration.""" + + recorder_file: str = Field(..., description="Path to recording file") + samplerate: int = Field(default=16000, description="Recording sample rate") + ptime: int = Field(default=200, description="Packet time in milliseconds") + + +class MediaPassOption(BaseModel): + """Media pass-through configuration for external audio processing.""" + + url: str = Field(..., description="WebSocket URL for media streaming") + input_sample_rate: int = Field(default=16000, description="Sample rate of audio received from WebSocket") + output_sample_rate: int = Field(default=16000, description="Sample rate of audio sent to WebSocket") + packet_size: int = Field(default=2560, description="Packet size in bytes") + ptime: Optional[int] = Field(default=None, description="Buffered playback period in milliseconds") + + +class SipOption(BaseModel): + """SIP protocol configuration.""" + + username: Optional[str] = Field(default=None, description="SIP username") + password: Optional[str] = Field(default=None, description="SIP password") + realm: Optional[str] = Field(default=None, description="SIP realm/domain") + headers: Optional[Dict[str, str]] = Field(default=None, description="Additional SIP headers") + + +class HandlerRule(BaseModel): + """Handler routing rule.""" + + caller: Optional[str] = Field(default=None, description="Caller pattern (regex)") + callee: Optional[str] = Field(default=None, description="Callee pattern (regex)") + playbook: Optional[str] = Field(default=None, description="Playbook file path") + webhook: Optional[str] = Field(default=None, description="Webhook URL") + + +class CallOption(BaseModel): + """Comprehensive call configuration options.""" + + # Basic options + denoise: bool = Field(default=False, description="Enable noise reduction") + offer: Optional[str] = Field(default=None, description="SDP offer string") + callee: Optional[str] = Field(default=None, description="Callee SIP URI or phone number") + caller: Optional[str] = Field(default=None, description="Caller SIP URI or phone number") + + # Audio codec + codec: str = Field(default="pcm", description="Audio codec (pcm, pcma, pcmu, g722)") + + # Component configurations + recorder: Optional[RecorderOption] = Field(default=None, description="Call recording config") + asr: Optional[ASROption] = Field(default=None, description="ASR configuration") + vad: Optional[VADOption] = Field(default=None, description="VAD configuration") + tts: Optional[TTSOption] = Field(default=None, description="TTS configuration") + media_pass: Optional[MediaPassOption] = Field(default=None, description="Media pass-through config") + sip: Optional[SipOption] = Field(default=None, description="SIP configuration") + + # Timeouts and networking + handshake_timeout: Optional[int] = Field(default=None, description="Handshake timeout in seconds") + enable_ipv6: bool = Field(default=False, description="Enable IPv6 support") + inactivity_timeout: Optional[int] = Field(default=None, description="Inactivity timeout in seconds") + + # EOU configuration + eou: Optional[Dict[str, Any]] = Field(default=None, description="End of utterance detection config") + + # Extra parameters + extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional custom parameters") + + class Config: + populate_by_name = True diff --git a/engine/models/events.py b/engine/models/events.py new file mode 100644 index 0000000..031b8be --- /dev/null +++ b/engine/models/events.py @@ -0,0 +1,231 @@ +"""Protocol event models matching the original active-call API.""" + +from typing import Optional, Dict, Any +from pydantic import BaseModel, Field +from datetime import datetime + + +def current_timestamp_ms() -> int: + """Get current timestamp in milliseconds.""" + return int(datetime.now().timestamp() * 1000) + + +# Base Event Model +class BaseEvent(BaseModel): + """Base event model.""" + + event: str = Field(..., description="Event type") + track_id: str = Field(..., description="Unique track identifier") + timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp in milliseconds") + + +# Lifecycle Events +class IncomingEvent(BaseEvent): + """Incoming call event (SIP only).""" + + event: str = Field(default="incoming", description="Event type") + caller: Optional[str] = Field(default=None, description="Caller's SIP URI") + callee: Optional[str] = Field(default=None, description="Callee's SIP URI") + sdp: Optional[str] = Field(default=None, description="SDP offer from caller") + + +class AnswerEvent(BaseEvent): + """Call answered event.""" + + event: str = Field(default="answer", description="Event type") + sdp: Optional[str] = Field(default=None, description="SDP answer from server") + + +class RejectEvent(BaseEvent): + """Call rejected event.""" + + event: str = Field(default="reject", description="Event type") + reason: Optional[str] = Field(default=None, description="Rejection reason") + code: Optional[int] = Field(default=None, description="SIP response code") + + +class RingingEvent(BaseEvent): + """Call ringing event.""" + + event: str = Field(default="ringing", description="Event type") + early_media: bool = Field(default=False, description="Early media available") + + +class HangupEvent(BaseModel): + """Call hangup event.""" + + event: str = Field(default="hangup", description="Event type") + timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp") + reason: Optional[str] = Field(default=None, description="Hangup reason") + initiator: Optional[str] = Field(default=None, description="Who initiated hangup") + start_time: Optional[str] = Field(default=None, description="Call start time (ISO 8601)") + hangup_time: Optional[str] = Field(default=None, description="Hangup time (ISO 8601)") + answer_time: Optional[str] = Field(default=None, description="Answer time (ISO 8601)") + ringing_time: Optional[str] = Field(default=None, description="Ringing time (ISO 8601)") + from_: Optional[Dict[str, Any]] = Field(default=None, alias="from", description="Caller info") + to: Optional[Dict[str, Any]] = Field(default=None, description="Callee info") + extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata") + + class Config: + populate_by_name = True + + +# VAD Events +class SpeakingEvent(BaseEvent): + """Speech detected event.""" + + event: str = Field(default="speaking", description="Event type") + start_time: int = Field(default_factory=current_timestamp_ms, description="Speech start time") + + +class SilenceEvent(BaseEvent): + """Silence detected event.""" + + event: str = Field(default="silence", description="Event type") + start_time: int = Field(default_factory=current_timestamp_ms, description="Silence start time") + duration: int = Field(default=0, description="Silence duration in milliseconds") + + +# AI/ASR Events +class AsrFinalEvent(BaseEvent): + """ASR final transcription event.""" + + event: str = Field(default="asrFinal", description="Event type") + index: int = Field(..., description="ASR result sequence number") + start_time: Optional[int] = Field(default=None, description="Speech start time") + end_time: Optional[int] = Field(default=None, description="Speech end time") + text: str = Field(..., description="Transcribed text") + + +class AsrDeltaEvent(BaseEvent): + """ASR partial transcription event (streaming).""" + + event: str = Field(default="asrDelta", description="Event type") + index: int = Field(..., description="ASR result sequence number") + start_time: Optional[int] = Field(default=None, description="Speech start time") + end_time: Optional[int] = Field(default=None, description="Speech end time") + text: str = Field(..., description="Partial transcribed text") + + +class EouEvent(BaseEvent): + """End of utterance detection event.""" + + event: str = Field(default="eou", description="Event type") + completed: bool = Field(default=True, description="Whether utterance was completed") + + +# Audio Track Events +class TrackStartEvent(BaseEvent): + """Audio track start event.""" + + event: str = Field(default="trackStart", description="Event type") + play_id: Optional[str] = Field(default=None, description="Play ID from TTS/Play command") + + +class TrackEndEvent(BaseEvent): + """Audio track end event.""" + + event: str = Field(default="trackEnd", description="Event type") + duration: int = Field(..., description="Track duration in milliseconds") + ssrc: int = Field(..., description="RTP SSRC identifier") + play_id: Optional[str] = Field(default=None, description="Play ID from TTS/Play command") + + +class InterruptionEvent(BaseEvent): + """Playback interruption event.""" + + event: str = Field(default="interruption", description="Event type") + play_id: Optional[str] = Field(default=None, description="Play ID that was interrupted") + subtitle: Optional[str] = Field(default=None, description="TTS text being played") + position: Optional[int] = Field(default=None, description="Word index position") + total_duration: Optional[int] = Field(default=None, description="Total TTS duration") + current: Optional[int] = Field(default=None, description="Elapsed time when interrupted") + + +# System Events +class ErrorEvent(BaseEvent): + """Error event.""" + + event: str = Field(default="error", description="Event type") + sender: str = Field(..., description="Component that generated the error") + error: str = Field(..., description="Error message") + code: Optional[int] = Field(default=None, description="Error code") + + +class MetricsEvent(BaseModel): + """Performance metrics event.""" + + event: str = Field(default="metrics", description="Event type") + timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp") + key: str = Field(..., description="Metric key") + duration: int = Field(..., description="Duration in milliseconds") + data: Optional[Dict[str, Any]] = Field(default=None, description="Additional metric data") + + +class AddHistoryEvent(BaseModel): + """Conversation history entry added event.""" + + event: str = Field(default="addHistory", description="Event type") + timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp") + sender: Optional[str] = Field(default=None, description="Component that added history") + speaker: str = Field(..., description="Speaker identifier") + text: str = Field(..., description="Conversation text") + + +class DTMFEvent(BaseEvent): + """DTMF tone detected event.""" + + event: str = Field(default="dtmf", description="Event type") + digit: str = Field(..., description="DTMF digit (0-9, *, #, A-D)") + + +class HeartBeatEvent(BaseModel): + """Server-to-client heartbeat to keep connection alive.""" + + event: str = Field(default="heartBeat", description="Event type") + timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp in milliseconds") + + +# Event type mapping +EVENT_TYPES = { + "incoming": IncomingEvent, + "answer": AnswerEvent, + "reject": RejectEvent, + "ringing": RingingEvent, + "hangup": HangupEvent, + "speaking": SpeakingEvent, + "silence": SilenceEvent, + "asrFinal": AsrFinalEvent, + "asrDelta": AsrDeltaEvent, + "eou": EouEvent, + "trackStart": TrackStartEvent, + "trackEnd": TrackEndEvent, + "interruption": InterruptionEvent, + "error": ErrorEvent, + "metrics": MetricsEvent, + "addHistory": AddHistoryEvent, + "dtmf": DTMFEvent, + "heartBeat": HeartBeatEvent, +} + + +def create_event(event_type: str, **kwargs) -> BaseModel: + """ + Create an event model. + + Args: + event_type: Type of event to create + **kwargs: Event fields + + Returns: + Event model instance + + Raises: + ValueError: If event type is unknown + """ + event_class = EVENT_TYPES.get(event_type) + + if not event_class: + raise ValueError(f"Unknown event type: {event_type}") + + return event_class(event=event_type, **kwargs) diff --git a/engine/processors/__init__.py b/engine/processors/__init__.py new file mode 100644 index 0000000..1952777 --- /dev/null +++ b/engine/processors/__init__.py @@ -0,0 +1,6 @@ +"""Audio Processors Package""" + +from processors.eou import EouDetector +from processors.vad import SileroVAD, VADProcessor + +__all__ = ["EouDetector", "SileroVAD", "VADProcessor"] diff --git a/engine/processors/eou.py b/engine/processors/eou.py new file mode 100644 index 0000000..baf6807 --- /dev/null +++ b/engine/processors/eou.py @@ -0,0 +1,80 @@ +"""End-of-Utterance Detection.""" + +import time +from typing import Optional + + +class EouDetector: + """ + End-of-utterance detector. Fires EOU only after continuous silence for + silence_threshold_ms. Short pauses between sentences do not trigger EOU + because speech resets the silence timer (one EOU per turn). + """ + + def __init__(self, silence_threshold_ms: int = 1000, min_speech_duration_ms: int = 250): + """ + Initialize EOU detector. + + Args: + silence_threshold_ms: How long silence must last to trigger EOU (default 1000ms) + min_speech_duration_ms: Minimum speech duration to consider valid (default 250ms) + """ + self.threshold = silence_threshold_ms / 1000.0 + self.min_speech = min_speech_duration_ms / 1000.0 + self._silence_threshold_ms = silence_threshold_ms + self._min_speech_duration_ms = min_speech_duration_ms + + # State + self.is_speaking = False + self.speech_start_time = 0.0 + self.silence_start_time: Optional[float] = None + self.triggered = False + + def process(self, vad_status: str) -> bool: + """ + Process VAD status and detect end of utterance. + + Input: "Speech" or "Silence" (from VAD). + Output: True if EOU detected, False otherwise. + + Short breaks between phrases reset the silence clock when speech + resumes, so only one EOU is emitted after the user truly stops. + """ + now = time.time() + + if vad_status == "Speech": + if not self.is_speaking: + self.is_speaking = True + self.speech_start_time = now + self.triggered = False + # Any speech resets silence timer — short pause + more speech = one utterance + self.silence_start_time = None + return False + + if vad_status == "Silence": + if not self.is_speaking: + return False + if self.silence_start_time is None: + self.silence_start_time = now + + speech_duration = self.silence_start_time - self.speech_start_time + if speech_duration < self.min_speech: + self.is_speaking = False + self.silence_start_time = None + return False + + silence_duration = now - self.silence_start_time + if silence_duration >= self.threshold and not self.triggered: + self.triggered = True + self.is_speaking = False + self.silence_start_time = None + return True + + return False + + def reset(self) -> None: + """Reset EOU detector state.""" + self.is_speaking = False + self.speech_start_time = 0.0 + self.silence_start_time = None + self.triggered = False diff --git a/engine/processors/tracks.py b/engine/processors/tracks.py new file mode 100644 index 0000000..71f3cbd --- /dev/null +++ b/engine/processors/tracks.py @@ -0,0 +1,168 @@ +"""Audio track processing for WebRTC.""" + +import asyncio +import fractions +from typing import Optional +from loguru import logger + +# Try to import aiortc (optional for WebRTC functionality) +try: + from aiortc import AudioStreamTrack + AIORTC_AVAILABLE = True +except ImportError: + AIORTC_AVAILABLE = False + AudioStreamTrack = object # Dummy class for type hints + +# Try to import PyAV (optional for audio resampling) +try: + from av import AudioFrame, AudioResampler + AV_AVAILABLE = True +except ImportError: + AV_AVAILABLE = False + # Create dummy classes for type hints + class AudioFrame: + pass + class AudioResampler: + pass + +import numpy as np + + +class Resampled16kTrack(AudioStreamTrack if AIORTC_AVAILABLE else object): + """ + Audio track that resamples input to 16kHz mono PCM. + + Wraps an existing MediaStreamTrack and converts its output + to 16kHz mono 16-bit PCM format for the pipeline. + """ + + def __init__(self, track, target_sample_rate: int = 16000): + """ + Initialize resampled track. + + Args: + track: Source MediaStreamTrack + target_sample_rate: Target sample rate (default: 16000) + """ + if not AIORTC_AVAILABLE: + raise RuntimeError("aiortc not available - Resampled16kTrack cannot be used") + + super().__init__() + self.track = track + self.target_sample_rate = target_sample_rate + + if AV_AVAILABLE: + self.resampler = AudioResampler( + format="s16", + layout="mono", + rate=target_sample_rate + ) + else: + logger.warning("PyAV not available, audio resampling disabled") + self.resampler = None + + self._closed = False + + async def recv(self): + """ + Receive and resample next audio frame. + + Returns: + Resampled AudioFrame at 16kHz mono + """ + if self._closed: + raise RuntimeError("Track is closed") + + # Get frame from source track + frame = await self.track.recv() + + # Resample the frame if AV is available + if AV_AVAILABLE and self.resampler: + resampled_frame = self.resampler.resample(frame) + # Ensure the frame has the correct format + resampled_frame.sample_rate = self.target_sample_rate + return resampled_frame + else: + # Return frame as-is if AV is not available + return frame + + async def stop(self) -> None: + """Stop the track and cleanup resources.""" + self._closed = True + if hasattr(self, 'resampler') and self.resampler: + del self.resampler + logger.debug("Resampled track stopped") + + +class SineWaveTrack(AudioStreamTrack if AIORTC_AVAILABLE else object): + """ + Synthetic audio track that generates a sine wave. + + Useful for testing without requiring real audio input. + """ + + def __init__(self, sample_rate: int = 16000, frequency: int = 440): + """ + Initialize sine wave track. + + Args: + sample_rate: Audio sample rate (default: 16000) + frequency: Sine wave frequency in Hz (default: 440) + """ + if not AIORTC_AVAILABLE: + raise RuntimeError("aiortc not available - SineWaveTrack cannot be used") + + super().__init__() + self.sample_rate = sample_rate + self.frequency = frequency + self.counter = 0 + self._stopped = False + + async def recv(self): + """ + Generate next audio frame with sine wave. + + Returns: + AudioFrame with sine wave data + """ + if self._stopped: + raise RuntimeError("Track is stopped") + + # Generate 20ms of audio + samples = int(self.sample_rate * 0.02) + pts = self.counter + time_base = fractions.Fraction(1, self.sample_rate) + + # Generate sine wave + t = np.linspace( + self.counter / self.sample_rate, + (self.counter + samples) / self.sample_rate, + samples, + endpoint=False + ) + + # Generate sine wave (Int16 PCM) + data = (0.5 * np.sin(2 * np.pi * self.frequency * t) * 32767).astype(np.int16) + + # Update counter + self.counter += samples + + # Create AudioFrame if AV is available + if AV_AVAILABLE: + frame = AudioFrame.from_ndarray(data.reshape(1, -1), format='s16', layout='mono') + frame.pts = pts + frame.time_base = time_base + frame.sample_rate = self.sample_rate + return frame + else: + # Return simple data structure if AV is not available + return { + 'data': data, + 'sample_rate': self.sample_rate, + 'pts': pts, + 'time_base': time_base + } + + def stop(self) -> None: + """Stop the track.""" + self._stopped = True diff --git a/engine/processors/vad.py b/engine/processors/vad.py new file mode 100644 index 0000000..cad6e8b --- /dev/null +++ b/engine/processors/vad.py @@ -0,0 +1,221 @@ +"""Voice Activity Detection using Silero VAD.""" + +import asyncio +import os +from typing import Tuple, Optional +import numpy as np +from loguru import logger + + +# Try to import onnxruntime (optional for VAD functionality) +try: + import onnxruntime as ort + ONNX_AVAILABLE = True +except ImportError: + ONNX_AVAILABLE = False + ort = None + logger.warning("onnxruntime not available - VAD will be disabled") + + +class SileroVAD: + """ + Voice Activity Detection using Silero VAD model. + + Detects speech in audio chunks using the Silero VAD ONNX model. + Returns "Speech" or "Silence" for each audio chunk. + """ + + def __init__(self, model_path: str = "data/vad/silero_vad.onnx", sample_rate: int = 16000): + """ + Initialize Silero VAD. + + Args: + model_path: Path to Silero VAD ONNX model + sample_rate: Audio sample rate (must be 16kHz for Silero VAD) + """ + self.sample_rate = sample_rate + self.model_path = model_path + + # Check if model exists + if not os.path.exists(model_path): + logger.warning(f"VAD model not found at {model_path}. VAD will be disabled.") + self.session = None + return + + # Check if onnxruntime is available + if not ONNX_AVAILABLE: + logger.warning("onnxruntime not available - VAD will be disabled") + self.session = None + return + + # Load ONNX model + try: + self.session = ort.InferenceSession(model_path) + logger.info(f"Loaded Silero VAD model from {model_path}") + except Exception as e: + logger.error(f"Failed to load VAD model: {e}") + self.session = None + return + + # Internal state for VAD + self._reset_state() + self.buffer = np.array([], dtype=np.float32) + self.min_chunk_size = 512 + self.last_label = "Silence" + self.last_probability = 0.0 + self._energy_noise_floor = 1e-4 + + def _reset_state(self): + # Silero VAD V4+ expects state shape [2, 1, 128] + self._state = np.zeros((2, 1, 128), dtype=np.float32) + self._sr = np.array([self.sample_rate], dtype=np.int64) + + def process_audio(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Tuple[str, float]: + """ + Process audio chunk and detect speech. + + Args: + pcm_bytes: PCM audio data (16-bit, mono, 16kHz) + chunk_size_ms: Chunk duration in milliseconds (ignored for buffering logic) + + Returns: + Tuple of (label, probability) where label is "Speech" or "Silence" + """ + if self.session is None or not ONNX_AVAILABLE: + # Fallback energy-based VAD with adaptive noise floor. + if not pcm_bytes: + return "Silence", 0.0 + audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16) + if audio_int16.size == 0: + return "Silence", 0.0 + audio_float = audio_int16.astype(np.float32) / 32768.0 + rms = float(np.sqrt(np.mean(audio_float * audio_float))) + + # Update adaptive noise floor (slowly rises, faster to fall) + if rms < self._energy_noise_floor: + self._energy_noise_floor = 0.95 * self._energy_noise_floor + 0.05 * rms + else: + self._energy_noise_floor = 0.995 * self._energy_noise_floor + 0.005 * rms + + # Compute SNR-like ratio and map to probability + denom = max(self._energy_noise_floor, 1e-6) + snr = max(0.0, (rms - denom) / denom) + probability = min(1.0, snr / 3.0) # ~3x above noise => strong speech + label = "Speech" if probability >= 0.5 else "Silence" + return label, probability + + # Convert bytes to numpy array of int16 + audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16) + + # Normalize to float32 (-1.0 to 1.0) + audio_float = audio_int16.astype(np.float32) / 32768.0 + + # Add to buffer + self.buffer = np.concatenate((self.buffer, audio_float)) + + # Process all complete chunks in the buffer + processed_any = False + while len(self.buffer) >= self.min_chunk_size: + # Slice exactly 512 samples + chunk = self.buffer[:self.min_chunk_size] + self.buffer = self.buffer[self.min_chunk_size:] + + # Prepare inputs + # Input tensor shape: [batch, samples] -> [1, 512] + input_tensor = chunk.reshape(1, -1) + + # Run inference + try: + ort_inputs = { + 'input': input_tensor, + 'state': self._state, + 'sr': self._sr + } + + # Outputs: probability, state + out, self._state = self.session.run(None, ort_inputs) + + # Get probability + self.last_probability = float(out[0][0]) + self.last_label = "Speech" if self.last_probability >= 0.5 else "Silence" + processed_any = True + + except Exception as e: + logger.error(f"VAD inference error: {e}") + # Try to determine if it's an input name issue + try: + inputs = [x.name for x in self.session.get_inputs()] + logger.error(f"Model expects inputs: {inputs}") + except: + pass + return "Speech", 1.0 + + return self.last_label, self.last_probability + + def reset(self) -> None: + """Reset VAD internal state.""" + self._reset_state() + self.buffer = np.array([], dtype=np.float32) + self.last_label = "Silence" + self.last_probability = 0.0 + + +class VADProcessor: + """ + High-level VAD processor with state management. + + Tracks speech/silence state and emits events on transitions. + """ + + def __init__(self, vad_model: SileroVAD, threshold: float = 0.5): + """ + Initialize VAD processor. + + Args: + vad_model: Silero VAD model instance + threshold: Speech detection threshold + """ + self.vad = vad_model + self.threshold = threshold + self.is_speaking = False + self.speech_start_time: Optional[float] = None + self.silence_start_time: Optional[float] = None + + def process(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Optional[Tuple[str, float]]: + """ + Process audio chunk and detect state changes. + + Args: + pcm_bytes: PCM audio data + chunk_size_ms: Chunk duration in milliseconds + + Returns: + Tuple of (event_type, probability) if state changed, None otherwise + """ + label, probability = self.vad.process_audio(pcm_bytes, chunk_size_ms) + + # Check if this is speech based on threshold + is_speech = probability >= self.threshold + + # State transition: Silence -> Speech + if is_speech and not self.is_speaking: + self.is_speaking = True + self.speech_start_time = asyncio.get_event_loop().time() + self.silence_start_time = None + return ("speaking", probability) + + # State transition: Speech -> Silence + elif not is_speech and self.is_speaking: + self.is_speaking = False + self.silence_start_time = asyncio.get_event_loop().time() + self.speech_start_time = None + return ("silence", probability) + + return None + + def reset(self) -> None: + """Reset VAD state.""" + self.vad.reset() + self.is_speaking = False + self.speech_start_time = None + self.silence_start_time = None diff --git a/engine/pyproject.toml b/engine/pyproject.toml new file mode 100644 index 0000000..8786905 --- /dev/null +++ b/engine/pyproject.toml @@ -0,0 +1,134 @@ +[build-system] +requires = ["setuptools>=68.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "py-active-call-cc" +version = "0.1.0" +description = "Python Active-Call: Real-time audio streaming with WebSocket and WebRTC" +readme = "README.md" +requires-python = ">=3.11" +license = {text = "MIT"} +authors = [ + {name = "Your Name", email = "your.email@example.com"} +] +keywords = ["webrtc", "websocket", "audio", "voip", "real-time"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Communications :: Telephony", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +[project.urls] +Homepage = "https://github.com/yourusername/py-active-call-cc" +Documentation = "https://github.com/yourusername/py-active-call-cc/blob/main/README.md" +Repository = "https://github.com/yourusername/py-active-call-cc.git" +Issues = "https://github.com/yourusername/py-active-call-cc/issues" + +[tool.setuptools.packages.find] +where = ["."] +include = ["app*"] +exclude = ["tests*", "scripts*", "reference*"] + +[tool.black] +line-length = 100 +target-version = ['py311'] +include = '\.pyi?$' +extend-exclude = ''' +/( + # directories + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | build + | dist + | reference +)/ +''' + +[tool.ruff] +line-length = 100 +target-version = "py311" +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade +] +ignore = [ + "E501", # line too long (handled by black) + "B008", # do not perform function calls in argument defaults +] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "reference", +] + +[tool.ruff.per-file-ignores] +"__init__.py" = ["F401"] # unused imports + +[tool.mypy] +python_version = "3.11" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false +disallow_incomplete_defs = false +check_untyped_defs = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +strict_equality = true +exclude = [ + "venv", + "reference", + "build", + "dist", +] + +[[tool.mypy.overrides]] +module = [ + "aiortc.*", + "av.*", + "onnxruntime.*", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +minversion = "7.0" +addopts = "-ra -q --strict-markers --strict-config" +testpaths = ["tests"] +pythonpath = ["."] +asyncio_mode = "auto" +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "integration: marks tests as integration tests", +] diff --git a/engine/requirements.txt b/engine/requirements.txt new file mode 100644 index 0000000..3d38414 --- /dev/null +++ b/engine/requirements.txt @@ -0,0 +1,37 @@ +# Web Framework +fastapi>=0.109.0 +uvicorn[standard]>=0.27.0 +websockets>=12.0 +python-multipart>=0.0.6 + +# WebRTC (optional - for WebRTC transport) +aiortc>=1.6.0 + +# Audio Processing +av>=12.1.0 +numpy>=1.26.3 +onnxruntime>=1.16.3 + +# Configuration +pydantic>=2.5.3 +pydantic-settings>=2.1.0 +python-dotenv>=1.0.0 +toml>=0.10.2 + +# Logging +loguru>=0.7.2 + +# HTTP Client +aiohttp>=3.9.1 + +# AI Services - LLM +openai>=1.0.0 + +# AI Services - TTS +edge-tts>=6.1.0 +pydub>=0.25.0 # For audio format conversion + +# Microphone client dependencies +sounddevice>=0.4.6 +soundfile>=0.12.1 +pyaudio>=0.2.13 # More reliable audio on Windows diff --git a/engine/scripts/README.md b/engine/scripts/README.md new file mode 100644 index 0000000..8b6f7a0 --- /dev/null +++ b/engine/scripts/README.md @@ -0,0 +1 @@ +# Development Script \ No newline at end of file diff --git a/engine/scripts/generate_test_audio/generate_test_audio.py b/engine/scripts/generate_test_audio/generate_test_audio.py new file mode 100644 index 0000000..9b37f5f --- /dev/null +++ b/engine/scripts/generate_test_audio/generate_test_audio.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 +""" +Generate test audio file with utterances using SiliconFlow TTS API. + +Creates a 16kHz mono WAV file with real speech segments separated by +configurable silence (for VAD/testing). + +Usage: + python generate_test_audio.py [OPTIONS] + +Options: + -o, --output PATH Output WAV path (default: data/audio_examples/two_utterances_16k.wav) + -u, --utterance TEXT Utterance text; repeat for multiple (ignored if -j is set) + -j, --json PATH JSON file: array of strings or {"utterances": [...]} + --silence-ms MS Silence in ms between utterances (default: 500) + --lead-silence-ms MS Silence in ms at start (default: 200) + --trail-silence-ms MS Silence in ms at end (default: 300) + +Examples: + # Default utterances and output + python generate_test_audio.py + + # Custom output path + python generate_test_audio.py -o out.wav + + # Utterances from command line + python generate_test_audio.py -u "Hello" -u "World" -o test.wav + + # Utterancgenerate_test_audio.py -j utterances.json -o test.wav + + # Custom silence (1s between utterances) + python generate_test_audio.py -u "One" -u "Two" --silence-ms 1000 -o test.wav + +Requires SILICONFLOW_API_KEY in .env. +""" + +import wave +import struct +import argparse +import asyncio +import aiohttp +import json +import os +from pathlib import Path +from dotenv import load_dotenv + + +# Load .env file from project root +project_root = Path(__file__).parent.parent.parent +load_dotenv(project_root / ".env") + + +# SiliconFlow TTS Configuration +SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/audio/speech" +SILICONFLOW_MODEL = "FunAudioLLM/CosyVoice2-0.5B" + +# Available voices +VOICES = { + "alex": "FunAudioLLM/CosyVoice2-0.5B:alex", + "anna": "FunAudioLLM/CosyVoice2-0.5B:anna", + "bella": "FunAudioLLM/CosyVoice2-0.5B:bella", + "benjamin": "FunAudioLLM/CosyVoice2-0.5B:benjamin", + "charles": "FunAudioLLM/CosyVoice2-0.5B:charles", + "claire": "FunAudioLLM/CosyVoice2-0.5B:claire", + "david": "FunAudioLLM/CosyVoice2-0.5B:david", + "diana": "FunAudioLLM/CosyVoice2-0.5B:diana", +} + + +def generate_silence(duration_ms: int, sample_rate: int = 16000) -> bytes: + """Generate silence as PCM bytes.""" + num_samples = int(sample_rate * (duration_ms / 1000.0)) + return b'\x00\x00' * num_samples + + +async def synthesize_speech( + text: str, + api_key: str, + voice: str = "anna", + sample_rate: int = 16000, + speed: float = 1.0 +) -> bytes: + """ + Synthesize speech using SiliconFlow TTS API. + + Args: + text: Text to synthesize + api_key: SiliconFlow API key + voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana) + sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100) + speed: Speech speed (0.25 to 4.0) + + Returns: + PCM audio bytes (16-bit signed, little-endian) + """ + # Resolve voice name + full_voice = VOICES.get(voice, voice) + + payload = { + "model": SILICONFLOW_MODEL, + "input": text, + "voice": full_voice, + "response_format": "pcm", + "sample_rate": sample_rate, + "stream": False, + "speed": speed + } + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + async with aiohttp.ClientSession() as session: + async with session.post(SILICONFLOW_API_URL, json=payload, headers=headers) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"SiliconFlow TTS error: {response.status} - {error_text}") + + return await response.read() + + +async def generate_test_audio( + output_path: str, + utterances: list[str], + silence_ms: int = 500, + lead_silence_ms: int = 200, + trail_silence_ms: int = 300, + voice: str = "anna", + sample_rate: int = 16000, + speed: float = 1.0 +): + """ + Generate test audio with multiple utterances separated by silence. + + Args: + output_path: Path to save the WAV file + utterances: List of text strings for each utterance + silence_ms: Silence duration between utterances (milliseconds) + lead_silence_ms: Silence at the beginning (milliseconds) + trail_silence_ms: Silence at the end (milliseconds) + voice: TTS voice to use + sample_rate: Audio sample rate + speed: TTS speech speed + """ + api_key = os.getenv("SILICONFLOW_API_KEY") + if not api_key: + raise ValueError( + "SILICONFLOW_API_KEY not found in environment.\n" + "Please set it in your .env file:\n" + " SILICONFLOW_API_KEY=your-api-key-here" + ) + + print(f"Using SiliconFlow TTS API") + print(f" Voice: {voice}") + print(f" Sample rate: {sample_rate}Hz") + print(f" Speed: {speed}x") + print() + + segments = [] + + # Lead-in silence + if lead_silence_ms > 0: + segments.append(generate_silence(lead_silence_ms, sample_rate)) + print(f" [silence: {lead_silence_ms}ms]") + + # Generate each utterance with silence between + for i, text in enumerate(utterances): + print(f" Synthesizing utterance {i + 1}: \"{text}\"") + audio = await synthesize_speech( + text=text, + api_key=api_key, + voice=voice, + sample_rate=sample_rate, + speed=speed + ) + segments.append(audio) + + # Add silence between utterances (not after the last one) + if i < len(utterances) - 1: + segments.append(generate_silence(silence_ms, sample_rate)) + print(f" [silence: {silence_ms}ms]") + + # Trail silence + if trail_silence_ms > 0: + segments.append(generate_silence(trail_silence_ms, sample_rate)) + print(f" [silence: {trail_silence_ms}ms]") + + # Concatenate all segments + audio_data = b''.join(segments) + + # Write WAV file + with wave.open(output_path, 'wb') as wf: + wf.setnchannels(1) # Mono + wf.setsampwidth(2) # 16-bit + wf.setframerate(sample_rate) + wf.writeframes(audio_data) + + duration_sec = len(audio_data) / (sample_rate * 2) + print() + print(f"Generated: {output_path}") + print(f" Duration: {duration_sec:.2f}s") + print(f" Sample rate: {sample_rate}Hz") + print(f" Format: 16-bit mono PCM WAV") + print(f" Size: {len(audio_data):,} bytes") + + +def load_utterances_from_json(path: Path) -> list[str]: + """ + Load utterances from a JSON file. + + Accepts either: + - A JSON array: ["utterance 1", "utterance 2"] + - A JSON object with "utterances" key: {"utterances": ["a", "b"]} + """ + with open(path, encoding="utf-8") as f: + data = json.load(f) + if isinstance(data, list): + return [str(s) for s in data] + if isinstance(data, dict) and "utterances" in data: + return [str(s) for s in data["utterances"]] + raise ValueError( + f"JSON file must be an array of strings or an object with 'utterances' key. " + f"Got: {type(data).__name__}" + ) + + +def parse_args(): + """Parse command-line arguments.""" + script_dir = Path(__file__).parent + default_output = script_dir.parent / "data" / "audio_examples" / "two_utterances_16k.wav" + + parser = argparse.ArgumentParser(description="Generate test audio with SiliconFlow TTS (utterances + silence).") + parser.add_argument( + "-o", "--output", + type=Path, + default=default_output, + help=f"Output WAV file path (default: {default_output})" + ) + parser.add_argument( + "-u", "--utterance", + action="append", + dest="utterances", + metavar="TEXT", + help="Utterance text (repeat for multiple). Ignored if --json is set." + ) + parser.add_argument( + "-j", "--json", + type=Path, + metavar="PATH", + help="JSON file with utterances: array of strings or object with 'utterances' key" + ) + parser.add_argument( + "--silence-ms", + type=int, + default=500, + metavar="MS", + help="Silence in ms between utterances (default: 500)" + ) + parser.add_argument( + "--lead-silence-ms", + type=int, + default=200, + metavar="MS", + help="Silence in ms at start of file (default: 200)" + ) + parser.add_argument( + "--trail-silence-ms", + type=int, + default=300, + metavar="MS", + help="Silence in ms at end of file (default: 300)" + ) + return parser.parse_args() + + +async def main(): + """Main entry point.""" + args = parse_args() + output_path = args.output + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Resolve utterances: JSON file > -u args > defaults + if args.json is not None: + if not args.json.is_file(): + raise FileNotFoundError(f"Utterances JSON file not found: {args.json}") + utterances = load_utterances_from_json(args.json) + if not utterances: + raise ValueError(f"JSON file has no utterances: {args.json}") + elif args.utterances: + utterances = args.utterances + else: + utterances = [ + "Hello, how are you doing today?", + "I'm doing great, thank you for asking!" + ] + + await generate_test_audio( + output_path=str(output_path), + utterances=utterances, + silence_ms=args.silence_ms, + lead_silence_ms=args.lead_silence_ms, + trail_silence_ms=args.trail_silence_ms, + voice="anna", + sample_rate=16000, + speed=1.0 + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/engine/services/__init__.py b/engine/services/__init__.py new file mode 100644 index 0000000..5a53e70 --- /dev/null +++ b/engine/services/__init__.py @@ -0,0 +1,47 @@ +"""AI Services package. + +Provides ASR, LLM, TTS, and Realtime API services for voice conversation. +""" + +from services.base import ( + ServiceState, + ASRResult, + LLMMessage, + TTSChunk, + BaseASRService, + BaseLLMService, + BaseTTSService, +) +from services.llm import OpenAILLMService, MockLLMService +from services.tts import EdgeTTSService, MockTTSService +from services.asr import BufferedASRService, MockASRService +from services.siliconflow_asr import SiliconFlowASRService +from services.siliconflow_tts import SiliconFlowTTSService +from services.realtime import RealtimeService, RealtimeConfig, RealtimePipeline + +__all__ = [ + # Base classes + "ServiceState", + "ASRResult", + "LLMMessage", + "TTSChunk", + "BaseASRService", + "BaseLLMService", + "BaseTTSService", + # LLM + "OpenAILLMService", + "MockLLMService", + # TTS + "EdgeTTSService", + "MockTTSService", + # ASR + "BufferedASRService", + "MockASRService", + "SiliconFlowASRService", + # TTS (SiliconFlow) + "SiliconFlowTTSService", + # Realtime + "RealtimeService", + "RealtimeConfig", + "RealtimePipeline", +] diff --git a/engine/services/asr.py b/engine/services/asr.py new file mode 100644 index 0000000..51ab584 --- /dev/null +++ b/engine/services/asr.py @@ -0,0 +1,147 @@ +"""ASR (Automatic Speech Recognition) Service implementations. + +Provides speech-to-text capabilities with streaming support. +""" + +import os +import asyncio +import json +from typing import AsyncIterator, Optional +from loguru import logger + +from services.base import BaseASRService, ASRResult, ServiceState + +# Try to import websockets for streaming ASR +try: + import websockets + WEBSOCKETS_AVAILABLE = True +except ImportError: + WEBSOCKETS_AVAILABLE = False + + +class BufferedASRService(BaseASRService): + """ + Buffered ASR service that accumulates audio and provides + a simple text accumulator for use with EOU detection. + + This is a lightweight implementation that works with the + existing VAD + EOU pattern without requiring external ASR. + """ + + def __init__( + self, + sample_rate: int = 16000, + language: str = "en" + ): + super().__init__(sample_rate=sample_rate, language=language) + + self._audio_buffer: bytes = b"" + self._current_text: str = "" + self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue() + + async def connect(self) -> None: + """No connection needed for buffered ASR.""" + self.state = ServiceState.CONNECTED + logger.info("Buffered ASR service connected") + + async def disconnect(self) -> None: + """Clear buffers on disconnect.""" + self._audio_buffer = b"" + self._current_text = "" + self.state = ServiceState.DISCONNECTED + logger.info("Buffered ASR service disconnected") + + async def send_audio(self, audio: bytes) -> None: + """Buffer audio for later processing.""" + self._audio_buffer += audio + + async def receive_transcripts(self) -> AsyncIterator[ASRResult]: + """Yield transcription results.""" + while True: + try: + result = await asyncio.wait_for( + self._transcript_queue.get(), + timeout=0.1 + ) + yield result + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + def set_text(self, text: str) -> None: + """ + Set the current transcript text directly. + + This allows external integration (e.g., Whisper, other ASR) + to provide transcripts. + """ + self._current_text = text + result = ASRResult(text=text, is_final=False) + asyncio.create_task(self._transcript_queue.put(result)) + + def get_and_clear_text(self) -> str: + """Get accumulated text and clear buffer.""" + text = self._current_text + self._current_text = "" + self._audio_buffer = b"" + return text + + def get_audio_buffer(self) -> bytes: + """Get accumulated audio buffer.""" + return self._audio_buffer + + def clear_audio_buffer(self) -> None: + """Clear audio buffer.""" + self._audio_buffer = b"" + + +class MockASRService(BaseASRService): + """ + Mock ASR service for testing without actual recognition. + """ + + def __init__(self, sample_rate: int = 16000, language: str = "en"): + super().__init__(sample_rate=sample_rate, language=language) + self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue() + self._mock_texts = [ + "Hello, how are you?", + "That's interesting.", + "Tell me more about that.", + "I understand.", + ] + self._text_index = 0 + + async def connect(self) -> None: + self.state = ServiceState.CONNECTED + logger.info("Mock ASR service connected") + + async def disconnect(self) -> None: + self.state = ServiceState.DISCONNECTED + logger.info("Mock ASR service disconnected") + + async def send_audio(self, audio: bytes) -> None: + """Mock audio processing - generates fake transcripts periodically.""" + pass + + def trigger_transcript(self) -> None: + """Manually trigger a transcript (for testing).""" + text = self._mock_texts[self._text_index % len(self._mock_texts)] + self._text_index += 1 + + result = ASRResult(text=text, is_final=True, confidence=0.95) + asyncio.create_task(self._transcript_queue.put(result)) + + async def receive_transcripts(self) -> AsyncIterator[ASRResult]: + """Yield transcription results.""" + while True: + try: + result = await asyncio.wait_for( + self._transcript_queue.get(), + timeout=0.1 + ) + yield result + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break diff --git a/engine/services/base.py b/engine/services/base.py new file mode 100644 index 0000000..420428b --- /dev/null +++ b/engine/services/base.py @@ -0,0 +1,244 @@ +"""Base classes for AI services. + +Defines abstract interfaces for ASR, LLM, and TTS services, +inspired by pipecat's service architecture and active-call's +StreamEngine pattern. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import AsyncIterator, Optional, List, Dict, Any +from enum import Enum + + +class ServiceState(Enum): + """Service connection state.""" + DISCONNECTED = "disconnected" + CONNECTING = "connecting" + CONNECTED = "connected" + ERROR = "error" + + +@dataclass +class ASRResult: + """ASR transcription result.""" + text: str + is_final: bool = False + confidence: float = 1.0 + language: Optional[str] = None + start_time: Optional[float] = None + end_time: Optional[float] = None + + def __str__(self) -> str: + status = "FINAL" if self.is_final else "PARTIAL" + return f"[{status}] {self.text}" + + +@dataclass +class LLMMessage: + """LLM conversation message.""" + role: str # "system", "user", "assistant", "function" + content: str + name: Optional[str] = None # For function calls + function_call: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to API-compatible dict.""" + d = {"role": self.role, "content": self.content} + if self.name: + d["name"] = self.name + if self.function_call: + d["function_call"] = self.function_call + return d + + +@dataclass +class TTSChunk: + """TTS audio chunk.""" + audio: bytes # PCM audio data + sample_rate: int = 16000 + channels: int = 1 + bits_per_sample: int = 16 + is_final: bool = False + text_offset: Optional[int] = None # Character offset in original text + + +class BaseASRService(ABC): + """ + Abstract base class for ASR (Speech-to-Text) services. + + Supports both streaming and non-streaming transcription. + """ + + def __init__(self, sample_rate: int = 16000, language: str = "en"): + self.sample_rate = sample_rate + self.language = language + self.state = ServiceState.DISCONNECTED + + @abstractmethod + async def connect(self) -> None: + """Establish connection to ASR service.""" + pass + + @abstractmethod + async def disconnect(self) -> None: + """Close connection to ASR service.""" + pass + + @abstractmethod + async def send_audio(self, audio: bytes) -> None: + """ + Send audio chunk for transcription. + + Args: + audio: PCM audio data (16-bit, mono) + """ + pass + + @abstractmethod + async def receive_transcripts(self) -> AsyncIterator[ASRResult]: + """ + Receive transcription results. + + Yields: + ASRResult objects as they become available + """ + pass + + async def transcribe(self, audio: bytes) -> ASRResult: + """ + Transcribe a complete audio buffer (non-streaming). + + Args: + audio: Complete PCM audio data + + Returns: + Final ASRResult + """ + # Default implementation using streaming + await self.send_audio(audio) + async for result in self.receive_transcripts(): + if result.is_final: + return result + return ASRResult(text="", is_final=True) + + +class BaseLLMService(ABC): + """ + Abstract base class for LLM (Language Model) services. + + Supports streaming responses for real-time conversation. + """ + + def __init__(self, model: str = "gpt-4"): + self.model = model + self.state = ServiceState.DISCONNECTED + + @abstractmethod + async def connect(self) -> None: + """Initialize LLM service connection.""" + pass + + @abstractmethod + async def disconnect(self) -> None: + """Close LLM service connection.""" + pass + + @abstractmethod + async def generate( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None + ) -> str: + """ + Generate a complete response. + + Args: + messages: Conversation history + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + + Returns: + Complete response text + """ + pass + + @abstractmethod + async def generate_stream( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None + ) -> AsyncIterator[str]: + """ + Generate response in streaming mode. + + Args: + messages: Conversation history + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + + Yields: + Text chunks as they are generated + """ + pass + + +class BaseTTSService(ABC): + """ + Abstract base class for TTS (Text-to-Speech) services. + + Supports streaming audio synthesis for low-latency playback. + """ + + def __init__( + self, + voice: str = "default", + sample_rate: int = 16000, + speed: float = 1.0 + ): + self.voice = voice + self.sample_rate = sample_rate + self.speed = speed + self.state = ServiceState.DISCONNECTED + + @abstractmethod + async def connect(self) -> None: + """Initialize TTS service connection.""" + pass + + @abstractmethod + async def disconnect(self) -> None: + """Close TTS service connection.""" + pass + + @abstractmethod + async def synthesize(self, text: str) -> bytes: + """ + Synthesize complete audio for text (non-streaming). + + Args: + text: Text to synthesize + + Returns: + Complete PCM audio data + """ + pass + + @abstractmethod + async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]: + """ + Synthesize audio in streaming mode. + + Args: + text: Text to synthesize + + Yields: + TTSChunk objects as audio is generated + """ + pass + + async def cancel(self) -> None: + """Cancel ongoing synthesis (for barge-in support).""" + pass diff --git a/engine/services/llm.py b/engine/services/llm.py new file mode 100644 index 0000000..e1d99a8 --- /dev/null +++ b/engine/services/llm.py @@ -0,0 +1,239 @@ +"""LLM (Large Language Model) Service implementations. + +Provides OpenAI-compatible LLM integration with streaming support +for real-time voice conversation. +""" + +import os +import asyncio +from typing import AsyncIterator, Optional, List, Dict, Any +from loguru import logger + +from services.base import BaseLLMService, LLMMessage, ServiceState + +# Try to import openai +try: + from openai import AsyncOpenAI + OPENAI_AVAILABLE = True +except ImportError: + OPENAI_AVAILABLE = False + logger.warning("openai package not available - LLM service will be disabled") + + +class OpenAILLMService(BaseLLMService): + """ + OpenAI-compatible LLM service. + + Supports streaming responses for low-latency voice conversation. + Works with OpenAI API, Azure OpenAI, and compatible APIs. + """ + + def __init__( + self, + model: str = "gpt-4o-mini", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + system_prompt: Optional[str] = None + ): + """ + Initialize OpenAI LLM service. + + Args: + model: Model name (e.g., "gpt-4o-mini", "gpt-4o") + api_key: OpenAI API key (defaults to OPENAI_API_KEY env var) + base_url: Custom API base URL (for Azure or compatible APIs) + system_prompt: Default system prompt for conversations + """ + super().__init__(model=model) + + self.api_key = api_key or os.getenv("OPENAI_API_KEY") + self.base_url = base_url or os.getenv("OPENAI_API_URL") + self.system_prompt = system_prompt or ( + "You are a helpful, friendly voice assistant. " + "Keep your responses concise and conversational. " + "Respond naturally as if having a phone conversation." + ) + + self.client: Optional[AsyncOpenAI] = None + self._cancel_event = asyncio.Event() + + async def connect(self) -> None: + """Initialize OpenAI client.""" + if not OPENAI_AVAILABLE: + raise RuntimeError("openai package not installed") + + if not self.api_key: + raise ValueError("OpenAI API key not provided") + + self.client = AsyncOpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + self.state = ServiceState.CONNECTED + logger.info(f"OpenAI LLM service connected: model={self.model}") + + async def disconnect(self) -> None: + """Close OpenAI client.""" + if self.client: + await self.client.close() + self.client = None + self.state = ServiceState.DISCONNECTED + logger.info("OpenAI LLM service disconnected") + + def _prepare_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]: + """Prepare messages list with system prompt.""" + result = [] + + # Add system prompt if not already present + has_system = any(m.role == "system" for m in messages) + if not has_system and self.system_prompt: + result.append({"role": "system", "content": self.system_prompt}) + + # Add all messages + for msg in messages: + result.append(msg.to_dict()) + + return result + + async def generate( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None + ) -> str: + """ + Generate a complete response. + + Args: + messages: Conversation history + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + + Returns: + Complete response text + """ + if not self.client: + raise RuntimeError("LLM service not connected") + + prepared = self._prepare_messages(messages) + + try: + response = await self.client.chat.completions.create( + model=self.model, + messages=prepared, + temperature=temperature, + max_tokens=max_tokens + ) + + content = response.choices[0].message.content or "" + logger.debug(f"LLM response: {content[:100]}...") + return content + + except Exception as e: + logger.error(f"LLM generation error: {e}") + raise + + async def generate_stream( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None + ) -> AsyncIterator[str]: + """ + Generate response in streaming mode. + + Args: + messages: Conversation history + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + + Yields: + Text chunks as they are generated + """ + if not self.client: + raise RuntimeError("LLM service not connected") + + prepared = self._prepare_messages(messages) + self._cancel_event.clear() + + try: + stream = await self.client.chat.completions.create( + model=self.model, + messages=prepared, + temperature=temperature, + max_tokens=max_tokens, + stream=True + ) + + async for chunk in stream: + # Check for cancellation + if self._cancel_event.is_set(): + logger.info("LLM stream cancelled") + break + + if chunk.choices and chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + yield content + + except asyncio.CancelledError: + logger.info("LLM stream cancelled via asyncio") + raise + except Exception as e: + logger.error(f"LLM streaming error: {e}") + raise + + def cancel(self) -> None: + """Cancel ongoing generation.""" + self._cancel_event.set() + + +class MockLLMService(BaseLLMService): + """ + Mock LLM service for testing without API calls. + """ + + def __init__(self, response_delay: float = 0.5): + super().__init__(model="mock") + self.response_delay = response_delay + self.responses = [ + "Hello! How can I help you today?", + "That's an interesting question. Let me think about it.", + "I understand. Is there anything else you'd like to know?", + "Great! I'm here if you need anything else.", + ] + self._response_index = 0 + + async def connect(self) -> None: + self.state = ServiceState.CONNECTED + logger.info("Mock LLM service connected") + + async def disconnect(self) -> None: + self.state = ServiceState.DISCONNECTED + logger.info("Mock LLM service disconnected") + + async def generate( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None + ) -> str: + await asyncio.sleep(self.response_delay) + response = self.responses[self._response_index % len(self.responses)] + self._response_index += 1 + return response + + async def generate_stream( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None + ) -> AsyncIterator[str]: + response = await self.generate(messages, temperature, max_tokens) + + # Stream word by word + words = response.split() + for i, word in enumerate(words): + if i > 0: + yield " " + yield word + await asyncio.sleep(0.05) # Simulate streaming delay diff --git a/engine/services/realtime.py b/engine/services/realtime.py new file mode 100644 index 0000000..3fd95c1 --- /dev/null +++ b/engine/services/realtime.py @@ -0,0 +1,548 @@ +"""OpenAI Realtime API Service. + +Provides true duplex voice conversation using OpenAI's Realtime API, +similar to active-call's RealtimeProcessor. This bypasses the need for +separate ASR/LLM/TTS services by handling everything server-side. + +The Realtime API provides: +- Server-side VAD with turn detection +- Streaming speech-to-text +- Streaming LLM responses +- Streaming text-to-speech +- Function calling support +- Barge-in/interruption handling +""" + +import os +import asyncio +import json +import base64 +from typing import Optional, Dict, Any, Callable, Awaitable, List +from dataclasses import dataclass, field +from enum import Enum +from loguru import logger + +try: + import websockets + WEBSOCKETS_AVAILABLE = True +except ImportError: + WEBSOCKETS_AVAILABLE = False + logger.warning("websockets not available - Realtime API will be disabled") + + +class RealtimeState(Enum): + """Realtime API connection state.""" + DISCONNECTED = "disconnected" + CONNECTING = "connecting" + CONNECTED = "connected" + ERROR = "error" + + +@dataclass +class RealtimeConfig: + """Configuration for OpenAI Realtime API.""" + + # API Configuration + api_key: Optional[str] = None + model: str = "gpt-4o-realtime-preview" + endpoint: Optional[str] = None # For Azure or custom endpoints + + # Voice Configuration + voice: str = "alloy" # alloy, echo, shimmer, etc. + instructions: str = ( + "You are a helpful, friendly voice assistant. " + "Keep your responses concise and conversational." + ) + + # Turn Detection (Server-side VAD) + turn_detection: Optional[Dict[str, Any]] = field(default_factory=lambda: { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 500 + }) + + # Audio Configuration + input_audio_format: str = "pcm16" + output_audio_format: str = "pcm16" + + # Tools/Functions + tools: List[Dict[str, Any]] = field(default_factory=list) + + +class RealtimeService: + """ + OpenAI Realtime API service for true duplex voice conversation. + + This service handles the entire voice conversation pipeline: + 1. Audio input → Server-side VAD → Speech-to-text + 2. Text → LLM processing → Response generation + 3. Response → Text-to-speech → Audio output + + Events emitted: + - on_audio: Audio output from the assistant + - on_transcript: Text transcript (user or assistant) + - on_speech_started: User started speaking + - on_speech_stopped: User stopped speaking + - on_response_started: Assistant started responding + - on_response_done: Assistant finished responding + - on_function_call: Function call requested + - on_error: Error occurred + """ + + def __init__(self, config: Optional[RealtimeConfig] = None): + """ + Initialize Realtime API service. + + Args: + config: Realtime configuration (uses defaults if not provided) + """ + self.config = config or RealtimeConfig() + self.config.api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") + + self.state = RealtimeState.DISCONNECTED + self._ws = None + self._receive_task: Optional[asyncio.Task] = None + self._cancel_event = asyncio.Event() + + # Event callbacks + self._callbacks: Dict[str, List[Callable]] = { + "on_audio": [], + "on_transcript": [], + "on_speech_started": [], + "on_speech_stopped": [], + "on_response_started": [], + "on_response_done": [], + "on_function_call": [], + "on_error": [], + "on_interrupted": [], + } + + logger.debug(f"RealtimeService initialized with model={self.config.model}") + + def on(self, event: str, callback: Callable[..., Awaitable[None]]) -> None: + """ + Register event callback. + + Args: + event: Event name + callback: Async callback function + """ + if event in self._callbacks: + self._callbacks[event].append(callback) + + async def _emit(self, event: str, *args, **kwargs) -> None: + """Emit event to all registered callbacks.""" + for callback in self._callbacks.get(event, []): + try: + await callback(*args, **kwargs) + except Exception as e: + logger.error(f"Event callback error ({event}): {e}") + + async def connect(self) -> None: + """Connect to OpenAI Realtime API.""" + if not WEBSOCKETS_AVAILABLE: + raise RuntimeError("websockets package not installed") + + if not self.config.api_key: + raise ValueError("OpenAI API key not provided") + + self.state = RealtimeState.CONNECTING + + # Build URL + if self.config.endpoint: + # Azure or custom endpoint + url = f"{self.config.endpoint}/openai/realtime?api-version=2024-10-01-preview&deployment={self.config.model}" + else: + # OpenAI endpoint + url = f"wss://api.openai.com/v1/realtime?model={self.config.model}" + + # Build headers + headers = {} + if self.config.endpoint: + headers["api-key"] = self.config.api_key + else: + headers["Authorization"] = f"Bearer {self.config.api_key}" + headers["OpenAI-Beta"] = "realtime=v1" + + try: + logger.info(f"Connecting to Realtime API: {url}") + self._ws = await websockets.connect(url, extra_headers=headers) + + # Send session configuration + await self._configure_session() + + # Start receive loop + self._receive_task = asyncio.create_task(self._receive_loop()) + + self.state = RealtimeState.CONNECTED + logger.info("Realtime API connected successfully") + + except Exception as e: + self.state = RealtimeState.ERROR + logger.error(f"Realtime API connection failed: {e}") + raise + + async def _configure_session(self) -> None: + """Send session configuration to server.""" + session_config = { + "type": "session.update", + "session": { + "modalities": ["text", "audio"], + "instructions": self.config.instructions, + "voice": self.config.voice, + "input_audio_format": self.config.input_audio_format, + "output_audio_format": self.config.output_audio_format, + "turn_detection": self.config.turn_detection, + } + } + + if self.config.tools: + session_config["session"]["tools"] = self.config.tools + + await self._send(session_config) + logger.debug("Session configuration sent") + + async def _send(self, data: Dict[str, Any]) -> None: + """Send JSON data to server.""" + if self._ws: + await self._ws.send(json.dumps(data)) + + async def send_audio(self, audio_bytes: bytes) -> None: + """ + Send audio to the Realtime API. + + Args: + audio_bytes: PCM audio data (16-bit, mono, 24kHz by default) + """ + if self.state != RealtimeState.CONNECTED: + return + + # Encode audio as base64 + audio_b64 = base64.standard_b64encode(audio_bytes).decode() + + await self._send({ + "type": "input_audio_buffer.append", + "audio": audio_b64 + }) + + async def send_text(self, text: str) -> None: + """ + Send text input (bypassing audio). + + Args: + text: User text input + """ + if self.state != RealtimeState.CONNECTED: + return + + # Create a conversation item with user text + await self._send({ + "type": "conversation.item.create", + "item": { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": text}] + } + }) + + # Trigger response + await self._send({"type": "response.create"}) + + async def cancel_response(self) -> None: + """Cancel the current response (for barge-in).""" + if self.state != RealtimeState.CONNECTED: + return + + await self._send({"type": "response.cancel"}) + logger.debug("Response cancelled") + + async def commit_audio(self) -> None: + """Commit the audio buffer and trigger response.""" + if self.state != RealtimeState.CONNECTED: + return + + await self._send({"type": "input_audio_buffer.commit"}) + await self._send({"type": "response.create"}) + + async def clear_audio_buffer(self) -> None: + """Clear the input audio buffer.""" + if self.state != RealtimeState.CONNECTED: + return + + await self._send({"type": "input_audio_buffer.clear"}) + + async def submit_function_result(self, call_id: str, result: str) -> None: + """ + Submit function call result. + + Args: + call_id: The function call ID + result: JSON string result + """ + if self.state != RealtimeState.CONNECTED: + return + + await self._send({ + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": call_id, + "output": result + } + }) + + # Trigger response with the function result + await self._send({"type": "response.create"}) + + async def _receive_loop(self) -> None: + """Receive and process messages from the Realtime API.""" + if not self._ws: + return + + try: + async for message in self._ws: + try: + data = json.loads(message) + await self._handle_event(data) + except json.JSONDecodeError: + logger.warning(f"Invalid JSON received: {message[:100]}") + + except asyncio.CancelledError: + logger.debug("Receive loop cancelled") + except websockets.ConnectionClosed as e: + logger.info(f"WebSocket closed: {e}") + self.state = RealtimeState.DISCONNECTED + except Exception as e: + logger.error(f"Receive loop error: {e}") + self.state = RealtimeState.ERROR + + async def _handle_event(self, data: Dict[str, Any]) -> None: + """Handle incoming event from Realtime API.""" + event_type = data.get("type", "unknown") + + # Audio delta - streaming audio output + if event_type == "response.audio.delta": + if "delta" in data: + audio_bytes = base64.standard_b64decode(data["delta"]) + await self._emit("on_audio", audio_bytes) + + # Audio transcript delta - streaming text + elif event_type == "response.audio_transcript.delta": + if "delta" in data: + await self._emit("on_transcript", data["delta"], "assistant", False) + + # Audio transcript done + elif event_type == "response.audio_transcript.done": + if "transcript" in data: + await self._emit("on_transcript", data["transcript"], "assistant", True) + + # Input audio transcript (user speech) + elif event_type == "conversation.item.input_audio_transcription.completed": + if "transcript" in data: + await self._emit("on_transcript", data["transcript"], "user", True) + + # Speech started (server VAD detected speech) + elif event_type == "input_audio_buffer.speech_started": + await self._emit("on_speech_started", data.get("audio_start_ms", 0)) + + # Speech stopped + elif event_type == "input_audio_buffer.speech_stopped": + await self._emit("on_speech_stopped", data.get("audio_end_ms", 0)) + + # Response started + elif event_type == "response.created": + await self._emit("on_response_started", data.get("response", {})) + + # Response done + elif event_type == "response.done": + await self._emit("on_response_done", data.get("response", {})) + + # Function call + elif event_type == "response.function_call_arguments.done": + call_id = data.get("call_id") + name = data.get("name") + arguments = data.get("arguments", "{}") + await self._emit("on_function_call", call_id, name, arguments) + + # Error + elif event_type == "error": + error = data.get("error", {}) + logger.error(f"Realtime API error: {error}") + await self._emit("on_error", error) + + # Session events + elif event_type == "session.created": + logger.info("Session created") + elif event_type == "session.updated": + logger.debug("Session updated") + + else: + logger.debug(f"Unhandled event type: {event_type}") + + async def disconnect(self) -> None: + """Disconnect from Realtime API.""" + self._cancel_event.set() + + if self._receive_task: + self._receive_task.cancel() + try: + await self._receive_task + except asyncio.CancelledError: + pass + + if self._ws: + await self._ws.close() + self._ws = None + + self.state = RealtimeState.DISCONNECTED + logger.info("Realtime API disconnected") + + +class RealtimePipeline: + """ + Pipeline adapter for RealtimeService. + + Provides a compatible interface with DuplexPipeline but uses + OpenAI Realtime API for all processing. + """ + + def __init__( + self, + transport, + session_id: str, + config: Optional[RealtimeConfig] = None + ): + """ + Initialize Realtime pipeline. + + Args: + transport: Transport for sending audio/events + session_id: Session identifier + config: Realtime configuration + """ + self.transport = transport + self.session_id = session_id + + self.service = RealtimeService(config) + + # Register callbacks + self.service.on("on_audio", self._on_audio) + self.service.on("on_transcript", self._on_transcript) + self.service.on("on_speech_started", self._on_speech_started) + self.service.on("on_speech_stopped", self._on_speech_stopped) + self.service.on("on_response_started", self._on_response_started) + self.service.on("on_response_done", self._on_response_done) + self.service.on("on_error", self._on_error) + + self._is_speaking = False + self._running = True + + logger.info(f"RealtimePipeline initialized for session {session_id}") + + async def start(self) -> None: + """Start the pipeline.""" + await self.service.connect() + + async def process_audio(self, pcm_bytes: bytes) -> None: + """ + Process incoming audio. + + Note: Realtime API expects 24kHz audio by default. + You may need to resample from 16kHz. + """ + if not self._running: + return + + # TODO: Resample from 16kHz to 24kHz if needed + await self.service.send_audio(pcm_bytes) + + async def process_text(self, text: str) -> None: + """Process text input.""" + if not self._running: + return + + await self.service.send_text(text) + + async def interrupt(self) -> None: + """Interrupt current response.""" + await self.service.cancel_response() + await self.transport.send_event({ + "event": "interrupt", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms() + }) + + async def cleanup(self) -> None: + """Cleanup resources.""" + self._running = False + await self.service.disconnect() + + # Event handlers + + async def _on_audio(self, audio_bytes: bytes) -> None: + """Handle audio output.""" + await self.transport.send_audio(audio_bytes) + + async def _on_transcript(self, text: str, role: str, is_final: bool) -> None: + """Handle transcript.""" + logger.info(f"[{role.upper()}] {text[:50]}..." if len(text) > 50 else f"[{role.upper()}] {text}") + + async def _on_speech_started(self, start_ms: int) -> None: + """Handle user speech start.""" + self._is_speaking = True + await self.transport.send_event({ + "event": "speaking", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms(), + "startTime": start_ms + }) + + # Cancel any ongoing response (barge-in) + await self.service.cancel_response() + + async def _on_speech_stopped(self, end_ms: int) -> None: + """Handle user speech stop.""" + self._is_speaking = False + await self.transport.send_event({ + "event": "silence", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms(), + "duration": end_ms + }) + + async def _on_response_started(self, response: Dict) -> None: + """Handle response start.""" + await self.transport.send_event({ + "event": "trackStart", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms() + }) + + async def _on_response_done(self, response: Dict) -> None: + """Handle response complete.""" + await self.transport.send_event({ + "event": "trackEnd", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms() + }) + + async def _on_error(self, error: Dict) -> None: + """Handle error.""" + await self.transport.send_event({ + "event": "error", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms(), + "sender": "realtime", + "error": str(error) + }) + + def _get_timestamp_ms(self) -> int: + """Get current timestamp in milliseconds.""" + import time + return int(time.time() * 1000) + + @property + def is_speaking(self) -> bool: + """Check if user is speaking.""" + return self._is_speaking diff --git a/engine/services/siliconflow_asr.py b/engine/services/siliconflow_asr.py new file mode 100644 index 0000000..6d67ad7 --- /dev/null +++ b/engine/services/siliconflow_asr.py @@ -0,0 +1,317 @@ +"""SiliconFlow ASR (Automatic Speech Recognition) Service. + +Uses the SiliconFlow API for speech-to-text transcription. +API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcriptions +""" + +import asyncio +import io +import wave +from typing import AsyncIterator, Optional, Callable, Awaitable +from loguru import logger + +try: + import aiohttp + AIOHTTP_AVAILABLE = True +except ImportError: + AIOHTTP_AVAILABLE = False + logger.warning("aiohttp not available - SiliconFlowASRService will not work") + +from services.base import BaseASRService, ASRResult, ServiceState + + +class SiliconFlowASRService(BaseASRService): + """ + SiliconFlow ASR service for speech-to-text transcription. + + Features: + - Buffers incoming audio chunks + - Provides interim transcriptions periodically (for streaming to client) + - Final transcription on EOU + + API Details: + - Endpoint: POST https://api.siliconflow.cn/v1/audio/transcriptions + - Models: FunAudioLLM/SenseVoiceSmall (default), TeleAI/TeleSpeechASR + - Input: Audio file (multipart/form-data) + - Output: {"text": "transcribed text"} + """ + + # Supported models + MODELS = { + "sensevoice": "FunAudioLLM/SenseVoiceSmall", + "telespeech": "TeleAI/TeleSpeechASR", + } + + API_URL = "https://api.siliconflow.cn/v1/audio/transcriptions" + + def __init__( + self, + api_key: str, + model: str = "FunAudioLLM/SenseVoiceSmall", + sample_rate: int = 16000, + language: str = "auto", + interim_interval_ms: int = 500, # How often to send interim results + min_audio_for_interim_ms: int = 300, # Min audio before first interim + on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None + ): + """ + Initialize SiliconFlow ASR service. + + Args: + api_key: SiliconFlow API key + model: ASR model name or alias + sample_rate: Audio sample rate (16000 recommended) + language: Language code (auto for automatic detection) + interim_interval_ms: How often to generate interim transcriptions + min_audio_for_interim_ms: Minimum audio duration before first interim + on_transcript: Callback for transcription results (text, is_final) + """ + super().__init__(sample_rate=sample_rate, language=language) + + if not AIOHTTP_AVAILABLE: + raise RuntimeError("aiohttp is required for SiliconFlowASRService") + + self.api_key = api_key + self.model = self.MODELS.get(model.lower(), model) + self.interim_interval_ms = interim_interval_ms + self.min_audio_for_interim_ms = min_audio_for_interim_ms + self.on_transcript = on_transcript + + # Session + self._session: Optional[aiohttp.ClientSession] = None + + # Audio buffer + self._audio_buffer: bytes = b"" + self._current_text: str = "" + self._last_interim_time: float = 0 + + # Transcript queue for async iteration + self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue() + + # Background task for interim results + self._interim_task: Optional[asyncio.Task] = None + self._running = False + + logger.info(f"SiliconFlowASRService initialized with model: {self.model}") + + async def connect(self) -> None: + """Connect to the service.""" + self._session = aiohttp.ClientSession( + headers={ + "Authorization": f"Bearer {self.api_key}" + } + ) + self._running = True + self.state = ServiceState.CONNECTED + logger.info("SiliconFlowASRService connected") + + async def disconnect(self) -> None: + """Disconnect and cleanup.""" + self._running = False + + if self._interim_task: + self._interim_task.cancel() + try: + await self._interim_task + except asyncio.CancelledError: + pass + self._interim_task = None + + if self._session: + await self._session.close() + self._session = None + + self._audio_buffer = b"" + self._current_text = "" + self.state = ServiceState.DISCONNECTED + logger.info("SiliconFlowASRService disconnected") + + async def send_audio(self, audio: bytes) -> None: + """ + Buffer incoming audio data. + + Args: + audio: PCM audio data (16-bit, mono) + """ + self._audio_buffer += audio + + async def transcribe_buffer(self, is_final: bool = False) -> Optional[str]: + """ + Transcribe current audio buffer. + + Args: + is_final: Whether this is the final transcription + + Returns: + Transcribed text or None if not enough audio + """ + if not self._session: + logger.warning("ASR session not connected") + return None + + # Check minimum audio duration + audio_duration_ms = len(self._audio_buffer) / (self.sample_rate * 2) * 1000 + + if not is_final and audio_duration_ms < self.min_audio_for_interim_ms: + return None + + if audio_duration_ms < 100: # Less than 100ms - too short + return None + + try: + # Convert PCM to WAV in memory + wav_buffer = io.BytesIO() + with wave.open(wav_buffer, 'wb') as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) # 16-bit + wav_file.setframerate(self.sample_rate) + wav_file.writeframes(self._audio_buffer) + + wav_buffer.seek(0) + wav_data = wav_buffer.read() + + # Send to API + form_data = aiohttp.FormData() + form_data.add_field( + 'file', + wav_data, + filename='audio.wav', + content_type='audio/wav' + ) + form_data.add_field('model', self.model) + + async with self._session.post(self.API_URL, data=form_data) as response: + if response.status == 200: + result = await response.json() + text = result.get("text", "").strip() + + if text: + self._current_text = text + + # Notify via callback + if self.on_transcript: + await self.on_transcript(text, is_final) + + # Queue result + await self._transcript_queue.put( + ASRResult(text=text, is_final=is_final) + ) + + logger.debug(f"ASR {'final' if is_final else 'interim'}: {text[:50]}...") + return text + else: + error_text = await response.text() + logger.error(f"ASR API error {response.status}: {error_text}") + return None + + except Exception as e: + logger.error(f"ASR transcription error: {e}") + return None + + async def get_final_transcription(self) -> str: + """ + Get final transcription and clear buffer. + + Call this when EOU is detected. + + Returns: + Final transcribed text + """ + # Transcribe full buffer as final + text = await self.transcribe_buffer(is_final=True) + + # Clear buffer + result = text or self._current_text + self._audio_buffer = b"" + self._current_text = "" + + return result + + def get_and_clear_text(self) -> str: + """ + Get accumulated text and clear buffer. + + Compatible with BufferedASRService interface. + """ + text = self._current_text + self._current_text = "" + self._audio_buffer = b"" + return text + + def get_audio_buffer(self) -> bytes: + """Get current audio buffer.""" + return self._audio_buffer + + def get_audio_duration_ms(self) -> float: + """Get current audio buffer duration in milliseconds.""" + return len(self._audio_buffer) / (self.sample_rate * 2) * 1000 + + def clear_buffer(self) -> None: + """Clear audio and text buffers.""" + self._audio_buffer = b"" + self._current_text = "" + + async def receive_transcripts(self) -> AsyncIterator[ASRResult]: + """ + Async iterator for transcription results. + + Yields: + ASRResult with text and is_final flag + """ + while self._running: + try: + result = await asyncio.wait_for( + self._transcript_queue.get(), + timeout=0.1 + ) + yield result + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + async def start_interim_transcription(self) -> None: + """ + Start background task for interim transcriptions. + + This periodically transcribes buffered audio for + real-time feedback to the user. + """ + if self._interim_task and not self._interim_task.done(): + return + + self._interim_task = asyncio.create_task(self._interim_loop()) + + async def stop_interim_transcription(self) -> None: + """Stop interim transcription task.""" + if self._interim_task: + self._interim_task.cancel() + try: + await self._interim_task + except asyncio.CancelledError: + pass + self._interim_task = None + + async def _interim_loop(self) -> None: + """Background loop for interim transcriptions.""" + import time + + while self._running: + try: + await asyncio.sleep(self.interim_interval_ms / 1000) + + # Check if we have enough new audio + current_time = time.time() + time_since_last = (current_time - self._last_interim_time) * 1000 + + if time_since_last >= self.interim_interval_ms: + audio_duration = self.get_audio_duration_ms() + + if audio_duration >= self.min_audio_for_interim_ms: + await self.transcribe_buffer(is_final=False) + self._last_interim_time = current_time + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Interim transcription error: {e}") diff --git a/engine/services/siliconflow_tts.py b/engine/services/siliconflow_tts.py new file mode 100644 index 0000000..38987f0 --- /dev/null +++ b/engine/services/siliconflow_tts.py @@ -0,0 +1,255 @@ +"""SiliconFlow TTS Service with streaming support. + +Uses SiliconFlow's CosyVoice2 or MOSS-TTSD models for low-latency +text-to-speech synthesis with streaming. + +API Docs: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech +""" + +import os +import asyncio +import aiohttp +from typing import AsyncIterator, Optional +from loguru import logger + +from services.base import BaseTTSService, TTSChunk, ServiceState + + +class SiliconFlowTTSService(BaseTTSService): + """ + SiliconFlow TTS service with streaming support. + + Supports CosyVoice2-0.5B and MOSS-TTSD-v0.5 models. + """ + + # Available voices + VOICES = { + "alex": "FunAudioLLM/CosyVoice2-0.5B:alex", + "anna": "FunAudioLLM/CosyVoice2-0.5B:anna", + "bella": "FunAudioLLM/CosyVoice2-0.5B:bella", + "benjamin": "FunAudioLLM/CosyVoice2-0.5B:benjamin", + "charles": "FunAudioLLM/CosyVoice2-0.5B:charles", + "claire": "FunAudioLLM/CosyVoice2-0.5B:claire", + "david": "FunAudioLLM/CosyVoice2-0.5B:david", + "diana": "FunAudioLLM/CosyVoice2-0.5B:diana", + } + + def __init__( + self, + api_key: Optional[str] = None, + voice: str = "anna", + model: str = "FunAudioLLM/CosyVoice2-0.5B", + sample_rate: int = 16000, + speed: float = 1.0 + ): + """ + Initialize SiliconFlow TTS service. + + Args: + api_key: SiliconFlow API key (defaults to SILICONFLOW_API_KEY env var) + voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana) + model: Model name + sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100) + speed: Speech speed (0.25 to 4.0) + """ + # Resolve voice name + if voice in self.VOICES: + full_voice = self.VOICES[voice] + else: + full_voice = voice + + super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed) + + self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY") + self.model = model + self.api_url = "https://api.siliconflow.cn/v1/audio/speech" + + self._session: Optional[aiohttp.ClientSession] = None + self._cancel_event = asyncio.Event() + + async def connect(self) -> None: + """Initialize HTTP session.""" + if not self.api_key: + raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.") + + self._session = aiohttp.ClientSession( + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + ) + self.state = ServiceState.CONNECTED + logger.info(f"SiliconFlow TTS service ready: voice={self.voice}, model={self.model}") + + async def disconnect(self) -> None: + """Close HTTP session.""" + if self._session: + await self._session.close() + self._session = None + self.state = ServiceState.DISCONNECTED + logger.info("SiliconFlow TTS service disconnected") + + async def synthesize(self, text: str) -> bytes: + """Synthesize complete audio for text.""" + audio_data = b"" + async for chunk in self.synthesize_stream(text): + audio_data += chunk.audio + return audio_data + + async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]: + """ + Synthesize audio in streaming mode. + + Args: + text: Text to synthesize + + Yields: + TTSChunk objects with PCM audio + """ + if not self._session: + raise RuntimeError("TTS service not connected") + + if not text.strip(): + return + + self._cancel_event.clear() + + payload = { + "model": self.model, + "input": text, + "voice": self.voice, + "response_format": "pcm", + "sample_rate": self.sample_rate, + "stream": True, + "speed": self.speed + } + + try: + async with self._session.post(self.api_url, json=payload) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"SiliconFlow TTS error: {response.status} - {error_text}") + return + + # Stream audio chunks + chunk_size = self.sample_rate * 2 // 10 # 100ms chunks + buffer = b"" + + async for chunk in response.content.iter_any(): + if self._cancel_event.is_set(): + logger.info("TTS synthesis cancelled") + return + + buffer += chunk + + # Yield complete chunks + while len(buffer) >= chunk_size: + audio_chunk = buffer[:chunk_size] + buffer = buffer[chunk_size:] + + yield TTSChunk( + audio=audio_chunk, + sample_rate=self.sample_rate, + is_final=False + ) + + # Yield remaining buffer + if buffer: + yield TTSChunk( + audio=buffer, + sample_rate=self.sample_rate, + is_final=True + ) + + except asyncio.CancelledError: + logger.info("TTS synthesis cancelled via asyncio") + raise + except Exception as e: + logger.error(f"TTS synthesis error: {e}") + raise + + async def cancel(self) -> None: + """Cancel ongoing synthesis.""" + self._cancel_event.set() + + +class StreamingTTSAdapter: + """ + Adapter for streaming LLM text to TTS with sentence-level chunking. + + This reduces latency by starting TTS as soon as a complete sentence + is received from the LLM, rather than waiting for the full response. + """ + + # Sentence delimiters + SENTENCE_ENDS = {'.', '!', '?', '。', '!', '?', ';', '\n'} + + def __init__(self, tts_service: BaseTTSService, transport, session_id: str): + self.tts_service = tts_service + self.transport = transport + self.session_id = session_id + self._buffer = "" + self._cancel_event = asyncio.Event() + self._is_speaking = False + + async def process_text_chunk(self, text_chunk: str) -> None: + """ + Process a text chunk from LLM and trigger TTS when sentence is complete. + + Args: + text_chunk: Text chunk from LLM streaming + """ + if self._cancel_event.is_set(): + return + + self._buffer += text_chunk + + # Check for sentence completion + for i, char in enumerate(self._buffer): + if char in self.SENTENCE_ENDS: + # Found sentence end, synthesize up to this point + sentence = self._buffer[:i+1].strip() + self._buffer = self._buffer[i+1:] + + if sentence: + await self._speak_sentence(sentence) + break + + async def flush(self) -> None: + """Flush remaining buffer.""" + if self._buffer.strip() and not self._cancel_event.is_set(): + await self._speak_sentence(self._buffer.strip()) + self._buffer = "" + + async def _speak_sentence(self, text: str) -> None: + """Synthesize and send a sentence.""" + if not text or self._cancel_event.is_set(): + return + + self._is_speaking = True + + try: + async for chunk in self.tts_service.synthesize_stream(text): + if self._cancel_event.is_set(): + break + await self.transport.send_audio(chunk.audio) + await asyncio.sleep(0.01) # Prevent flooding + except Exception as e: + logger.error(f"TTS speak error: {e}") + finally: + self._is_speaking = False + + def cancel(self) -> None: + """Cancel ongoing speech.""" + self._cancel_event.set() + self._buffer = "" + + def reset(self) -> None: + """Reset for new turn.""" + self._cancel_event.clear() + self._buffer = "" + self._is_speaking = False + + @property + def is_speaking(self) -> bool: + return self._is_speaking diff --git a/engine/services/tts.py b/engine/services/tts.py new file mode 100644 index 0000000..e838f08 --- /dev/null +++ b/engine/services/tts.py @@ -0,0 +1,271 @@ +"""TTS (Text-to-Speech) Service implementations. + +Provides multiple TTS backend options including edge-tts (free) +and placeholder for cloud services. +""" + +import os +import io +import asyncio +import struct +from typing import AsyncIterator, Optional +from loguru import logger + +from services.base import BaseTTSService, TTSChunk, ServiceState + +# Try to import edge-tts +try: + import edge_tts + EDGE_TTS_AVAILABLE = True +except ImportError: + EDGE_TTS_AVAILABLE = False + logger.warning("edge-tts not available - EdgeTTS service will be disabled") + + +class EdgeTTSService(BaseTTSService): + """ + Microsoft Edge TTS service. + + Uses edge-tts library for free, high-quality speech synthesis. + Supports streaming for low-latency playback. + """ + + # Voice mapping for common languages + VOICE_MAP = { + "en": "en-US-JennyNeural", + "en-US": "en-US-JennyNeural", + "en-GB": "en-GB-SoniaNeural", + "zh": "zh-CN-XiaoxiaoNeural", + "zh-CN": "zh-CN-XiaoxiaoNeural", + "zh-TW": "zh-TW-HsiaoChenNeural", + "ja": "ja-JP-NanamiNeural", + "ko": "ko-KR-SunHiNeural", + "fr": "fr-FR-DeniseNeural", + "de": "de-DE-KatjaNeural", + "es": "es-ES-ElviraNeural", + } + + def __init__( + self, + voice: str = "en-US-JennyNeural", + sample_rate: int = 16000, + speed: float = 1.0 + ): + """ + Initialize Edge TTS service. + + Args: + voice: Voice name (e.g., "en-US-JennyNeural") or language code (e.g., "en") + sample_rate: Target sample rate (will be resampled) + speed: Speech speed multiplier + """ + # Resolve voice from language code if needed + if voice in self.VOICE_MAP: + voice = self.VOICE_MAP[voice] + + super().__init__(voice=voice, sample_rate=sample_rate, speed=speed) + self._cancel_event = asyncio.Event() + + async def connect(self) -> None: + """Edge TTS doesn't require explicit connection.""" + if not EDGE_TTS_AVAILABLE: + raise RuntimeError("edge-tts package not installed") + self.state = ServiceState.CONNECTED + logger.info(f"Edge TTS service ready: voice={self.voice}") + + async def disconnect(self) -> None: + """Edge TTS doesn't require explicit disconnection.""" + self.state = ServiceState.DISCONNECTED + logger.info("Edge TTS service disconnected") + + def _get_rate_string(self) -> str: + """Convert speed to rate string for edge-tts.""" + # edge-tts uses percentage format: "+0%", "-10%", "+20%" + percentage = int((self.speed - 1.0) * 100) + if percentage >= 0: + return f"+{percentage}%" + return f"{percentage}%" + + async def synthesize(self, text: str) -> bytes: + """ + Synthesize complete audio for text. + + Args: + text: Text to synthesize + + Returns: + PCM audio data (16-bit, mono, 16kHz) + """ + if not EDGE_TTS_AVAILABLE: + raise RuntimeError("edge-tts not available") + + # Collect all chunks + audio_data = b"" + async for chunk in self.synthesize_stream(text): + audio_data += chunk.audio + + return audio_data + + async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]: + """ + Synthesize audio in streaming mode. + + Args: + text: Text to synthesize + + Yields: + TTSChunk objects with PCM audio + """ + if not EDGE_TTS_AVAILABLE: + raise RuntimeError("edge-tts not available") + + self._cancel_event.clear() + + try: + communicate = edge_tts.Communicate( + text, + voice=self.voice, + rate=self._get_rate_string() + ) + + # edge-tts outputs MP3, we need to decode to PCM + # For now, collect MP3 chunks and yield after conversion + mp3_data = b"" + + async for chunk in communicate.stream(): + # Check for cancellation + if self._cancel_event.is_set(): + logger.info("TTS synthesis cancelled") + return + + if chunk["type"] == "audio": + mp3_data += chunk["data"] + + # Convert MP3 to PCM + if mp3_data: + pcm_data = await self._convert_mp3_to_pcm(mp3_data) + if pcm_data: + # Yield in chunks for streaming playback + chunk_size = self.sample_rate * 2 // 10 # 100ms chunks + for i in range(0, len(pcm_data), chunk_size): + if self._cancel_event.is_set(): + return + + chunk_data = pcm_data[i:i + chunk_size] + yield TTSChunk( + audio=chunk_data, + sample_rate=self.sample_rate, + is_final=(i + chunk_size >= len(pcm_data)) + ) + + except asyncio.CancelledError: + logger.info("TTS synthesis cancelled via asyncio") + raise + except Exception as e: + logger.error(f"TTS synthesis error: {e}") + raise + + async def _convert_mp3_to_pcm(self, mp3_data: bytes) -> bytes: + """ + Convert MP3 audio to PCM. + + Uses pydub or ffmpeg for conversion. + """ + try: + # Try using pydub (requires ffmpeg) + from pydub import AudioSegment + + # Load MP3 from bytes + audio = AudioSegment.from_mp3(io.BytesIO(mp3_data)) + + # Convert to target format + audio = audio.set_frame_rate(self.sample_rate) + audio = audio.set_channels(1) + audio = audio.set_sample_width(2) # 16-bit + + # Export as raw PCM + return audio.raw_data + + except ImportError: + logger.warning("pydub not available, trying fallback") + # Fallback: Use subprocess to call ffmpeg directly + return await self._ffmpeg_convert(mp3_data) + except Exception as e: + logger.error(f"Audio conversion error: {e}") + return b"" + + async def _ffmpeg_convert(self, mp3_data: bytes) -> bytes: + """Convert MP3 to PCM using ffmpeg subprocess.""" + try: + process = await asyncio.create_subprocess_exec( + "ffmpeg", + "-i", "pipe:0", + "-f", "s16le", + "-acodec", "pcm_s16le", + "-ar", str(self.sample_rate), + "-ac", "1", + "pipe:1", + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL + ) + + stdout, _ = await process.communicate(input=mp3_data) + return stdout + + except Exception as e: + logger.error(f"ffmpeg conversion error: {e}") + return b"" + + async def cancel(self) -> None: + """Cancel ongoing synthesis.""" + self._cancel_event.set() + + +class MockTTSService(BaseTTSService): + """ + Mock TTS service for testing without actual synthesis. + + Generates silence or simple tones. + """ + + def __init__( + self, + voice: str = "mock", + sample_rate: int = 16000, + speed: float = 1.0 + ): + super().__init__(voice=voice, sample_rate=sample_rate, speed=speed) + + async def connect(self) -> None: + self.state = ServiceState.CONNECTED + logger.info("Mock TTS service connected") + + async def disconnect(self) -> None: + self.state = ServiceState.DISCONNECTED + logger.info("Mock TTS service disconnected") + + async def synthesize(self, text: str) -> bytes: + """Generate silence based on text length.""" + # Approximate: 100ms per word + word_count = len(text.split()) + duration_ms = word_count * 100 + samples = int(self.sample_rate * duration_ms / 1000) + + # Generate silence (zeros) + return bytes(samples * 2) # 16-bit = 2 bytes per sample + + async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]: + """Generate silence chunks.""" + audio = await self.synthesize(text) + + # Yield in 100ms chunks + chunk_size = self.sample_rate * 2 // 10 + for i in range(0, len(audio), chunk_size): + chunk_data = audio[i:i + chunk_size] + yield TTSChunk( + audio=chunk_data, + sample_rate=self.sample_rate, + is_final=(i + chunk_size >= len(audio)) + ) + await asyncio.sleep(0.05) # Simulate processing time diff --git a/engine/utils/__init__.py b/engine/utils/__init__.py new file mode 100644 index 0000000..48a989f --- /dev/null +++ b/engine/utils/__init__.py @@ -0,0 +1 @@ +"""Utilities Package""" diff --git a/engine/utils/logging.py b/engine/utils/logging.py new file mode 100644 index 0000000..28b3a8f --- /dev/null +++ b/engine/utils/logging.py @@ -0,0 +1,83 @@ +"""Logging configuration utilities.""" + +import sys +from loguru import logger +from pathlib import Path + + +def setup_logging( + log_level: str = "INFO", + log_format: str = "text", + log_to_file: bool = True, + log_dir: str = "logs" +): + """ + Configure structured logging with loguru. + + Args: + log_level: Logging level (DEBUG, INFO, WARNING, ERROR) + log_format: Format type (json or text) + log_to_file: Whether to log to file + log_dir: Directory for log files + """ + # Remove default handler + logger.remove() + + # Console handler + if log_format == "json": + logger.add( + sys.stdout, + format="{message}", + level=log_level, + serialize=True, + colorize=False + ) + else: + logger.add( + sys.stdout, + format="{time:HH:mm:ss} | {level: <8} | {message}", + level=log_level, + colorize=True + ) + + # File handler + if log_to_file: + log_path = Path(log_dir) + log_path.mkdir(exist_ok=True) + + if log_format == "json": + logger.add( + log_path / "active_call_{time:YYYY-MM-DD}.log", + format="{message}", + level=log_level, + rotation="1 day", + retention="7 days", + compression="zip", + serialize=True + ) + else: + logger.add( + log_path / "active_call_{time:YYYY-MM-DD}.log", + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}", + level=log_level, + rotation="1 day", + retention="7 days", + compression="zip" + ) + + return logger + + +def get_logger(name: str = None): + """ + Get a logger instance. + + Args: + name: Logger name (optional) + + Returns: + Logger instance + """ + if name: + return logger.bind(name=name) + return logger