from __future__ import annotations import asyncio import base64 import hashlib import hmac import json import os from collections.abc import AsyncGenerator from datetime import datetime, timezone from email.utils import format_datetime from typing import Any from urllib.parse import urlencode, urlparse from loguru import logger from pipecat.frames.frames import ( CancelFrame, EndFrame, Frame, InterimTranscriptionFrame, TranscriptionFrame, UserStoppedSpeakingFrame, VADUserStartedSpeakingFrame, ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.settings import STTSettings from pipecat.services.stt_service import STTService from pipecat.transcriptions.language import Language from pipecat.utils.time import time_now_iso8601 from websockets.asyncio.client import connect as websocket_connect from websockets.protocol import State DEFAULT_XFYUN_ASR_URL = "wss://iat-api.xfyun.cn/v2/iat" class XfyunASRService(STTService): """iFlytek/Xfyun streaming voice dictation service for Pipecat.""" def __init__( self, *, app_id: str, api_key: str, api_secret: str, url: str | None = None, language: str = "zh_cn", domain: str = "iat", accent: str = "mandarin", sample_rate: int = 16000, encoding: str = "raw", frame_size: int = 1280, open_timeout: float = 10.0, dynamic_correction: bool = False, **kwargs, ) -> None: super().__init__( sample_rate=sample_rate, settings=STTSettings(model=None, language=language), **kwargs, ) self._app_id = app_id or os.environ.get("XFYUN_APP_ID", "") self._api_key = api_key or os.environ.get("XFYUN_API_KEY", "") self._api_secret = api_secret or os.environ.get("XFYUN_API_SECRET", "") self._url = url or DEFAULT_XFYUN_ASR_URL self._language = language self._domain = domain self._accent = accent self._encoding = encoding self._frame_size = frame_size self._open_timeout = open_timeout self._dynamic_correction = dynamic_correction self._websocket = None self._receive_task = None self._audio_buffer = bytearray() self._sent_first_frame = False self._sent_final_frame = False self._finalizing_turn = False self._partials: list[str] = [] self._last_text = "" async def cleanup(self) -> None: await self._close_utterance() await super().cleanup() async def stop(self, frame: EndFrame) -> None: await self._close_utterance() await super().stop(frame) async def cancel(self, frame: CancelFrame) -> None: await self._close_utterance() await super().cancel(frame) async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: await super().process_frame(frame, direction) if isinstance(frame, UserStoppedSpeakingFrame): # Aggregator-level turn end (broadcast once per logical user turn). # This is the only boundary that finalizes/closes the xfyun # websocket, so brief VAD pauses do not restart the ASR session. await self._finish_utterance() elif isinstance(frame, VADUserStartedSpeakingFrame): await self._start_utterance() async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]: if not audio: yield None return if not self._websocket or self._websocket.state is not State.OPEN: await self._start_utterance() self._audio_buffer.extend(audio) await self._flush_audio_buffer(final=False) yield None async def _start_utterance(self) -> None: if self._websocket and self._websocket.state is State.OPEN: return if not self._app_id or not self._api_key or not self._api_secret: await self.push_error("Xfyun ASR requires app_id, api_key, and api_secret") return if self.sample_rate not in (8000, 16000): await self.push_error("Xfyun ASR sample rate must be 8000 or 16000") return self._audio_buffer.clear() self._partials = [] self._last_text = "" self._sent_first_frame = False self._sent_final_frame = False auth_url = _build_auth_url(self._url, self._api_key, self._api_secret) try: self._websocket = await websocket_connect( auth_url, max_size=None, open_timeout=self._open_timeout, ) except Exception as exc: await self.push_error(f"Xfyun ASR connection failed: {exc}", exception=exc) self._websocket = None return self._receive_task = self.create_task( self._receive_messages(), name="xfyun_asr_receive", ) async def _finish_utterance(self) -> None: if not self._websocket or self._websocket.state is not State.OPEN: return await self._flush_audio_buffer(final=True) if not self._sent_first_frame: await self._close_utterance() return if not self._sent_final_frame: self._finalizing_turn = True await self._send_payload({"data": {"status": 2}}) self.request_finalize() self._sent_final_frame = True async def _close_utterance(self) -> None: current_task = asyncio.current_task() if self._receive_task and self._receive_task is not current_task: await self.cancel_task(self._receive_task) self._receive_task = None websocket = self._websocket self._websocket = None if websocket and websocket.state is State.OPEN: try: await websocket.close() except Exception: pass self._audio_buffer.clear() self._sent_first_frame = False self._sent_final_frame = False self._finalizing_turn = False async def _flush_audio_buffer(self, *, final: bool) -> None: while len(self._audio_buffer) >= self._frame_size: chunk = bytes(self._audio_buffer[: self._frame_size]) del self._audio_buffer[: self._frame_size] await self._send_audio_chunk(chunk, status=1) if final and self._audio_buffer: chunk = bytes(self._audio_buffer) self._audio_buffer.clear() await self._send_audio_chunk(chunk, status=1) async def _send_audio_chunk(self, audio: bytes, *, status: int) -> None: if not audio: return if not self._sent_first_frame: business = { "language": self._language, "domain": self._domain, "accent": self._accent, } if self._dynamic_correction: business["dwa"] = "wpgs" payload = { "common": {"app_id": self._app_id}, "business": business, "data": { "status": 0, "format": f"audio/L16;rate={self.sample_rate}", "encoding": self._encoding, "audio": base64.b64encode(audio).decode("utf-8"), }, } self._sent_first_frame = True else: payload = { "data": { "status": status, "format": f"audio/L16;rate={self.sample_rate}", "encoding": self._encoding, "audio": base64.b64encode(audio).decode("utf-8"), } } await self._send_payload(payload) async def _send_payload(self, payload: dict[str, Any]) -> None: if not self._websocket or self._websocket.state is not State.OPEN: return await self._websocket.send(json.dumps(payload, ensure_ascii=False)) async def _receive_messages(self) -> None: websocket = self._websocket if not websocket: return try: async for message in websocket: await self._process_response(json.loads(message)) except Exception as exc: if self._websocket is websocket: await self.push_error(f"Xfyun ASR receive failed: {exc}", exception=exc) finally: if self._websocket is websocket: self._websocket = None self._receive_task = None async def _process_response(self, payload: dict[str, Any]) -> None: code = payload.get("code", -1) if code != 0: message = payload.get("message", "unknown error") sid = payload.get("sid") await self.push_error(f"Xfyun ASR error code={code}, sid={sid}, message={message}") return data = payload.get("data") if not isinstance(data, dict): return is_final_response = data.get("status") == 2 recognition = data.get("result") if isinstance(recognition, dict): text = self._apply_recognition_result(recognition) if text and text != self._last_text: self._last_text = text if not self._finalizing_turn and not is_final_response: await self.push_frame( InterimTranscriptionFrame( text, self._user_id, time_now_iso8601(), _language_or_none(self._language), result=payload, ) ) if is_final_response: final_text = self._last_text if final_text: self.confirm_finalize() await self.push_frame( TranscriptionFrame( final_text, self._user_id, time_now_iso8601(), _language_or_none(self._language), result=payload, ) ) await self._close_utterance() def _apply_recognition_result(self, recognition: dict[str, Any]) -> str: partial = _extract_text_from_result(recognition) if not partial: return self._last_text if self._dynamic_correction and recognition.get("pgs") == "rpl" and recognition.get("rg"): start, end = recognition["rg"] if 1 <= start <= len(self._partials): self._partials[start - 1 : end] = [partial] else: logger.debug(f"Ignoring out-of-range Xfyun replacement rg={recognition['rg']}") else: self._partials.append(partial) return "".join(self._partials) def _extract_text_from_result(result: dict[str, Any]) -> str: words: list[str] = [] for item in result.get("ws", []): for candidate in item.get("cw", []): word = candidate.get("w") if word: words.append(word) return "".join(words) def _build_auth_url(url: str, api_key: str, api_secret: str) -> str: parsed = urlparse(url) host = parsed.netloc path = parsed.path or "/v2/iat" date = format_datetime(datetime.now(timezone.utc), usegmt=True) request_line = f"GET {path} HTTP/1.1" signature_origin = f"host: {host}\ndate: {date}\n{request_line}" signature_sha = hmac.new( api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256, ).digest() signature = base64.b64encode(signature_sha).decode("utf-8") authorization_origin = ( f'api_key="{api_key}", algorithm="hmac-sha256", ' f'headers="host date request-line", signature="{signature}"' ) authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode("utf-8") query = urlencode({"authorization": authorization, "date": date, "host": host}) return f"{url}?{query}" def _language_or_none(value: str) -> Language | None: try: return Language(value) except ValueError: return None