- Introduce new Xfyun ASR and TTS services, enabling integration with iFlytek's voice recognition and synthesis capabilities. - Update AssistantConfig model to include interface types for STT and TTS. - Enhance credential testing to validate Xfyun credentials. - Modify service factory to create Xfyun services based on configuration. - Update README with new configuration details for Xfyun integration. - Add new frontend components for visualizing audio streams and managing user interactions.
354 lines
12 KiB
Python
354 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import os
|
|
from collections.abc import AsyncGenerator
|
|
from datetime import datetime, timezone
|
|
from email.utils import format_datetime
|
|
from typing import Any
|
|
from urllib.parse import urlencode, urlparse
|
|
|
|
from loguru import logger
|
|
|
|
from pipecat.frames.frames import (
|
|
CancelFrame,
|
|
EndFrame,
|
|
Frame,
|
|
InterimTranscriptionFrame,
|
|
TranscriptionFrame,
|
|
UserStoppedSpeakingFrame,
|
|
VADUserStartedSpeakingFrame,
|
|
)
|
|
from pipecat.processors.frame_processor import FrameDirection
|
|
from pipecat.services.settings import STTSettings
|
|
from pipecat.services.stt_service import STTService
|
|
from pipecat.transcriptions.language import Language
|
|
from pipecat.utils.time import time_now_iso8601
|
|
from websockets.asyncio.client import connect as websocket_connect
|
|
from websockets.protocol import State
|
|
|
|
|
|
DEFAULT_XFYUN_ASR_URL = "wss://iat-api.xfyun.cn/v2/iat"
|
|
|
|
|
|
class XfyunASRService(STTService):
|
|
"""iFlytek/Xfyun streaming voice dictation service for Pipecat."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
app_id: str,
|
|
api_key: str,
|
|
api_secret: str,
|
|
url: str | None = None,
|
|
language: str = "zh_cn",
|
|
domain: str = "iat",
|
|
accent: str = "mandarin",
|
|
sample_rate: int = 16000,
|
|
encoding: str = "raw",
|
|
frame_size: int = 1280,
|
|
open_timeout: float = 10.0,
|
|
dynamic_correction: bool = False,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(
|
|
sample_rate=sample_rate,
|
|
settings=STTSettings(model=None, language=language),
|
|
**kwargs,
|
|
)
|
|
self._app_id = app_id or os.environ.get("XFYUN_APP_ID", "")
|
|
self._api_key = api_key or os.environ.get("XFYUN_API_KEY", "")
|
|
self._api_secret = api_secret or os.environ.get("XFYUN_API_SECRET", "")
|
|
self._url = url or DEFAULT_XFYUN_ASR_URL
|
|
self._language = language
|
|
self._domain = domain
|
|
self._accent = accent
|
|
self._encoding = encoding
|
|
self._frame_size = frame_size
|
|
self._open_timeout = open_timeout
|
|
self._dynamic_correction = dynamic_correction
|
|
|
|
self._websocket = None
|
|
self._receive_task = None
|
|
self._audio_buffer = bytearray()
|
|
self._sent_first_frame = False
|
|
self._sent_final_frame = False
|
|
self._finalizing_turn = False
|
|
self._partials: list[str] = []
|
|
self._last_text = ""
|
|
|
|
async def cleanup(self) -> None:
|
|
await self._close_utterance()
|
|
await super().cleanup()
|
|
|
|
async def stop(self, frame: EndFrame) -> None:
|
|
await self._close_utterance()
|
|
await super().stop(frame)
|
|
|
|
async def cancel(self, frame: CancelFrame) -> None:
|
|
await self._close_utterance()
|
|
await super().cancel(frame)
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
|
|
await super().process_frame(frame, direction)
|
|
|
|
if isinstance(frame, UserStoppedSpeakingFrame):
|
|
# Aggregator-level turn end (broadcast once per logical user turn).
|
|
# This is the only boundary that finalizes/closes the xfyun
|
|
# websocket, so brief VAD pauses do not restart the ASR session.
|
|
await self._finish_utterance()
|
|
elif isinstance(frame, VADUserStartedSpeakingFrame):
|
|
await self._start_utterance()
|
|
|
|
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
|
|
if not audio:
|
|
yield None
|
|
return
|
|
|
|
if not self._websocket or self._websocket.state is not State.OPEN:
|
|
await self._start_utterance()
|
|
|
|
self._audio_buffer.extend(audio)
|
|
await self._flush_audio_buffer(final=False)
|
|
yield None
|
|
|
|
async def _start_utterance(self) -> None:
|
|
if self._websocket and self._websocket.state is State.OPEN:
|
|
return
|
|
|
|
if not self._app_id or not self._api_key or not self._api_secret:
|
|
await self.push_error("Xfyun ASR requires app_id, api_key, and api_secret")
|
|
return
|
|
|
|
if self.sample_rate not in (8000, 16000):
|
|
await self.push_error("Xfyun ASR sample rate must be 8000 or 16000")
|
|
return
|
|
|
|
self._audio_buffer.clear()
|
|
self._partials = []
|
|
self._last_text = ""
|
|
self._sent_first_frame = False
|
|
self._sent_final_frame = False
|
|
|
|
auth_url = _build_auth_url(self._url, self._api_key, self._api_secret)
|
|
try:
|
|
self._websocket = await websocket_connect(
|
|
auth_url,
|
|
max_size=None,
|
|
open_timeout=self._open_timeout,
|
|
)
|
|
except Exception as exc:
|
|
await self.push_error(f"Xfyun ASR connection failed: {exc}", exception=exc)
|
|
self._websocket = None
|
|
return
|
|
|
|
self._receive_task = self.create_task(
|
|
self._receive_messages(),
|
|
name="xfyun_asr_receive",
|
|
)
|
|
|
|
async def _finish_utterance(self) -> None:
|
|
if not self._websocket or self._websocket.state is not State.OPEN:
|
|
return
|
|
|
|
await self._flush_audio_buffer(final=True)
|
|
if not self._sent_first_frame:
|
|
await self._close_utterance()
|
|
return
|
|
|
|
if not self._sent_final_frame:
|
|
self._finalizing_turn = True
|
|
await self._send_payload({"data": {"status": 2}})
|
|
self.request_finalize()
|
|
self._sent_final_frame = True
|
|
|
|
async def _close_utterance(self) -> None:
|
|
current_task = asyncio.current_task()
|
|
if self._receive_task and self._receive_task is not current_task:
|
|
await self.cancel_task(self._receive_task)
|
|
self._receive_task = None
|
|
|
|
websocket = self._websocket
|
|
self._websocket = None
|
|
if websocket and websocket.state is State.OPEN:
|
|
try:
|
|
await websocket.close()
|
|
except Exception:
|
|
pass
|
|
|
|
self._audio_buffer.clear()
|
|
self._sent_first_frame = False
|
|
self._sent_final_frame = False
|
|
self._finalizing_turn = False
|
|
|
|
async def _flush_audio_buffer(self, *, final: bool) -> None:
|
|
while len(self._audio_buffer) >= self._frame_size:
|
|
chunk = bytes(self._audio_buffer[: self._frame_size])
|
|
del self._audio_buffer[: self._frame_size]
|
|
await self._send_audio_chunk(chunk, status=1)
|
|
|
|
if final and self._audio_buffer:
|
|
chunk = bytes(self._audio_buffer)
|
|
self._audio_buffer.clear()
|
|
await self._send_audio_chunk(chunk, status=1)
|
|
|
|
async def _send_audio_chunk(self, audio: bytes, *, status: int) -> None:
|
|
if not audio:
|
|
return
|
|
|
|
if not self._sent_first_frame:
|
|
business = {
|
|
"language": self._language,
|
|
"domain": self._domain,
|
|
"accent": self._accent,
|
|
}
|
|
if self._dynamic_correction:
|
|
business["dwa"] = "wpgs"
|
|
|
|
payload = {
|
|
"common": {"app_id": self._app_id},
|
|
"business": business,
|
|
"data": {
|
|
"status": 0,
|
|
"format": f"audio/L16;rate={self.sample_rate}",
|
|
"encoding": self._encoding,
|
|
"audio": base64.b64encode(audio).decode("utf-8"),
|
|
},
|
|
}
|
|
self._sent_first_frame = True
|
|
else:
|
|
payload = {
|
|
"data": {
|
|
"status": status,
|
|
"format": f"audio/L16;rate={self.sample_rate}",
|
|
"encoding": self._encoding,
|
|
"audio": base64.b64encode(audio).decode("utf-8"),
|
|
}
|
|
}
|
|
|
|
await self._send_payload(payload)
|
|
|
|
async def _send_payload(self, payload: dict[str, Any]) -> None:
|
|
if not self._websocket or self._websocket.state is not State.OPEN:
|
|
return
|
|
await self._websocket.send(json.dumps(payload, ensure_ascii=False))
|
|
|
|
async def _receive_messages(self) -> None:
|
|
websocket = self._websocket
|
|
if not websocket:
|
|
return
|
|
|
|
try:
|
|
async for message in websocket:
|
|
await self._process_response(json.loads(message))
|
|
except Exception as exc:
|
|
if self._websocket is websocket:
|
|
await self.push_error(f"Xfyun ASR receive failed: {exc}", exception=exc)
|
|
finally:
|
|
if self._websocket is websocket:
|
|
self._websocket = None
|
|
self._receive_task = None
|
|
|
|
async def _process_response(self, payload: dict[str, Any]) -> None:
|
|
code = payload.get("code", -1)
|
|
if code != 0:
|
|
message = payload.get("message", "unknown error")
|
|
sid = payload.get("sid")
|
|
await self.push_error(f"Xfyun ASR error code={code}, sid={sid}, message={message}")
|
|
return
|
|
|
|
data = payload.get("data")
|
|
if not isinstance(data, dict):
|
|
return
|
|
|
|
is_final_response = data.get("status") == 2
|
|
recognition = data.get("result")
|
|
if isinstance(recognition, dict):
|
|
text = self._apply_recognition_result(recognition)
|
|
if text and text != self._last_text:
|
|
self._last_text = text
|
|
if not self._finalizing_turn and not is_final_response:
|
|
await self.push_frame(
|
|
InterimTranscriptionFrame(
|
|
text,
|
|
self._user_id,
|
|
time_now_iso8601(),
|
|
_language_or_none(self._language),
|
|
result=payload,
|
|
)
|
|
)
|
|
|
|
if is_final_response:
|
|
final_text = self._last_text
|
|
if final_text:
|
|
self.confirm_finalize()
|
|
await self.push_frame(
|
|
TranscriptionFrame(
|
|
final_text,
|
|
self._user_id,
|
|
time_now_iso8601(),
|
|
_language_or_none(self._language),
|
|
result=payload,
|
|
)
|
|
)
|
|
await self._close_utterance()
|
|
|
|
def _apply_recognition_result(self, recognition: dict[str, Any]) -> str:
|
|
partial = _extract_text_from_result(recognition)
|
|
if not partial:
|
|
return self._last_text
|
|
|
|
if self._dynamic_correction and recognition.get("pgs") == "rpl" and recognition.get("rg"):
|
|
start, end = recognition["rg"]
|
|
if 1 <= start <= len(self._partials):
|
|
self._partials[start - 1 : end] = [partial]
|
|
else:
|
|
logger.debug(f"Ignoring out-of-range Xfyun replacement rg={recognition['rg']}")
|
|
else:
|
|
self._partials.append(partial)
|
|
|
|
return "".join(self._partials)
|
|
|
|
|
|
def _extract_text_from_result(result: dict[str, Any]) -> str:
|
|
words: list[str] = []
|
|
for item in result.get("ws", []):
|
|
for candidate in item.get("cw", []):
|
|
word = candidate.get("w")
|
|
if word:
|
|
words.append(word)
|
|
return "".join(words)
|
|
|
|
|
|
def _build_auth_url(url: str, api_key: str, api_secret: str) -> str:
|
|
parsed = urlparse(url)
|
|
host = parsed.netloc
|
|
path = parsed.path or "/v2/iat"
|
|
date = format_datetime(datetime.now(timezone.utc), usegmt=True)
|
|
request_line = f"GET {path} HTTP/1.1"
|
|
signature_origin = f"host: {host}\ndate: {date}\n{request_line}"
|
|
signature_sha = hmac.new(
|
|
api_secret.encode("utf-8"),
|
|
signature_origin.encode("utf-8"),
|
|
digestmod=hashlib.sha256,
|
|
).digest()
|
|
signature = base64.b64encode(signature_sha).decode("utf-8")
|
|
authorization_origin = (
|
|
f'api_key="{api_key}", algorithm="hmac-sha256", '
|
|
f'headers="host date request-line", signature="{signature}"'
|
|
)
|
|
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode("utf-8")
|
|
query = urlencode({"authorization": authorization, "date": date, "host": host})
|
|
return f"{url}?{query}"
|
|
|
|
|
|
def _language_or_none(value: str) -> Language | None:
|
|
try:
|
|
return Language(value)
|
|
except ValueError:
|
|
return None
|