Add support for Xfyun ASR and TTS services in the backend
- Introduce new Xfyun ASR and TTS services, enabling integration with iFlytek's voice recognition and synthesis capabilities. - Update AssistantConfig model to include interface types for STT and TTS. - Enhance credential testing to validate Xfyun credentials. - Modify service factory to create Xfyun services based on configuration. - Update README with new configuration details for Xfyun integration. - Add new frontend components for visualizing audio streams and managing user interactions.
This commit is contained in:
@@ -9,6 +9,14 @@
|
||||
"port": 3001,
|
||||
"autoPort": false
|
||||
},
|
||||
{
|
||||
"name": "ai-video-admin-3000",
|
||||
"runtimeExecutable": "npm",
|
||||
"runtimeArgs": ["run", "dev"],
|
||||
"cwd": "frontend",
|
||||
"port": 3000,
|
||||
"autoPort": false
|
||||
},
|
||||
{
|
||||
"name": "ui-docker",
|
||||
"runtimeExecutable": "docker",
|
||||
|
||||
@@ -55,6 +55,18 @@ ai-video-backend/
|
||||
| STT | SenseVoice / FunASR | 本地 OpenAI 兼容转写服务 |
|
||||
| TTS | CosyVoice | 本地 OpenAI 兼容 TTS 服务 |
|
||||
|
||||
### 讯飞 ASR / TTS / SuperTTS
|
||||
|
||||
讯飞继续复用 `ProviderCredential` 的现有字段,不增加专属列:
|
||||
|
||||
- `interface_type`: `xfyun`
|
||||
- `api_url`: 讯飞 WebSocket URL(`https://` 会自动转为 `wss://`)
|
||||
- `api_key`: `{"appId":"...","apiKey":"...","apiSecret":"..."}`
|
||||
- ASR `model_id`: `iat`
|
||||
- 普通 TTS `model_id`: `tts`
|
||||
- 超拟人 TTS `model_id`: `supertts`(包含 `/private/` 的 URL 也会自动识别)
|
||||
- TTS `voice`: 讯飞音色 ID;`speed=1.0` 对应讯飞正常语速 `50`
|
||||
|
||||
## 本地运行(用 uv,Python 3.12)
|
||||
|
||||
```bash
|
||||
@@ -97,7 +109,7 @@ docker compose --profile remote up -d
|
||||
|
||||
## 待联调 / TODO
|
||||
|
||||
- [ ] `pip install` 后跑通,核对 pipecat 版本的服务/transport 构造参数(代码内有注释)
|
||||
- [ ] 联调 Pipecat 1.3.0 语音链路与各 OpenAI 兼容服务
|
||||
- [ ] 起本地 SenseVoice / CosyVoice 的 OpenAI 兼容服务
|
||||
- [ ] `realtime` 模式(目前只 `pipeline` 级联)
|
||||
- [x] 前端 `DebugVoicePanel` 接 `/ws/voice`(参考 dograh `useWebSocketRTC.tsx`)
|
||||
|
||||
@@ -17,9 +17,9 @@ VALUES
|
||||
('model_003', 'SiliconFlow-Qwen3-Embedding-4B', 'Qwen/Qwen3-Embedding-4B', 'Embedding', 'openai', 'https://api.siliconflow.cn/v1', 'sk-uudpgflahqqjbofhgcbwjjefgwhvwwmxgeyehcueqlemwavq', '', 1.0, '', TRUE),
|
||||
('model_004', 'SiliconFlow-CosyVoice2-0.5B', 'FunAudioLLM/CosyVoice2-0.5B', 'TTS', 'openai', 'https://api.siliconflow.cn/v1', 'sk-uudpgflahqqjbofhgcbwjjefgwhvwwmxgeyehcueqlemwavq', 'FunAudioLLM/CosyVoice2-0.5B:anna', 1.0, '', FALSE),
|
||||
('model_005', 'Qwen-Max', 'qwen-max', 'LLM', 'openai', 'https://dashscope.aliyuncs.com/compatible-mode/v1', 'sk-qwen-4d8e2a6f0c', '', 1.0, '', FALSE),
|
||||
('model_006', '讯飞语音识别', 'iat', 'ASR', 'xfyun', 'https://iat-api.xfyun.cn/v2/iat', 'xf-asr-9b1c3d5e7a', '', 1.0, 'zh', TRUE),
|
||||
('model_006', '讯飞语音识别', 'iat', 'ASR', 'xfyun', 'https://iat-api.xfyun.cn/v2/iat', '{"appId":"replace-me","apiKey":"replace-me","apiSecret":"replace-me"}', '', 1.0, 'zh', TRUE),
|
||||
('model_007', 'Paraformer 识别', 'paraformer-realtime-v2', 'ASR', 'dashscope', 'https://dashscope.aliyuncs.com/api/v1/services/audio/asr', 'sk-paraformer-2e4f6a', '', 1.0, 'zh', FALSE),
|
||||
('model_008', '讯飞语音合成', 'tts', 'TTS', 'xfyun', 'https://tts-api.xfyun.cn/v2/tts', 'xf-tts-6c8a0b2d4f', 'xiaoyan', 1.0, '', TRUE),
|
||||
('model_008', '讯飞语音合成', 'tts', 'TTS', 'xfyun', 'https://tts-api.xfyun.cn/v2/tts', '{"appId":"replace-me","apiKey":"replace-me","apiSecret":"replace-me"}', 'xiaoyan', 1.0, '', TRUE),
|
||||
('model_009', 'CosyVoice 合成', 'cosyvoice-v1', 'TTS', 'dashscope', 'https://dashscope.aliyuncs.com/api/v1/services/audio/tts', 'sk-cosyvoice-1a3c5e', 'longxiaochun', 1.0, '', FALSE),
|
||||
('model_010', 'GPT Realtime', 'gpt-4o-realtime-preview', 'Realtime', 'openai', 'https://api.openai.com/v1/realtime', 'sk-realtime-3b5d7f9a1c', '', 1.0, '', TRUE),
|
||||
('model_011', 'Gemini Live', 'gemini-2.0-flash-live', 'Realtime', 'gemini', 'https://generativelanguage.googleapis.com/v1beta', 'gm-live-5e7a9c1b3d', '', 1.0, '', FALSE),
|
||||
|
||||
@@ -31,6 +31,8 @@ class AssistantConfig(BaseModel):
|
||||
stt_language: str = ""
|
||||
tts_speed: float = 1.0
|
||||
realtimeModel: str = ""
|
||||
stt_interface_type: str = "openai"
|
||||
tts_interface_type: str = "openai"
|
||||
|
||||
enableInterrupt: bool = True
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# webrtc -> SmallWebRTCTransport / SmallWebRTCConnection + aiortc
|
||||
# silero -> 本地 VAD(判断用户说话起止),语音必备
|
||||
# openai -> OpenAI 兼容的 LLM/STT/TTS 客户端(DeepSeek、SenseVoice、CosyVoice 都走它)
|
||||
pipecat-ai[webrtc,silero,openai]~=0.0.60
|
||||
pipecat-ai[webrtc,websocket,silero,openai]==1.3.0
|
||||
|
||||
fastapi
|
||||
httpx
|
||||
@@ -10,6 +10,7 @@ uvicorn[standard]
|
||||
python-dotenv
|
||||
pydantic
|
||||
loguru
|
||||
websockets>=13
|
||||
|
||||
# 存储:Postgres(SQLAlchemy 2.0 异步 + asyncpg 驱动)
|
||||
sqlalchemy[asyncio]>=2.0
|
||||
|
||||
@@ -15,7 +15,7 @@ from schemas import (
|
||||
CredentialTestResult,
|
||||
CredentialUpsert,
|
||||
)
|
||||
from services.credential_tester import test_openai_credential
|
||||
from services.credential_tester import test_openai_credential, test_xfyun_credential
|
||||
from services.masking import mask, resolve_incoming_key
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -84,6 +84,8 @@ async def create_credential(
|
||||
|
||||
@router.post("/test", response_model=CredentialTestResult)
|
||||
async def test_new_credential(body: CredentialTestRequest):
|
||||
if body.interface_type == "xfyun":
|
||||
return test_xfyun_credential(body)
|
||||
if body.interface_type != "openai":
|
||||
return CredentialTestResult(
|
||||
ok=False,
|
||||
@@ -111,6 +113,8 @@ async def test_saved_credential(
|
||||
config = body.model_copy(
|
||||
update={"api_key": resolve_incoming_key(body.api_key, c.api_key)}
|
||||
)
|
||||
if config.interface_type == "xfyun":
|
||||
return test_xfyun_credential(config)
|
||||
if config.interface_type != "openai":
|
||||
return CredentialTestResult(
|
||||
ok=False,
|
||||
|
||||
@@ -62,6 +62,8 @@ async def resolve_runtime_config(
|
||||
voice=(tts.voice if tts else ""),
|
||||
stt_language=(stt.language if stt else ""),
|
||||
tts_speed=(tts.speed if tts else 1.0),
|
||||
stt_interface_type=(stt.interface_type if stt else "openai"),
|
||||
tts_interface_type=(tts.interface_type if tts else "openai"),
|
||||
realtimeModel=(realtime.model_id if realtime else ""),
|
||||
# 运行时连接信息(真 key + url):凭证优先,否则 .env 兜底
|
||||
llm_api_key=(llm.api_key if llm else config.LLM_API_KEY),
|
||||
|
||||
@@ -9,6 +9,7 @@ import wave
|
||||
import httpx
|
||||
|
||||
from schemas import CredentialTestRequest, CredentialTestResult
|
||||
from services.pipecat.xfyun_config import parse_xfyun_credential
|
||||
|
||||
TEST_TIMEOUT_SECONDS = 10.0
|
||||
|
||||
@@ -123,3 +124,25 @@ async def test_openai_credential(
|
||||
message="无法连接到模型服务",
|
||||
detail=str(exc)[:300],
|
||||
)
|
||||
|
||||
|
||||
def test_xfyun_credential(config: CredentialTestRequest) -> CredentialTestResult:
|
||||
"""Validate the Xfyun credential packed into the existing api_key field.
|
||||
|
||||
Actual signed-WebSocket synthesis/recognition is exercised by the voice
|
||||
pipeline; this check deliberately avoids consuming provider quota.
|
||||
"""
|
||||
try:
|
||||
parse_xfyun_credential(config.api_key)
|
||||
except ValueError as exc:
|
||||
return CredentialTestResult(
|
||||
ok=False,
|
||||
message="讯飞凭证格式无效",
|
||||
detail=str(exc),
|
||||
)
|
||||
|
||||
return CredentialTestResult(
|
||||
ok=True,
|
||||
message="讯飞凭证格式有效",
|
||||
detail="请在语音测试页验证签名、识别和合成链路",
|
||||
)
|
||||
|
||||
@@ -10,19 +10,84 @@ from loguru import logger
|
||||
from models import AssistantConfig
|
||||
from services.pipecat.service_factory import create_services
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
InterruptionTaskFrame,
|
||||
TranscriptionFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
InputTextRawFrame,
|
||||
InputTransportMessageFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
OutputTransportMessageUrgentFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.turns.user_start import (
|
||||
TranscriptionUserTurnStartStrategy,
|
||||
VADUserTurnStartStrategy,
|
||||
)
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.workers.runner import WorkerRunner
|
||||
|
||||
|
||||
def _text_input(message) -> tuple[str, bool] | None:
|
||||
"""解析现有 user-text 与 RTVI send-text 两种前端文字消息。"""
|
||||
if not isinstance(message, dict):
|
||||
return None
|
||||
if message.get("type") == "user-text":
|
||||
text = str(message.get("text") or "").strip()
|
||||
return (text, True) if text else None
|
||||
if message.get("type") == "send-text":
|
||||
data = message.get("data")
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
text = str(data.get("content") or "").strip()
|
||||
options = data.get("options")
|
||||
run_immediately = not isinstance(options, dict) or options.get(
|
||||
"run_immediately", True
|
||||
)
|
||||
return (text, bool(run_immediately)) if text else None
|
||||
return None
|
||||
|
||||
|
||||
class TextInputProcessor(FrameProcessor):
|
||||
"""把 transport 文字消息转换成级联与实时 LLM 都能消费的帧。"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._register_event_handler("on_text_input")
|
||||
|
||||
async def process_frame(self, frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if not isinstance(frame, InputTransportMessageFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
parsed = _text_input(frame.message)
|
||||
if not parsed:
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
text, run_immediately = parsed
|
||||
if run_immediately:
|
||||
await self.broadcast_interruption()
|
||||
|
||||
await self.push_frame(
|
||||
LLMMessagesAppendFrame(
|
||||
messages=[{"role": "user", "content": text}],
|
||||
run_llm=run_immediately,
|
||||
)
|
||||
)
|
||||
if run_immediately:
|
||||
await self.push_frame(InputTextRawFrame(text=text))
|
||||
await self._call_event_handler("on_text_input", text)
|
||||
|
||||
|
||||
async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
@@ -37,78 +102,80 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
|
||||
stt, llm, tts = create_services(cfg)
|
||||
|
||||
context = OpenAILLMContext(messages=[{"role": "system", "content": cfg.prompt}])
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# 转写收集:user 侧收 ASR 最终转写,assistant 侧聚合 TTS 实际播报的文本,
|
||||
# 统一通过 data channel 推给前端聊天记录面板。
|
||||
transcript = TranscriptProcessor()
|
||||
context = LLMContext(messages=[{"role": "system", "content": cfg.prompt}])
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
start=[
|
||||
VADUserTurnStartStrategy(enable_interruptions=cfg.enableInterrupt),
|
||||
TranscriptionUserTurnStartStrategy(
|
||||
enable_interruptions=cfg.enableInterrupt
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
)
|
||||
text_input = TextInputProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
text_input,
|
||||
stt,
|
||||
transcript.user(),
|
||||
context_aggregator.user(),
|
||||
user_aggregator,
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
transcript.assistant(),
|
||||
context_aggregator.assistant(),
|
||||
assistant_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
worker = PipelineWorker(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
allow_interruptions=cfg.enableInterrupt,
|
||||
enable_metrics=False,
|
||||
),
|
||||
enable_rtvi=False,
|
||||
)
|
||||
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(_processor, frame):
|
||||
# 每条最终转写(用户/助手)推给前端,前端据此渲染聊天记录
|
||||
for msg in frame.messages:
|
||||
await task.queue_frame(
|
||||
TransportMessageUrgentFrame(
|
||||
async def queue_transcript(role: str, content: str, timestamp: str) -> None:
|
||||
if content:
|
||||
await worker.queue_frame(
|
||||
OutputTransportMessageUrgentFrame(
|
||||
message={
|
||||
"type": "transcript",
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"timestamp": msg.timestamp,
|
||||
}
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@transport.event_handler("on_app_message")
|
||||
async def on_app_message(_transport, message, _sender):
|
||||
# 前端文字输入:先打断当前播报,再当作一条用户最终转写注入,
|
||||
# 走与语音完全相同的 转写→上下文→LLM→TTS 链路
|
||||
if not isinstance(message, dict) or message.get("type") != "user-text":
|
||||
return
|
||||
text = str(message.get("text") or "").strip()
|
||||
if not text:
|
||||
return
|
||||
await task.queue_frames(
|
||||
[
|
||||
InterruptionTaskFrame(),
|
||||
TranscriptionFrame(
|
||||
text=text, user_id="debug", timestamp=time_now_iso8601()
|
||||
),
|
||||
]
|
||||
)
|
||||
@user_aggregator.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(_aggregator, _strategy, message):
|
||||
await queue_transcript("user", message.content, message.timestamp)
|
||||
|
||||
@assistant_aggregator.event_handler("on_assistant_turn_stopped")
|
||||
async def on_assistant_turn_stopped(_aggregator, message):
|
||||
await queue_transcript("assistant", message.content, message.timestamp)
|
||||
|
||||
@text_input.event_handler("on_text_input")
|
||||
async def on_text_input(_processor, text):
|
||||
await queue_transcript("user", text, time_now_iso8601())
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(_transport, _client):
|
||||
if cfg.greeting:
|
||||
await task.queue_frame(TTSSpeakFrame(cfg.greeting))
|
||||
await worker.queue_frame(TTSSpeakFrame(cfg.greeting))
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(_transport, _client):
|
||||
logger.info("对端断开,结束管线")
|
||||
await task.queue_frame(EndFrame())
|
||||
await worker.queue_frame(EndFrame())
|
||||
|
||||
runner = PipelineRunner(handle_sigint=False)
|
||||
await runner.run(task)
|
||||
runner = WorkerRunner(handle_sigint=False)
|
||||
await runner.add_workers(worker)
|
||||
await runner.run()
|
||||
logger.info("管线已结束")
|
||||
|
||||
@@ -13,6 +13,20 @@ from pipecat.services.openai.stt import OpenAISTTService
|
||||
from pipecat.services.openai.tts import VALID_VOICES, OpenAITTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
from services.pipecat.xfyun_asr import DEFAULT_XFYUN_ASR_URL, XfyunASRService
|
||||
from services.pipecat.xfyun_config import (
|
||||
is_super_tts,
|
||||
parse_xfyun_credential,
|
||||
websocket_url,
|
||||
xfyun_language,
|
||||
xfyun_speed,
|
||||
)
|
||||
from services.pipecat.xfyun_super_tts import (
|
||||
DEFAULT_XFYUN_SUPER_TTS_URL,
|
||||
XfyunSuperTTSService,
|
||||
)
|
||||
from services.pipecat.xfyun_tts import DEFAULT_XFYUN_TTS_URL, XfyunTTSService
|
||||
|
||||
|
||||
def _language(value: str) -> Language | None:
|
||||
if not value:
|
||||
@@ -29,11 +43,24 @@ def create_stt(cfg: AssistantConfig):
|
||||
|
||||
连接信息优先用 cfg(由 config_resolver 从 DB 注入),为空回退 .env 默认。
|
||||
"""
|
||||
if cfg.stt_interface_type == "xfyun":
|
||||
credential = parse_xfyun_credential(cfg.stt_api_key)
|
||||
return XfyunASRService(
|
||||
app_id=credential.app_id,
|
||||
api_key=credential.api_key,
|
||||
api_secret=credential.api_secret,
|
||||
url=websocket_url(cfg.stt_base_url, DEFAULT_XFYUN_ASR_URL),
|
||||
language=xfyun_language(cfg.stt_language),
|
||||
sample_rate=16000,
|
||||
)
|
||||
|
||||
return OpenAISTTService(
|
||||
api_key=cfg.stt_api_key or config.STT_API_KEY,
|
||||
base_url=cfg.stt_base_url or config.STT_BASE_URL,
|
||||
model=cfg.asr or config.STT_MODEL,
|
||||
language=_language(cfg.stt_language),
|
||||
settings=OpenAISTTService.Settings(
|
||||
model=cfg.asr or config.STT_MODEL,
|
||||
language=_language(cfg.stt_language),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -42,13 +69,39 @@ def create_llm(cfg: AssistantConfig):
|
||||
return OpenAILLMService(
|
||||
api_key=cfg.llm_api_key or config.LLM_API_KEY,
|
||||
base_url=cfg.llm_base_url or config.LLM_BASE_URL,
|
||||
model=cfg.model or config.LLM_MODEL,
|
||||
settings=OpenAILLMService.Settings(model=cfg.model or config.LLM_MODEL),
|
||||
)
|
||||
|
||||
|
||||
def create_tts(cfg: AssistantConfig):
|
||||
"""CosyVoice 等,走 OpenAI 兼容的 /v1/audio/speech。"""
|
||||
voice = cfg.voice or config.TTS_VOICE
|
||||
if cfg.tts_interface_type == "xfyun":
|
||||
credential = parse_xfyun_credential(cfg.tts_api_key)
|
||||
speed = xfyun_speed(cfg.tts_speed)
|
||||
if is_super_tts(cfg.tts_model, cfg.tts_base_url):
|
||||
return XfyunSuperTTSService(
|
||||
app_id=credential.app_id,
|
||||
api_key=credential.api_key,
|
||||
api_secret=credential.api_secret,
|
||||
voice=voice,
|
||||
url=websocket_url(cfg.tts_base_url, DEFAULT_XFYUN_SUPER_TTS_URL),
|
||||
sample_rate=16000,
|
||||
source_sample_rate=24000,
|
||||
speed=speed,
|
||||
)
|
||||
return XfyunTTSService(
|
||||
app_id=credential.app_id,
|
||||
api_key=credential.api_key,
|
||||
api_secret=credential.api_secret,
|
||||
voice=voice,
|
||||
url=websocket_url(cfg.tts_base_url, DEFAULT_XFYUN_TTS_URL),
|
||||
sample_rate=16000,
|
||||
source_sample_rate=16000,
|
||||
speed=speed,
|
||||
push_stop_frames=True,
|
||||
)
|
||||
|
||||
# Pipecat 默认只接受 OpenAI 官方音色。OpenAI 兼容服务常使用自定义 voice id,
|
||||
# 注册为原样映射后仍由 OpenAI SDK 按字符串透传给供应商。
|
||||
VALID_VOICES.setdefault(voice, voice)
|
||||
@@ -65,9 +118,9 @@ def create_tts(cfg: AssistantConfig):
|
||||
|
||||
def create_services(cfg: AssistantConfig):
|
||||
logger.info(
|
||||
f"创建服务: stt={cfg.asr or config.STT_MODEL} "
|
||||
f"创建服务: stt={cfg.stt_interface_type}/{cfg.asr or config.STT_MODEL} "
|
||||
f"llm={cfg.model or config.LLM_MODEL} "
|
||||
f"tts={cfg.tts_model or config.TTS_MODEL} "
|
||||
f"tts={cfg.tts_interface_type}/{cfg.tts_model or config.TTS_MODEL} "
|
||||
f"voice={cfg.voice or config.TTS_VOICE}"
|
||||
)
|
||||
return create_stt(cfg), create_llm(cfg), create_tts(cfg)
|
||||
|
||||
@@ -11,14 +11,13 @@
|
||||
from fastapi import WebSocket
|
||||
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
|
||||
# WebRTC
|
||||
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
|
||||
from pipecat.transports.smallwebrtc.transport import SmallWebRTCTransport
|
||||
|
||||
# 裸 WS 音频流
|
||||
from pipecat.transports.network.fastapi_websocket import (
|
||||
from pipecat.transports.websocket.fastapi import (
|
||||
FastAPIWebsocketTransport,
|
||||
FastAPIWebsocketParams,
|
||||
)
|
||||
@@ -30,7 +29,6 @@ def _base_params() -> dict:
|
||||
return dict(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(), # 本地 VAD,打断功能依赖它
|
||||
)
|
||||
|
||||
|
||||
|
||||
353
backend/services/pipecat/xfyun_asr.py
Normal file
353
backend/services/pipecat/xfyun_asr.py
Normal file
@@ -0,0 +1,353 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from email.utils import format_datetime
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import STTSettings
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
|
||||
|
||||
DEFAULT_XFYUN_ASR_URL = "wss://iat-api.xfyun.cn/v2/iat"
|
||||
|
||||
|
||||
class XfyunASRService(STTService):
|
||||
"""iFlytek/Xfyun streaming voice dictation service for Pipecat."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
app_id: str,
|
||||
api_key: str,
|
||||
api_secret: str,
|
||||
url: str | None = None,
|
||||
language: str = "zh_cn",
|
||||
domain: str = "iat",
|
||||
accent: str = "mandarin",
|
||||
sample_rate: int = 16000,
|
||||
encoding: str = "raw",
|
||||
frame_size: int = 1280,
|
||||
open_timeout: float = 10.0,
|
||||
dynamic_correction: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
settings=STTSettings(model=None, language=language),
|
||||
**kwargs,
|
||||
)
|
||||
self._app_id = app_id or os.environ.get("XFYUN_APP_ID", "")
|
||||
self._api_key = api_key or os.environ.get("XFYUN_API_KEY", "")
|
||||
self._api_secret = api_secret or os.environ.get("XFYUN_API_SECRET", "")
|
||||
self._url = url or DEFAULT_XFYUN_ASR_URL
|
||||
self._language = language
|
||||
self._domain = domain
|
||||
self._accent = accent
|
||||
self._encoding = encoding
|
||||
self._frame_size = frame_size
|
||||
self._open_timeout = open_timeout
|
||||
self._dynamic_correction = dynamic_correction
|
||||
|
||||
self._websocket = None
|
||||
self._receive_task = None
|
||||
self._audio_buffer = bytearray()
|
||||
self._sent_first_frame = False
|
||||
self._sent_final_frame = False
|
||||
self._finalizing_turn = False
|
||||
self._partials: list[str] = []
|
||||
self._last_text = ""
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
await self._close_utterance()
|
||||
await super().cleanup()
|
||||
|
||||
async def stop(self, frame: EndFrame) -> None:
|
||||
await self._close_utterance()
|
||||
await super().stop(frame)
|
||||
|
||||
async def cancel(self, frame: CancelFrame) -> None:
|
||||
await self._close_utterance()
|
||||
await super().cancel(frame)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserStoppedSpeakingFrame):
|
||||
# Aggregator-level turn end (broadcast once per logical user turn).
|
||||
# This is the only boundary that finalizes/closes the xfyun
|
||||
# websocket, so brief VAD pauses do not restart the ASR session.
|
||||
await self._finish_utterance()
|
||||
elif isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._start_utterance()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
|
||||
if not audio:
|
||||
yield None
|
||||
return
|
||||
|
||||
if not self._websocket or self._websocket.state is not State.OPEN:
|
||||
await self._start_utterance()
|
||||
|
||||
self._audio_buffer.extend(audio)
|
||||
await self._flush_audio_buffer(final=False)
|
||||
yield None
|
||||
|
||||
async def _start_utterance(self) -> None:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
if not self._app_id or not self._api_key or not self._api_secret:
|
||||
await self.push_error("Xfyun ASR requires app_id, api_key, and api_secret")
|
||||
return
|
||||
|
||||
if self.sample_rate not in (8000, 16000):
|
||||
await self.push_error("Xfyun ASR sample rate must be 8000 or 16000")
|
||||
return
|
||||
|
||||
self._audio_buffer.clear()
|
||||
self._partials = []
|
||||
self._last_text = ""
|
||||
self._sent_first_frame = False
|
||||
self._sent_final_frame = False
|
||||
|
||||
auth_url = _build_auth_url(self._url, self._api_key, self._api_secret)
|
||||
try:
|
||||
self._websocket = await websocket_connect(
|
||||
auth_url,
|
||||
max_size=None,
|
||||
open_timeout=self._open_timeout,
|
||||
)
|
||||
except Exception as exc:
|
||||
await self.push_error(f"Xfyun ASR connection failed: {exc}", exception=exc)
|
||||
self._websocket = None
|
||||
return
|
||||
|
||||
self._receive_task = self.create_task(
|
||||
self._receive_messages(),
|
||||
name="xfyun_asr_receive",
|
||||
)
|
||||
|
||||
async def _finish_utterance(self) -> None:
|
||||
if not self._websocket or self._websocket.state is not State.OPEN:
|
||||
return
|
||||
|
||||
await self._flush_audio_buffer(final=True)
|
||||
if not self._sent_first_frame:
|
||||
await self._close_utterance()
|
||||
return
|
||||
|
||||
if not self._sent_final_frame:
|
||||
self._finalizing_turn = True
|
||||
await self._send_payload({"data": {"status": 2}})
|
||||
self.request_finalize()
|
||||
self._sent_final_frame = True
|
||||
|
||||
async def _close_utterance(self) -> None:
|
||||
current_task = asyncio.current_task()
|
||||
if self._receive_task and self._receive_task is not current_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
websocket = self._websocket
|
||||
self._websocket = None
|
||||
if websocket and websocket.state is State.OPEN:
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._audio_buffer.clear()
|
||||
self._sent_first_frame = False
|
||||
self._sent_final_frame = False
|
||||
self._finalizing_turn = False
|
||||
|
||||
async def _flush_audio_buffer(self, *, final: bool) -> None:
|
||||
while len(self._audio_buffer) >= self._frame_size:
|
||||
chunk = bytes(self._audio_buffer[: self._frame_size])
|
||||
del self._audio_buffer[: self._frame_size]
|
||||
await self._send_audio_chunk(chunk, status=1)
|
||||
|
||||
if final and self._audio_buffer:
|
||||
chunk = bytes(self._audio_buffer)
|
||||
self._audio_buffer.clear()
|
||||
await self._send_audio_chunk(chunk, status=1)
|
||||
|
||||
async def _send_audio_chunk(self, audio: bytes, *, status: int) -> None:
|
||||
if not audio:
|
||||
return
|
||||
|
||||
if not self._sent_first_frame:
|
||||
business = {
|
||||
"language": self._language,
|
||||
"domain": self._domain,
|
||||
"accent": self._accent,
|
||||
}
|
||||
if self._dynamic_correction:
|
||||
business["dwa"] = "wpgs"
|
||||
|
||||
payload = {
|
||||
"common": {"app_id": self._app_id},
|
||||
"business": business,
|
||||
"data": {
|
||||
"status": 0,
|
||||
"format": f"audio/L16;rate={self.sample_rate}",
|
||||
"encoding": self._encoding,
|
||||
"audio": base64.b64encode(audio).decode("utf-8"),
|
||||
},
|
||||
}
|
||||
self._sent_first_frame = True
|
||||
else:
|
||||
payload = {
|
||||
"data": {
|
||||
"status": status,
|
||||
"format": f"audio/L16;rate={self.sample_rate}",
|
||||
"encoding": self._encoding,
|
||||
"audio": base64.b64encode(audio).decode("utf-8"),
|
||||
}
|
||||
}
|
||||
|
||||
await self._send_payload(payload)
|
||||
|
||||
async def _send_payload(self, payload: dict[str, Any]) -> None:
|
||||
if not self._websocket or self._websocket.state is not State.OPEN:
|
||||
return
|
||||
await self._websocket.send(json.dumps(payload, ensure_ascii=False))
|
||||
|
||||
async def _receive_messages(self) -> None:
|
||||
websocket = self._websocket
|
||||
if not websocket:
|
||||
return
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
await self._process_response(json.loads(message))
|
||||
except Exception as exc:
|
||||
if self._websocket is websocket:
|
||||
await self.push_error(f"Xfyun ASR receive failed: {exc}", exception=exc)
|
||||
finally:
|
||||
if self._websocket is websocket:
|
||||
self._websocket = None
|
||||
self._receive_task = None
|
||||
|
||||
async def _process_response(self, payload: dict[str, Any]) -> None:
|
||||
code = payload.get("code", -1)
|
||||
if code != 0:
|
||||
message = payload.get("message", "unknown error")
|
||||
sid = payload.get("sid")
|
||||
await self.push_error(f"Xfyun ASR error code={code}, sid={sid}, message={message}")
|
||||
return
|
||||
|
||||
data = payload.get("data")
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
|
||||
is_final_response = data.get("status") == 2
|
||||
recognition = data.get("result")
|
||||
if isinstance(recognition, dict):
|
||||
text = self._apply_recognition_result(recognition)
|
||||
if text and text != self._last_text:
|
||||
self._last_text = text
|
||||
if not self._finalizing_turn and not is_final_response:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
_language_or_none(self._language),
|
||||
result=payload,
|
||||
)
|
||||
)
|
||||
|
||||
if is_final_response:
|
||||
final_text = self._last_text
|
||||
if final_text:
|
||||
self.confirm_finalize()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
final_text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
_language_or_none(self._language),
|
||||
result=payload,
|
||||
)
|
||||
)
|
||||
await self._close_utterance()
|
||||
|
||||
def _apply_recognition_result(self, recognition: dict[str, Any]) -> str:
|
||||
partial = _extract_text_from_result(recognition)
|
||||
if not partial:
|
||||
return self._last_text
|
||||
|
||||
if self._dynamic_correction and recognition.get("pgs") == "rpl" and recognition.get("rg"):
|
||||
start, end = recognition["rg"]
|
||||
if 1 <= start <= len(self._partials):
|
||||
self._partials[start - 1 : end] = [partial]
|
||||
else:
|
||||
logger.debug(f"Ignoring out-of-range Xfyun replacement rg={recognition['rg']}")
|
||||
else:
|
||||
self._partials.append(partial)
|
||||
|
||||
return "".join(self._partials)
|
||||
|
||||
|
||||
def _extract_text_from_result(result: dict[str, Any]) -> str:
|
||||
words: list[str] = []
|
||||
for item in result.get("ws", []):
|
||||
for candidate in item.get("cw", []):
|
||||
word = candidate.get("w")
|
||||
if word:
|
||||
words.append(word)
|
||||
return "".join(words)
|
||||
|
||||
|
||||
def _build_auth_url(url: str, api_key: str, api_secret: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.netloc
|
||||
path = parsed.path or "/v2/iat"
|
||||
date = format_datetime(datetime.now(timezone.utc), usegmt=True)
|
||||
request_line = f"GET {path} HTTP/1.1"
|
||||
signature_origin = f"host: {host}\ndate: {date}\n{request_line}"
|
||||
signature_sha = hmac.new(
|
||||
api_secret.encode("utf-8"),
|
||||
signature_origin.encode("utf-8"),
|
||||
digestmod=hashlib.sha256,
|
||||
).digest()
|
||||
signature = base64.b64encode(signature_sha).decode("utf-8")
|
||||
authorization_origin = (
|
||||
f'api_key="{api_key}", algorithm="hmac-sha256", '
|
||||
f'headers="host date request-line", signature="{signature}"'
|
||||
)
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode("utf-8")
|
||||
query = urlencode({"authorization": authorization, "date": date, "host": host})
|
||||
return f"{url}?{query}"
|
||||
|
||||
|
||||
def _language_or_none(value: str) -> Language | None:
|
||||
try:
|
||||
return Language(value)
|
||||
except ValueError:
|
||||
return None
|
||||
65
backend/services/pipecat/xfyun_config.py
Normal file
65
backend/services/pipecat/xfyun_config.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Parse Xfyun's three-part credential from ProviderCredential.api_key."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class XfyunCredential:
|
||||
app_id: str
|
||||
api_key: str
|
||||
api_secret: str
|
||||
|
||||
|
||||
def parse_xfyun_credential(value: str) -> XfyunCredential:
|
||||
"""Accept JSON in the existing api_key column.
|
||||
|
||||
Example:
|
||||
{"appId":"...","apiKey":"...","apiSecret":"..."}
|
||||
"""
|
||||
try:
|
||||
payload = json.loads(value)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(
|
||||
'Xfyun API Key must be JSON: {"appId":"...","apiKey":"...","apiSecret":"..."}'
|
||||
) from exc
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("Xfyun API Key JSON must be an object")
|
||||
|
||||
credential = XfyunCredential(
|
||||
app_id=str(payload.get("appId") or payload.get("app_id") or "").strip(),
|
||||
api_key=str(payload.get("apiKey") or payload.get("api_key") or "").strip(),
|
||||
api_secret=str(
|
||||
payload.get("apiSecret") or payload.get("api_secret") or ""
|
||||
).strip(),
|
||||
)
|
||||
if not credential.app_id or not credential.api_key or not credential.api_secret:
|
||||
raise ValueError("Xfyun API Key JSON requires appId, apiKey, and apiSecret")
|
||||
return credential
|
||||
|
||||
|
||||
def websocket_url(value: str, default: str) -> str:
|
||||
url = (value or default).strip()
|
||||
if url.startswith("https://"):
|
||||
return f"wss://{url.removeprefix('https://')}"
|
||||
if url.startswith("http://"):
|
||||
return f"ws://{url.removeprefix('http://')}"
|
||||
return url
|
||||
|
||||
|
||||
def is_super_tts(model_id: str, api_url: str) -> bool:
|
||||
model = model_id.lower().replace("-", "_")
|
||||
return "super" in model or "/private/" in api_url.lower()
|
||||
|
||||
|
||||
def xfyun_language(value: str) -> str:
|
||||
normalized = (value or "zh_cn").lower().replace("-", "_")
|
||||
return {"zh": "zh_cn", "en": "en_us"}.get(normalized, normalized)
|
||||
|
||||
|
||||
def xfyun_speed(value: float) -> int:
|
||||
"""Reuse the existing OpenAI-style speed field where 1.0 means normal."""
|
||||
return max(0, min(100, round(value * 50 if value <= 4 else value)))
|
||||
391
backend/services/pipecat/xfyun_super_tts.py
Normal file
391
backend/services/pipecat/xfyun_super_tts.py
Normal file
@@ -0,0 +1,391 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from email.utils import format_datetime
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import TTSSettings
|
||||
from pipecat.services.tts_service import TextAggregationMode, WebsocketTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as exc:
|
||||
logger.error(f"Exception: {exc}")
|
||||
logger.error("In order to use Xfyun Super TTS, install the websockets package.")
|
||||
raise Exception(f"Missing module: {exc}") from exc
|
||||
|
||||
from .xfyun_tts import _sanitize_text_for_tts
|
||||
|
||||
|
||||
DEFAULT_XFYUN_SUPER_TTS_URL = "wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6"
|
||||
VALID_SAMPLE_RATES = {8000, 16000, 24000}
|
||||
|
||||
|
||||
class XfyunSuperTTSService(WebsocketTTSService):
|
||||
"""iFlytek/Xfyun Super Smart TTS using bidirectional WebSocket streaming.
|
||||
|
||||
The service keeps one Xfyun synthesis session open for a Pipecat turn. Each
|
||||
``run_tts`` call sends a text segment with status 0/1, while ``flush_audio``
|
||||
sends the terminal status 2 frame. Audio arrives on the receive task and is
|
||||
appended to the Pipecat audio context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
app_id: str,
|
||||
api_key: str,
|
||||
api_secret: str,
|
||||
voice: str,
|
||||
url: str | None = None,
|
||||
sample_rate: int = 16000,
|
||||
source_sample_rate: int = 24000,
|
||||
encoding: str = "raw",
|
||||
speed: int = 50,
|
||||
volume: int = 50,
|
||||
pitch: int = 50,
|
||||
oral_level: str = "mid",
|
||||
text_aggregation_mode: TextAggregationMode | str | None = TextAggregationMode.TOKEN,
|
||||
open_timeout: float = 30.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if isinstance(text_aggregation_mode, str):
|
||||
text_aggregation_mode = TextAggregationMode(text_aggregation_mode)
|
||||
|
||||
super().__init__(
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_text_frames=True,
|
||||
push_stop_frames=False,
|
||||
push_start_frame=True,
|
||||
pause_frame_processing=False,
|
||||
sample_rate=sample_rate,
|
||||
settings=TTSSettings(model=None, voice=voice, language=None),
|
||||
**kwargs,
|
||||
)
|
||||
self._app_id = app_id or os.environ.get("XFYUN_APP_ID", "")
|
||||
self._api_key = api_key or os.environ.get("XFYUN_API_KEY", "")
|
||||
self._api_secret = api_secret or os.environ.get("XFYUN_API_SECRET", "")
|
||||
self._voice = voice
|
||||
self._url = url or DEFAULT_XFYUN_SUPER_TTS_URL
|
||||
self._source_sample_rate = source_sample_rate
|
||||
self._encoding = encoding
|
||||
self._speed = speed
|
||||
self._volume = volume
|
||||
self._pitch = pitch
|
||||
self._oral_level = oral_level
|
||||
self._open_timeout = open_timeout
|
||||
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._active_context_id: str | None = None
|
||||
self._started_contexts: set[str] = set()
|
||||
self._seq_by_context: dict[str, int] = {}
|
||||
self._sent_text_bytes_by_context: dict[str, int] = {}
|
||||
self._stream_completed = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame) -> None:
|
||||
await super().start(frame)
|
||||
if not self._app_id or not self._api_key or not self._api_secret:
|
||||
await self.push_error(
|
||||
error_msg="Xfyun Super TTS requires app_id, api_key, and api_secret"
|
||||
)
|
||||
return
|
||||
if self._encoding != "raw":
|
||||
await self.push_error(error_msg="Xfyun Super TTS must use raw PCM audio in Pipecat")
|
||||
return
|
||||
if self._source_sample_rate not in VALID_SAMPLE_RATES:
|
||||
await self.push_error(
|
||||
error_msg=(
|
||||
"Xfyun Super TTS source_sample_rate must be one of "
|
||||
f"{sorted(VALID_SAMPLE_RATES)}"
|
||||
)
|
||||
)
|
||||
return
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame) -> None:
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame) -> None:
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def flush_audio(self, context_id: str | None = None) -> None:
|
||||
flush_id = context_id or self.get_active_audio_context_id()
|
||||
if not flush_id or not self._websocket:
|
||||
return
|
||||
if flush_id not in self._started_contexts:
|
||||
return
|
||||
|
||||
logger.trace(f"{self}: flushing Xfyun Super TTS stream {flush_id}")
|
||||
await self._send_request_frame(flush_id, "", status=2)
|
||||
|
||||
async def on_audio_context_interrupted(self, context_id: str) -> None:
|
||||
await self.stop_all_metrics()
|
||||
await self._reset_context(context_id)
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
await super().on_audio_context_interrupted(context_id)
|
||||
|
||||
async def _connect(self) -> None:
|
||||
await super()._connect()
|
||||
await self._connect_websocket()
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self) -> None:
|
||||
await super()._disconnect()
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self) -> None:
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
logger.debug("Connecting to Xfyun Super TTS")
|
||||
auth_url = _build_auth_url(self._url, self._api_key, self._api_secret)
|
||||
self._websocket = await websocket_connect(
|
||||
auth_url,
|
||||
max_size=None,
|
||||
open_timeout=self._open_timeout,
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as exc:
|
||||
self._websocket = None
|
||||
await self.push_error(
|
||||
error_msg=f"Unable to connect to Xfyun Super TTS: {exc}",
|
||||
exception=exc,
|
||||
)
|
||||
await self._call_event_handler("on_connection_error", f"{exc}")
|
||||
|
||||
async def _disconnect_websocket(self) -> None:
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from Xfyun Super TTS")
|
||||
await self._websocket.close()
|
||||
except Exception as exc:
|
||||
await self.push_error(
|
||||
error_msg=f"Error closing Xfyun Super TTS websocket: {exc}",
|
||||
exception=exc,
|
||||
)
|
||||
finally:
|
||||
await self.remove_active_audio_context()
|
||||
self._websocket = None
|
||||
self._active_context_id = None
|
||||
self._started_contexts.clear()
|
||||
self._seq_by_context.clear()
|
||||
self._sent_text_bytes_by_context.clear()
|
||||
self._stream_completed = False
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _receive_messages(self) -> None:
|
||||
async for raw_message in self._get_websocket():
|
||||
try:
|
||||
message = json.loads(raw_message)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"{self}: received non-JSON Xfyun Super TTS message: {raw_message!r}")
|
||||
continue
|
||||
|
||||
header = message.get("header") or {}
|
||||
code = header.get("code", -1)
|
||||
sid = header.get("sid")
|
||||
context_id = self._active_context_id
|
||||
|
||||
if code != 0:
|
||||
error_message = header.get("message", "unknown error")
|
||||
await self.push_error(
|
||||
error_msg=f"Xfyun Super TTS error code={code}, sid={sid}: {error_message}"
|
||||
)
|
||||
if context_id and self.audio_context_available(context_id):
|
||||
await self.append_to_audio_context(
|
||||
context_id, TTSStoppedFrame(context_id=context_id)
|
||||
)
|
||||
await self.remove_audio_context(context_id)
|
||||
if context_id:
|
||||
await self._reset_context(context_id)
|
||||
continue
|
||||
|
||||
audio_obj = (message.get("payload") or {}).get("audio") or {}
|
||||
audio_b64 = audio_obj.get("audio")
|
||||
if audio_b64 and context_id and self.audio_context_available(context_id):
|
||||
await self.stop_ttfb_metrics()
|
||||
audio = base64.b64decode(audio_b64)
|
||||
if self._source_sample_rate != self.sample_rate:
|
||||
audio = await self._resampler.resample(
|
||||
audio, self._source_sample_rate, self.sample_rate
|
||||
)
|
||||
frame = TTSAudioRawFrame(audio, self.sample_rate, 1, context_id=context_id)
|
||||
await self.append_to_audio_context(context_id, frame)
|
||||
|
||||
audio_status = audio_obj.get("status")
|
||||
header_status = header.get("status")
|
||||
if audio_status == 2 or header_status == 2:
|
||||
if context_id and self.audio_context_available(context_id):
|
||||
await self.append_to_audio_context(
|
||||
context_id, TTSStoppedFrame(context_id=context_id)
|
||||
)
|
||||
await self.remove_audio_context(context_id)
|
||||
if context_id:
|
||||
await self._reset_context(context_id)
|
||||
self._stream_completed = True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
|
||||
sanitized = _sanitize_text_for_tts(text)
|
||||
if not sanitized:
|
||||
return
|
||||
|
||||
if not self._is_streaming_tokens:
|
||||
logger.debug(f"{self}: Generating Xfyun Super TTS [{sanitized}]")
|
||||
else:
|
||||
logger.trace(f"{self}: Generating Xfyun Super TTS [{sanitized}]")
|
||||
|
||||
if self._stream_completed and self._websocket:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
if self._active_context_id and self._active_context_id != context_id:
|
||||
yield ErrorFrame(
|
||||
error=(
|
||||
"Xfyun Super TTS supports one active synthesis stream per WebSocket; "
|
||||
f"active={self._active_context_id}, new={context_id}"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
status = 0 if context_id not in self._started_contexts else 1
|
||||
await self._send_request_frame(context_id, sanitized, status=status)
|
||||
await self.start_tts_usage_metrics(sanitized)
|
||||
except Exception as exc:
|
||||
yield ErrorFrame(error=f"Xfyun Super TTS request failed: {exc}")
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
|
||||
yield None
|
||||
|
||||
async def _send_request_frame(self, context_id: str, text: str, *, status: int) -> None:
|
||||
if status == 0:
|
||||
self._active_context_id = context_id
|
||||
self._started_contexts.add(context_id)
|
||||
|
||||
seq = self._seq_by_context.get(context_id, 0)
|
||||
text_bytes = text.encode("utf-8")
|
||||
total_bytes = self._sent_text_bytes_by_context.get(context_id, 0) + len(text_bytes)
|
||||
if total_bytes > 65536:
|
||||
raise ValueError("Xfyun Super TTS text must not exceed 64K UTF-8 bytes per stream")
|
||||
|
||||
frame = self._build_request_frame(text, status=status, seq=seq)
|
||||
await self._get_websocket().send(json.dumps(frame, ensure_ascii=False))
|
||||
|
||||
self._seq_by_context[context_id] = seq + 1
|
||||
self._sent_text_bytes_by_context[context_id] = total_bytes
|
||||
|
||||
def _build_request_frame(self, text: str, *, status: int, seq: int) -> dict[str, Any]:
|
||||
return {
|
||||
"header": {
|
||||
"app_id": self._app_id,
|
||||
"status": status,
|
||||
},
|
||||
"parameter": {
|
||||
"oral": {
|
||||
"oral_level": self._oral_level,
|
||||
},
|
||||
"tts": {
|
||||
"vcn": self._voice,
|
||||
"speed": self._speed,
|
||||
"volume": self._volume,
|
||||
"pitch": self._pitch,
|
||||
"bgs": 0,
|
||||
"reg": 0,
|
||||
"rdn": 0,
|
||||
"rhy": 0,
|
||||
"audio": {
|
||||
"encoding": self._encoding,
|
||||
"sample_rate": self._source_sample_rate,
|
||||
"channels": 1,
|
||||
"bit_depth": 16,
|
||||
"frame_size": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
"payload": {
|
||||
"text": {
|
||||
"encoding": "utf8",
|
||||
"compress": "raw",
|
||||
"format": "plain",
|
||||
"status": status,
|
||||
"seq": seq,
|
||||
"text": base64.b64encode(text.encode("utf-8")).decode("utf-8"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def _reset_context(self, context_id: str) -> None:
|
||||
self._started_contexts.discard(context_id)
|
||||
self._seq_by_context.pop(context_id, None)
|
||||
self._sent_text_bytes_by_context.pop(context_id, None)
|
||||
if self._active_context_id == context_id:
|
||||
self._active_context_id = None
|
||||
|
||||
|
||||
def _build_auth_url(url: str, api_key: str, api_secret: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in {"ws", "wss"} or not parsed.hostname:
|
||||
raise ValueError(f"invalid Xfyun Super TTS WebSocket URL: {url}")
|
||||
|
||||
host = parsed.hostname
|
||||
path = parsed.path or "/"
|
||||
date = format_datetime(datetime.now(timezone.utc), usegmt=True)
|
||||
request_line = f"GET {path} HTTP/1.1"
|
||||
signature_origin = f"host: {host}\ndate: {date}\n{request_line}"
|
||||
signature_sha = hmac.new(
|
||||
api_secret.encode("utf-8"),
|
||||
signature_origin.encode("utf-8"),
|
||||
digestmod=hashlib.sha256,
|
||||
).digest()
|
||||
signature = base64.b64encode(signature_sha).decode("utf-8")
|
||||
authorization_origin = (
|
||||
f'api_key="{api_key}", algorithm="hmac-sha256", '
|
||||
f'headers="host date request-line", signature="{signature}"'
|
||||
)
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode("utf-8")
|
||||
query = urlencode({"authorization": authorization, "date": date, "host": host})
|
||||
return f"{url}?{query}"
|
||||
257
backend/services/pipecat/xfyun_tts.py
Normal file
257
backend/services/pipecat/xfyun_tts.py
Normal file
@@ -0,0 +1,257 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import unicodedata
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from datetime import datetime, timezone
|
||||
from email.utils import format_datetime
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame
|
||||
from pipecat.services.settings import TTSSettings
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from websockets.asyncio.client import connect
|
||||
|
||||
|
||||
DEFAULT_XFYUN_TTS_URL = "wss://tts-api.xfyun.cn/v2/tts"
|
||||
|
||||
# Strip characters Xfyun's online TTS cannot synthesize. The engine silently
|
||||
# rejects (or returns empty audio for) text containing emoji and other
|
||||
# non-BMP symbols, which surfaces as "request finished without audio data".
|
||||
_EMOJI_AND_SYMBOL_RE = re.compile(
|
||||
"["
|
||||
"\U0001F300-\U0001FAFF" # misc pictographs, emoji, symbols, transport, etc.
|
||||
"\U00002600-\U000027BF" # misc symbols and dingbats
|
||||
"\U0001F1E6-\U0001F1FF" # regional indicators (flags)
|
||||
"\uFE00-\uFE0F" # variation selectors
|
||||
"\u200D" # zero-width joiner
|
||||
"]",
|
||||
flags=re.UNICODE,
|
||||
)
|
||||
|
||||
|
||||
class XfyunTTSService(TTSService):
|
||||
"""iFlytek/Xfyun online TTS service for Pipecat.
|
||||
|
||||
Xfyun's API is not OpenAI-compatible. It uses a signed WebSocket URL,
|
||||
receives one JSON request per synthesis, and streams text WebSocket
|
||||
messages containing base64-encoded audio chunks. This service requests
|
||||
raw PCM so the chunks can become Pipecat audio frames without MP3 decode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
app_id: str,
|
||||
api_key: str,
|
||||
api_secret: str,
|
||||
voice: str,
|
||||
url: str | None = None,
|
||||
sample_rate: int = 16000,
|
||||
source_sample_rate: int = 16000,
|
||||
encoding: str = "raw",
|
||||
text_encoding: str = "UTF8",
|
||||
speed: int = 50,
|
||||
volume: int = 50,
|
||||
pitch: int = 50,
|
||||
timeout: float = 30.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
settings=TTSSettings(model=None, voice=voice, language=None),
|
||||
**kwargs,
|
||||
)
|
||||
self._app_id = app_id or os.environ.get("XFYUN_APP_ID", "")
|
||||
self._api_key = api_key or os.environ.get("XFYUN_API_KEY", "")
|
||||
self._api_secret = api_secret or os.environ.get("XFYUN_API_SECRET", "")
|
||||
self._voice = voice
|
||||
self._url = url or DEFAULT_XFYUN_TTS_URL
|
||||
self._source_sample_rate = source_sample_rate
|
||||
self._encoding = encoding
|
||||
self._text_encoding = text_encoding
|
||||
self._speed = speed
|
||||
self._volume = volume
|
||||
self._pitch = pitch
|
||||
self._timeout = timeout
|
||||
self._last_failure_detail: str | None = None
|
||||
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
if not text:
|
||||
return
|
||||
|
||||
if not self._app_id or not self._api_key or not self._api_secret:
|
||||
yield ErrorFrame(error="Xfyun TTS requires app_id, api_key, and api_secret")
|
||||
return
|
||||
|
||||
sanitized = _sanitize_text_for_tts(text)
|
||||
if not sanitized:
|
||||
logger.debug(
|
||||
f"{self}: skipping Xfyun TTS, text became empty after sanitization "
|
||||
f"(original={text!r})"
|
||||
)
|
||||
return
|
||||
|
||||
if sanitized != text:
|
||||
logger.debug(
|
||||
f"{self}: sanitized Xfyun TTS text "
|
||||
f"(original={text!r}, sanitized={sanitized!r})"
|
||||
)
|
||||
|
||||
if len(sanitized.encode("utf-8")) >= 8000:
|
||||
yield ErrorFrame(error="Xfyun TTS text must be less than 8000 UTF-8 bytes")
|
||||
return
|
||||
|
||||
if self._encoding != "raw":
|
||||
yield ErrorFrame(error="Xfyun TTS is configured for PCM output; set aue/encoding to raw")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.start_tts_usage_metrics(sanitized)
|
||||
|
||||
first_frame = True
|
||||
async for frame in self._stream_audio_frames_from_iterator(
|
||||
self._iter_audio_chunks(sanitized),
|
||||
in_sample_rate=self._source_sample_rate,
|
||||
context_id=context_id,
|
||||
):
|
||||
if first_frame:
|
||||
await self.stop_ttfb_metrics()
|
||||
first_frame = False
|
||||
yield frame
|
||||
|
||||
if first_frame:
|
||||
detail = self._last_failure_detail or "no audio frames received"
|
||||
yield ErrorFrame(
|
||||
error=(
|
||||
f"Xfyun TTS request finished without audio data ({detail}); "
|
||||
f"text={sanitized!r}"
|
||||
)
|
||||
)
|
||||
except Exception as exc:
|
||||
yield ErrorFrame(error=f"Xfyun TTS request failed: {exc}")
|
||||
|
||||
async def _iter_audio_chunks(self, text: str) -> AsyncIterator[bytes]:
|
||||
request = self._build_request_frame(text)
|
||||
auth_url = _build_auth_url(self._url, self._api_key, self._api_secret)
|
||||
|
||||
self._last_failure_detail = None
|
||||
frames_received = 0
|
||||
audio_bytes_received = 0
|
||||
last_status: int | None = None
|
||||
last_sid: str | None = None
|
||||
saw_status_2 = False
|
||||
|
||||
async with connect(auth_url, max_size=None, open_timeout=self._timeout) as websocket:
|
||||
await websocket.send(json.dumps(request, ensure_ascii=False))
|
||||
|
||||
async for raw_message in websocket:
|
||||
frames_received += 1
|
||||
payload = json.loads(raw_message)
|
||||
code = payload.get("code", -1)
|
||||
sid = payload.get("sid")
|
||||
if sid:
|
||||
last_sid = sid
|
||||
if code != 0:
|
||||
err_msg = payload.get("message", "unknown error")
|
||||
raise RuntimeError(f"code={code}, sid={sid}, message={err_msg}")
|
||||
|
||||
data = payload.get("data")
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
|
||||
last_status = data.get("status", last_status)
|
||||
|
||||
audio_b64 = data.get("audio")
|
||||
if audio_b64:
|
||||
audio_bytes = base64.b64decode(audio_b64)
|
||||
audio_bytes_received += len(audio_bytes)
|
||||
yield audio_bytes
|
||||
|
||||
if data.get("status") == 2:
|
||||
saw_status_2 = True
|
||||
break
|
||||
|
||||
if audio_bytes_received == 0:
|
||||
self._last_failure_detail = (
|
||||
f"frames={frames_received}, audio_bytes=0, "
|
||||
f"last_status={last_status}, saw_status_2={saw_status_2}, sid={last_sid}"
|
||||
)
|
||||
logger.warning(
|
||||
f"{self}: Xfyun TTS produced no audio ({self._last_failure_detail})"
|
||||
)
|
||||
|
||||
def _build_request_frame(self, text: str) -> dict[str, Any]:
|
||||
business: dict[str, Any] = {
|
||||
"aue": self._encoding,
|
||||
"auf": f"audio/L16;rate={self._source_sample_rate}",
|
||||
"vcn": self._voice,
|
||||
"speed": self._speed,
|
||||
"volume": self._volume,
|
||||
"pitch": self._pitch,
|
||||
"tte": self._text_encoding,
|
||||
}
|
||||
|
||||
return {
|
||||
"common": {"app_id": self._app_id},
|
||||
"business": business,
|
||||
"data": {
|
||||
"status": 2,
|
||||
"text": base64.b64encode(text.encode("utf-8")).decode("utf-8"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_text_for_tts(text: str) -> str:
|
||||
"""Strip characters Xfyun's online TTS cannot synthesize.
|
||||
|
||||
The Xfyun ``/v2/tts`` engine silently drops or rejects emoji, pictographs,
|
||||
dingbats, regional-indicator flags, variation selectors, and zero-width
|
||||
joiners. When such characters appear in the input the synthesis can
|
||||
finish without any audio data ("Xfyun TTS request finished without audio
|
||||
data"). We also drop control characters (other than common whitespace)
|
||||
and "Symbol, Other" codepoints, then collapse runs of whitespace.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
cleaned = _EMOJI_AND_SYMBOL_RE.sub("", text)
|
||||
filtered: list[str] = []
|
||||
for ch in cleaned:
|
||||
category = unicodedata.category(ch)
|
||||
if category == "So":
|
||||
continue
|
||||
if category.startswith("C") and ch not in ("\n", "\r", "\t"):
|
||||
continue
|
||||
filtered.append(ch)
|
||||
return re.sub(r"\s+", " ", "".join(filtered)).strip()
|
||||
|
||||
|
||||
def _build_auth_url(url: str, api_key: str, api_secret: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.netloc
|
||||
path = parsed.path or "/v2/tts"
|
||||
date = format_datetime(datetime.now(timezone.utc), usegmt=True)
|
||||
request_line = f"GET {path} HTTP/1.1"
|
||||
signature_origin = f"host: {host}\ndate: {date}\n{request_line}"
|
||||
signature_sha = hmac.new(
|
||||
api_secret.encode("utf-8"),
|
||||
signature_origin.encode("utf-8"),
|
||||
digestmod=hashlib.sha256,
|
||||
).digest()
|
||||
signature = base64.b64encode(signature_sha).decode("utf-8")
|
||||
authorization_origin = (
|
||||
f'api_key="{api_key}", algorithm="hmac-sha256", '
|
||||
f'headers="host date request-line", signature="{signature}"'
|
||||
)
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode("utf-8")
|
||||
query = urlencode({"authorization": authorization, "date": date, "host": host})
|
||||
return f"{url}?{query}"
|
||||
@@ -60,6 +60,7 @@ import { AuraVisualizer } from "@/components/ui/aura-visualizer";
|
||||
import { NebulaVisualizer } from "@/components/ui/nebula-visualizer";
|
||||
import { SpectrumVisualizer } from "@/components/ui/spectrum-visualizer";
|
||||
import { WaveVisualizer } from "@/components/ui/wave-visualizer";
|
||||
import { WaveformTimelinePanel } from "@/components/ui/waveform-timeline";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
@@ -1856,6 +1857,7 @@ function DebugVoicePanel({
|
||||
error,
|
||||
micWarning,
|
||||
localStream,
|
||||
remoteStream,
|
||||
messages,
|
||||
audioInputs,
|
||||
selectedDeviceId,
|
||||
@@ -2032,6 +2034,13 @@ function DebugVoicePanel({
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 底部双轨波形:用户麦克风 + 助手音频,可折叠 */}
|
||||
<WaveformTimelinePanel
|
||||
userStream={localStream}
|
||||
agentStream={remoteStream}
|
||||
active={status === "connected"}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -809,11 +809,18 @@ export function ComponentsModelsPage() {
|
||||
htmlFor="model-api-key"
|
||||
hint={{
|
||||
description:
|
||||
"访问模型服务的鉴权密钥,由服务商控制台生成,请妥善保管勿泄露。",
|
||||
example: "sk-xxxxxxxx",
|
||||
form.interfaceType === "xfyun"
|
||||
? "讯飞需要三段凭证,使用 JSON 存入现有 API Key 字段。"
|
||||
: "访问模型服务的鉴权密钥,由服务商控制台生成,请妥善保管勿泄露。",
|
||||
example:
|
||||
form.interfaceType === "xfyun"
|
||||
? '{"appId":"...","apiKey":"...","apiSecret":"..."}'
|
||||
: "sk-xxxxxxxx",
|
||||
}}
|
||||
>
|
||||
API Key
|
||||
{form.interfaceType === "xfyun"
|
||||
? "Xfyun Credential JSON"
|
||||
: "API Key"}
|
||||
</FieldLabel>
|
||||
{hasStoredApiKey && (
|
||||
<div className="mb-2 flex items-center gap-2 text-xs text-muted-foreground">
|
||||
@@ -832,7 +839,9 @@ export function ComponentsModelsPage() {
|
||||
placeholder={
|
||||
hasStoredApiKey
|
||||
? "已配置,留空则保持不变"
|
||||
: "请输入 API Key"
|
||||
: form.interfaceType === "xfyun"
|
||||
? '{"appId":"...","apiKey":"...","apiSecret":"..."}'
|
||||
: "请输入 API Key"
|
||||
}
|
||||
autoComplete="new-password"
|
||||
className="border-hairline-strong bg-background pr-10 text-foreground placeholder:text-muted-soft"
|
||||
@@ -852,6 +861,12 @@ export function ComponentsModelsPage() {
|
||||
仅显示当前密钥首尾用于识别。留空可保持原密钥,输入新值将覆盖原密钥。
|
||||
</p>
|
||||
)}
|
||||
{form.interfaceType === "xfyun" && (
|
||||
<p className="mt-2 text-xs leading-5 text-muted-foreground">
|
||||
ASR 的模型 ID 使用 iat;普通 TTS 使用 tts;超拟人 TTS 使用
|
||||
supertts。超拟人服务也可通过包含 /private/ 的 API URL 自动识别。
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
256
frontend/src/components/ui/waveform-timeline.tsx
Normal file
256
frontend/src/components/ui/waveform-timeline.tsx
Normal file
@@ -0,0 +1,256 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import { Activity, ChevronDown, ChevronUp } from "lucide-react";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useAudioAnalyser } from "@/hooks/use-audio-analyser";
|
||||
import {
|
||||
adaptPalette,
|
||||
isDarkTheme,
|
||||
readPalette,
|
||||
rgba,
|
||||
} from "@/lib/visualizer-palette";
|
||||
|
||||
/** 每格条形代表的音频时长(ms),决定时间轴滚动节奏 */
|
||||
const SAMPLE_MS = 50;
|
||||
/** 条形宽度/间距(px):滚动速度 = (BAR_WIDTH+BAR_GAP) * 1000/SAMPLE_MS px/s */
|
||||
const BAR_WIDTH = 2;
|
||||
const BAR_GAP = 1;
|
||||
const BAR_STEP = BAR_WIDTH + BAR_GAP;
|
||||
/** 历史保留上限:2 分钟,超出后丢最旧的样本 */
|
||||
const MAX_SAMPLES = (2 * 60 * 1000) / SAMPLE_MS;
|
||||
/** 时间刻度间隔(ms) */
|
||||
const TICK_MS = 5_000;
|
||||
/** 顶部时间轴高度(px) */
|
||||
const AXIS_HEIGHT = 16;
|
||||
|
||||
type History = {
|
||||
/** 每 SAMPLE_MS 一条的 RMS 强度(0~1),user/agent 等长同步推进 */
|
||||
user: number[];
|
||||
agent: number[];
|
||||
/** 因超出上限被丢弃的最旧样本数,用于换算样本对应的会话时间 */
|
||||
dropped: number;
|
||||
/** 上次采样的时间戳(performance.now) */
|
||||
lastSampleAt: number;
|
||||
};
|
||||
|
||||
function makeHistory(): History {
|
||||
return { user: [], agent: [], dropped: 0, lastSampleAt: 0 };
|
||||
}
|
||||
|
||||
/** 当前时域 RMS 强度(0~1);放大系数与 WaveVisualizer 一致,让小音量也可见 */
|
||||
function rmsLevel(node: AnalyserNode | null, buf: Uint8Array<ArrayBuffer>): number {
|
||||
if (!node) return 0;
|
||||
node.getByteTimeDomainData(buf);
|
||||
let sum = 0;
|
||||
for (let i = 0; i < node.fftSize; i++) {
|
||||
const d = (buf[i] - 128) / 128;
|
||||
sum += d * d;
|
||||
}
|
||||
return Math.min(1, Math.sqrt(sum / node.fftSize) * 3.2);
|
||||
}
|
||||
|
||||
/** 会话内毫秒 → m:ss 刻度文本 */
|
||||
function formatTick(ms: number): string {
|
||||
const total = Math.round(ms / 1000);
|
||||
const m = Math.floor(total / 60);
|
||||
const s = total % 60;
|
||||
return `${m}:${String(s).padStart(2, "0")}`;
|
||||
}
|
||||
|
||||
export type WaveformTimelineProps = {
|
||||
/** 用户麦克风流(本地) */
|
||||
userStream: MediaStream | null;
|
||||
/** 助手音频流(WebRTC 远端) */
|
||||
agentStream: MediaStream | null;
|
||||
/** 会话进行中才采样;结束后画面冻结,新会话开始时清空重来 */
|
||||
active: boolean;
|
||||
className?: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* 双轨波形时间轴:上轨「我」(麦克风)、下轨「助手」(远端音频),
|
||||
* 按固定节拍采样 RMS 音量,最新样本贴右缘向左滚动,顶部带 m:ss 时间刻度。
|
||||
* 配色取自设计 token(--gradient-*),自动跟随明暗主题。
|
||||
*/
|
||||
export function WaveformTimeline({
|
||||
userStream,
|
||||
agentStream,
|
||||
active,
|
||||
className,
|
||||
}: WaveformTimelineProps) {
|
||||
const canvasRef = React.useRef<HTMLCanvasElement>(null);
|
||||
const historyRef = React.useRef<History>(makeHistory());
|
||||
const activeRef = React.useRef(active);
|
||||
|
||||
// active 传 stream 是否存在,避免 useAudioAnalyser 在缺流时去申请麦克风
|
||||
const userAnalyserRef = useAudioAnalyser({
|
||||
active: active && Boolean(userStream),
|
||||
stream: userStream,
|
||||
smoothingTimeConstant: 0.5,
|
||||
});
|
||||
const agentAnalyserRef = useAudioAnalyser({
|
||||
active: active && Boolean(agentStream),
|
||||
stream: agentStream,
|
||||
smoothingTimeConstant: 0.5,
|
||||
});
|
||||
|
||||
// 新会话开始时清空上一轮历史
|
||||
React.useEffect(() => {
|
||||
activeRef.current = active;
|
||||
if (active) {
|
||||
historyRef.current = makeHistory();
|
||||
}
|
||||
}, [active]);
|
||||
|
||||
React.useEffect(() => {
|
||||
const canvas = canvasRef.current;
|
||||
if (!canvas) return;
|
||||
const ctx = canvas.getContext("2d");
|
||||
if (!ctx) return;
|
||||
|
||||
const timeBuf = new Uint8Array(2048);
|
||||
let raf = 0;
|
||||
|
||||
const draw = () => {
|
||||
raf = requestAnimationFrame(draw);
|
||||
const w = canvas.clientWidth;
|
||||
const h = canvas.clientHeight;
|
||||
if (!w || !h) return;
|
||||
|
||||
const dpr = Math.min(window.devicePixelRatio || 1, 2);
|
||||
if (
|
||||
canvas.width !== Math.round(w * dpr) ||
|
||||
canvas.height !== Math.round(h * dpr)
|
||||
) {
|
||||
canvas.width = Math.round(w * dpr);
|
||||
canvas.height = Math.round(h * dpr);
|
||||
}
|
||||
ctx.setTransform(dpr, 0, 0, dpr, 0, 0);
|
||||
ctx.clearRect(0, 0, w, h);
|
||||
|
||||
// 采样:按固定节拍推入历史,帧率波动时补齐;长时间空窗(面板折叠)则跳过
|
||||
const hist = historyRef.current;
|
||||
if (activeRef.current) {
|
||||
const now = performance.now();
|
||||
if (now - hist.lastSampleAt > 1000) hist.lastSampleAt = now;
|
||||
while (now - hist.lastSampleAt >= SAMPLE_MS) {
|
||||
hist.lastSampleAt += SAMPLE_MS;
|
||||
hist.user.push(rmsLevel(userAnalyserRef.current, timeBuf));
|
||||
hist.agent.push(rmsLevel(agentAnalyserRef.current, timeBuf));
|
||||
if (hist.user.length > MAX_SAMPLES) {
|
||||
hist.user.shift();
|
||||
hist.agent.shift();
|
||||
hist.dropped += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const palette = adaptPalette(readPalette(canvas), isDarkTheme());
|
||||
const textColor = getComputedStyle(canvas).color;
|
||||
const rowH = (h - AXIS_HEIGHT) / 2;
|
||||
const n = hist.user.length;
|
||||
const ticksEvery = TICK_MS / SAMPLE_MS;
|
||||
|
||||
ctx.font = '10px "Inter", system-ui, sans-serif';
|
||||
ctx.textBaseline = "middle";
|
||||
|
||||
// 时间刻度:竖向网格线 + 顶部 m:ss 标签
|
||||
ctx.textAlign = "center";
|
||||
for (let i = 0; i < n; i++) {
|
||||
const sampleIndex = hist.dropped + i;
|
||||
if (sampleIndex % ticksEvery !== 0) continue;
|
||||
const x = w - (n - i) * BAR_STEP;
|
||||
if (x < 0) continue;
|
||||
ctx.fillStyle = textColor;
|
||||
ctx.globalAlpha = 0.12;
|
||||
ctx.fillRect(x, AXIS_HEIGHT, 1, h - AXIS_HEIGHT);
|
||||
ctx.globalAlpha = 0.75;
|
||||
ctx.fillText(formatTick(sampleIndex * SAMPLE_MS), Math.max(14, x), AXIS_HEIGHT / 2);
|
||||
}
|
||||
|
||||
const rows = [
|
||||
{ label: "我", levels: hist.user, color: palette.sky },
|
||||
{ label: "助手", levels: hist.agent, color: palette.lav },
|
||||
];
|
||||
|
||||
rows.forEach((row, r) => {
|
||||
const cy = AXIS_HEIGHT + rowH * r + rowH / 2;
|
||||
|
||||
// 中线
|
||||
ctx.globalAlpha = 1;
|
||||
ctx.fillStyle = rgba(row.color, 0.28);
|
||||
ctx.fillRect(0, cy - 0.5, w, 1);
|
||||
|
||||
// 音量条:最新样本贴右缘,向左回溯到画布边界为止
|
||||
ctx.fillStyle = rgba(row.color, 0.9);
|
||||
const maxBarH = rowH * 0.86;
|
||||
for (let i = n - 1; i >= 0; i--) {
|
||||
const x = w - (n - i) * BAR_STEP;
|
||||
if (x + BAR_WIDTH < 0) break;
|
||||
const bh = Math.max(1.5, row.levels[i] * maxBarH);
|
||||
ctx.fillRect(x, cy - bh / 2, BAR_WIDTH, bh);
|
||||
}
|
||||
|
||||
// 轨道标签
|
||||
ctx.globalAlpha = 0.85;
|
||||
ctx.fillStyle = textColor;
|
||||
ctx.textAlign = "left";
|
||||
ctx.fillText(row.label, 8, cy);
|
||||
ctx.textAlign = "center";
|
||||
});
|
||||
|
||||
ctx.globalAlpha = 1;
|
||||
};
|
||||
|
||||
raf = requestAnimationFrame(draw);
|
||||
return () => cancelAnimationFrame(raf);
|
||||
}, [userAnalyserRef, agentAnalyserRef]);
|
||||
|
||||
return (
|
||||
<canvas
|
||||
ref={canvasRef}
|
||||
role="img"
|
||||
aria-label="用户与助手语音波形时间轴"
|
||||
className={cn("block select-none text-muted-foreground", className)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export type WaveformTimelinePanelProps = WaveformTimelineProps & {
|
||||
/** 初始是否展开 */
|
||||
defaultOpen?: boolean;
|
||||
};
|
||||
|
||||
/** 可折叠的波形底栏:头部一行可点击展开/收起,展开后显示双轨时间轴 */
|
||||
export function WaveformTimelinePanel({
|
||||
defaultOpen = true,
|
||||
className,
|
||||
...timeline
|
||||
}: WaveformTimelinePanelProps) {
|
||||
const [open, setOpen] = React.useState(defaultOpen);
|
||||
|
||||
return (
|
||||
<div className={cn("shrink-0 border-t border-hairline", className)}>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setOpen((v) => !v)}
|
||||
aria-expanded={open}
|
||||
className="flex h-9 w-full items-center gap-2 px-4 text-xs font-medium text-muted-foreground transition-colors hover:text-foreground"
|
||||
>
|
||||
<Activity size={13} className="shrink-0" />
|
||||
波形监视
|
||||
<span className="ml-auto flex h-5 w-5 items-center justify-center text-muted-soft">
|
||||
{open ? <ChevronDown size={14} /> : <ChevronUp size={14} />}
|
||||
</span>
|
||||
</button>
|
||||
|
||||
{open && (
|
||||
<div className="h-28 px-3 pb-3">
|
||||
<WaveformTimeline {...timeline} className="h-full w-full" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -74,6 +74,8 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [micWarning, setMicWarning] = useState<string | null>(null);
|
||||
const [localStream, setLocalStream] = useState<MediaStream | null>(null);
|
||||
// 远端(助手 TTS)媒体流:除挂到 <audio> 播放外,也暴露给波形可视化
|
||||
const [remoteStream, setRemoteStream] = useState<MediaStream | null>(null);
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||
// 可选麦克风列表与当前选择(空串表示交给浏览器选默认设备)
|
||||
const [audioInputs, setAudioInputs] = useState<MediaDeviceInfo[]>([]);
|
||||
@@ -92,7 +94,8 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
if (!navigator.mediaDevices?.enumerateDevices) return;
|
||||
try {
|
||||
const devices = await navigator.mediaDevices.enumerateDevices();
|
||||
const inputs = devices.filter((d) => d.kind === "audioinput");
|
||||
// 未授权时设备可能没有 deviceId(空串),无法选择,直接过滤掉
|
||||
const inputs = devices.filter((d) => d.kind === "audioinput" && d.deviceId);
|
||||
setAudioInputs(inputs);
|
||||
// 选中的设备已被拔出时,回退到浏览器默认设备
|
||||
if (
|
||||
@@ -157,6 +160,7 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
const disconnect = useCallback(() => {
|
||||
releaseResources();
|
||||
setLocalStream(null);
|
||||
setRemoteStream(null);
|
||||
setError(null);
|
||||
setMicWarning(null);
|
||||
setStatus("idle");
|
||||
@@ -166,6 +170,7 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
(message: string) => {
|
||||
releaseResources();
|
||||
setLocalStream(null);
|
||||
setRemoteStream(null);
|
||||
setError(message);
|
||||
setStatus("failed");
|
||||
},
|
||||
@@ -312,9 +317,11 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
};
|
||||
|
||||
pc.ontrack = (e) => {
|
||||
if (e.track.kind === "audio" && audioRef.current) {
|
||||
audioRef.current.srcObject =
|
||||
e.streams[0] ?? new MediaStream([e.track]);
|
||||
if (e.track.kind !== "audio") return;
|
||||
const remote = e.streams[0] ?? new MediaStream([e.track]);
|
||||
setRemoteStream(remote);
|
||||
if (audioRef.current) {
|
||||
audioRef.current.srcObject = remote;
|
||||
void audioRef.current.play().catch(() => {});
|
||||
}
|
||||
};
|
||||
@@ -384,6 +391,7 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
error,
|
||||
micWarning,
|
||||
localStream,
|
||||
remoteStream,
|
||||
messages,
|
||||
audioInputs,
|
||||
selectedDeviceId,
|
||||
|
||||
Reference in New Issue
Block a user