Files
engine-v5-pipecat-core/engine/xfyun_asr.py
2026-05-22 08:37:31 +08:00

354 lines
12 KiB
Python

from __future__ import annotations
import asyncio
import base64
import hashlib
import hmac
import json
import os
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from email.utils import format_datetime
from typing import Any
from urllib.parse import urlencode, urlparse
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
TranscriptionFrame,
UserStoppedSpeakingFrame,
VADUserStartedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import STTSettings
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from websockets.asyncio.client import connect as websocket_connect
from websockets.protocol import State
DEFAULT_XFYUN_ASR_URL = "wss://iat-api.xfyun.cn/v2/iat"
class XfyunASRService(STTService):
"""iFlytek/Xfyun streaming voice dictation service for Pipecat."""
def __init__(
self,
*,
app_id: str,
api_key: str,
api_secret: str,
url: str | None = None,
language: str = "zh_cn",
domain: str = "iat",
accent: str = "mandarin",
sample_rate: int = 16000,
encoding: str = "raw",
frame_size: int = 1280,
open_timeout: float = 10.0,
dynamic_correction: bool = False,
**kwargs,
) -> None:
super().__init__(
sample_rate=sample_rate,
settings=STTSettings(model=None, language=language),
**kwargs,
)
self._app_id = app_id or os.environ.get("XFYUN_APP_ID", "")
self._api_key = api_key or os.environ.get("XFYUN_API_KEY", "")
self._api_secret = api_secret or os.environ.get("XFYUN_API_SECRET", "")
self._url = url or DEFAULT_XFYUN_ASR_URL
self._language = language
self._domain = domain
self._accent = accent
self._encoding = encoding
self._frame_size = frame_size
self._open_timeout = open_timeout
self._dynamic_correction = dynamic_correction
self._websocket = None
self._receive_task = None
self._audio_buffer = bytearray()
self._sent_first_frame = False
self._sent_final_frame = False
self._finalizing_turn = False
self._partials: list[str] = []
self._last_text = ""
async def cleanup(self) -> None:
await self._close_utterance()
await super().cleanup()
async def stop(self, frame: EndFrame) -> None:
await self._close_utterance()
await super().stop(frame)
async def cancel(self, frame: CancelFrame) -> None:
await self._close_utterance()
await super().cancel(frame)
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
await super().process_frame(frame, direction)
if isinstance(frame, UserStoppedSpeakingFrame):
# Aggregator-level turn end (broadcast once per logical user turn).
# This is the only boundary that finalizes/closes the xfyun
# websocket, so brief VAD pauses do not restart the ASR session.
await self._finish_utterance()
elif isinstance(frame, VADUserStartedSpeakingFrame):
await self._start_utterance()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
if not audio:
yield None
return
if not self._websocket or self._websocket.state is not State.OPEN:
await self._start_utterance()
self._audio_buffer.extend(audio)
await self._flush_audio_buffer(final=False)
yield None
async def _start_utterance(self) -> None:
if self._websocket and self._websocket.state is State.OPEN:
return
if not self._app_id or not self._api_key or not self._api_secret:
await self.push_error("Xfyun ASR requires app_id, api_key, and api_secret")
return
if self.sample_rate not in (8000, 16000):
await self.push_error("Xfyun ASR sample rate must be 8000 or 16000")
return
self._audio_buffer.clear()
self._partials = []
self._last_text = ""
self._sent_first_frame = False
self._sent_final_frame = False
auth_url = _build_auth_url(self._url, self._api_key, self._api_secret)
try:
self._websocket = await websocket_connect(
auth_url,
max_size=None,
open_timeout=self._open_timeout,
)
except Exception as exc:
await self.push_error(f"Xfyun ASR connection failed: {exc}", exception=exc)
self._websocket = None
return
self._receive_task = self.create_task(
self._receive_messages(),
name="xfyun_asr_receive",
)
async def _finish_utterance(self) -> None:
if not self._websocket or self._websocket.state is not State.OPEN:
return
await self._flush_audio_buffer(final=True)
if not self._sent_first_frame:
await self._close_utterance()
return
if not self._sent_final_frame:
self._finalizing_turn = True
await self._send_payload({"data": {"status": 2}})
self.request_finalize()
self._sent_final_frame = True
async def _close_utterance(self) -> None:
current_task = asyncio.current_task()
if self._receive_task and self._receive_task is not current_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
websocket = self._websocket
self._websocket = None
if websocket and websocket.state is State.OPEN:
try:
await websocket.close()
except Exception:
pass
self._audio_buffer.clear()
self._sent_first_frame = False
self._sent_final_frame = False
self._finalizing_turn = False
async def _flush_audio_buffer(self, *, final: bool) -> None:
while len(self._audio_buffer) >= self._frame_size:
chunk = bytes(self._audio_buffer[: self._frame_size])
del self._audio_buffer[: self._frame_size]
await self._send_audio_chunk(chunk, status=1)
if final and self._audio_buffer:
chunk = bytes(self._audio_buffer)
self._audio_buffer.clear()
await self._send_audio_chunk(chunk, status=1)
async def _send_audio_chunk(self, audio: bytes, *, status: int) -> None:
if not audio:
return
if not self._sent_first_frame:
business = {
"language": self._language,
"domain": self._domain,
"accent": self._accent,
}
if self._dynamic_correction:
business["dwa"] = "wpgs"
payload = {
"common": {"app_id": self._app_id},
"business": business,
"data": {
"status": 0,
"format": f"audio/L16;rate={self.sample_rate}",
"encoding": self._encoding,
"audio": base64.b64encode(audio).decode("utf-8"),
},
}
self._sent_first_frame = True
else:
payload = {
"data": {
"status": status,
"format": f"audio/L16;rate={self.sample_rate}",
"encoding": self._encoding,
"audio": base64.b64encode(audio).decode("utf-8"),
}
}
await self._send_payload(payload)
async def _send_payload(self, payload: dict[str, Any]) -> None:
if not self._websocket or self._websocket.state is not State.OPEN:
return
await self._websocket.send(json.dumps(payload, ensure_ascii=False))
async def _receive_messages(self) -> None:
websocket = self._websocket
if not websocket:
return
try:
async for message in websocket:
await self._process_response(json.loads(message))
except Exception as exc:
if self._websocket is websocket:
await self.push_error(f"Xfyun ASR receive failed: {exc}", exception=exc)
finally:
if self._websocket is websocket:
self._websocket = None
self._receive_task = None
async def _process_response(self, payload: dict[str, Any]) -> None:
code = payload.get("code", -1)
if code != 0:
message = payload.get("message", "unknown error")
sid = payload.get("sid")
await self.push_error(f"Xfyun ASR error code={code}, sid={sid}, message={message}")
return
data = payload.get("data")
if not isinstance(data, dict):
return
is_final_response = data.get("status") == 2
recognition = data.get("result")
if isinstance(recognition, dict):
text = self._apply_recognition_result(recognition)
if text and text != self._last_text:
self._last_text = text
if not self._finalizing_turn and not is_final_response:
await self.push_frame(
InterimTranscriptionFrame(
text,
self._user_id,
time_now_iso8601(),
_language_or_none(self._language),
result=payload,
)
)
if is_final_response:
final_text = self._last_text
if final_text:
self.confirm_finalize()
await self.push_frame(
TranscriptionFrame(
final_text,
self._user_id,
time_now_iso8601(),
_language_or_none(self._language),
result=payload,
)
)
await self._close_utterance()
def _apply_recognition_result(self, recognition: dict[str, Any]) -> str:
partial = _extract_text_from_result(recognition)
if not partial:
return self._last_text
if self._dynamic_correction and recognition.get("pgs") == "rpl" and recognition.get("rg"):
start, end = recognition["rg"]
if 1 <= start <= len(self._partials):
self._partials[start - 1 : end] = [partial]
else:
logger.debug(f"Ignoring out-of-range Xfyun replacement rg={recognition['rg']}")
else:
self._partials.append(partial)
return "".join(self._partials)
def _extract_text_from_result(result: dict[str, Any]) -> str:
words: list[str] = []
for item in result.get("ws", []):
for candidate in item.get("cw", []):
word = candidate.get("w")
if word:
words.append(word)
return "".join(words)
def _build_auth_url(url: str, api_key: str, api_secret: str) -> str:
parsed = urlparse(url)
host = parsed.netloc
path = parsed.path or "/v2/iat"
date = format_datetime(datetime.now(timezone.utc), usegmt=True)
request_line = f"GET {path} HTTP/1.1"
signature_origin = f"host: {host}\ndate: {date}\n{request_line}"
signature_sha = hmac.new(
api_secret.encode("utf-8"),
signature_origin.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()
signature = base64.b64encode(signature_sha).decode("utf-8")
authorization_origin = (
f'api_key="{api_key}", algorithm="hmac-sha256", '
f'headers="host date request-line", signature="{signature}"'
)
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode("utf-8")
query = urlencode({"authorization": authorization, "date": date, "host": host})
return f"{url}?{query}"
def _language_or_none(value: str) -> Language | None:
try:
return Language(value)
except ValueError:
return None