Implement StepFun Realtime service and enhance AssistantConfig
- Add new fields to AssistantConfig for realtime interface configuration, including types, values, and secrets. - Introduce StepFunRealtimeService to handle speech-to-speech processing via WebSocket, integrating STT, LLM, and TTS functionalities. - Refactor pipeline execution to support a new realtime mode, allowing direct text input processing and immediate responses. - Update model resource testing to include validation for StepFun Realtime connections. - Enhance service factory to create realtime services based on configuration settings. - Modify README documentation to reflect new realtime capabilities and usage instructions.
This commit is contained in:
@@ -129,6 +129,6 @@ docker compose --profile remote up -d
|
||||
|
||||
- [ ] 联调 Pipecat 1.3.0 语音链路与各 OpenAI 兼容服务
|
||||
- [ ] 起本地 SenseVoice / CosyVoice 的 OpenAI 兼容服务
|
||||
- [ ] `realtime` 模式(目前只 `pipeline` 级联)
|
||||
- [x] `realtime` 模式(StepFun StepAudio Realtime)
|
||||
- [x] 前端 `DebugVoicePanel` 接 `/ws/voice`(参考 dograh `useWebSocketRTC.tsx`)
|
||||
- [ ] 加 DB 后:助手配置入库(目前随请求内联)
|
||||
|
||||
@@ -15,7 +15,6 @@ VALUES
|
||||
('model_004', 'SiliconFlow-CosyVoice2-0.5B', 'TTS', 'openai-tts',
|
||||
'{"modelId":"FunAudioLLM/CosyVoice2-0.5B","apiUrl":"https://api.siliconflow.cn/v1","voice":"FunAudioLLM/CosyVoice2-0.5B:anna","speed":1.0,"sourceSampleRate":24000}',
|
||||
'{"apiKey":"replace-me"}', TRUE, FALSE),
|
||||
'{"apiKey":"replace-me"}', TRUE, FALSE),
|
||||
('model_005', '讯飞语音识别', 'ASR', 'xfyun-asr',
|
||||
'{"apiUrl":"https://iat-api.xfyun.cn/v2/iat","language":"zh_cn","domain":"iat","accent":"mandarin","dynamicCorrection":false,"frameSize":1280}',
|
||||
'{"appId":"replace-me","apiKey":"replace-me","apiSecret":"replace-me"}', TRUE, TRUE),
|
||||
@@ -36,13 +35,9 @@ VALUES
|
||||
'{"apiKey":"replace-me"}', TRUE, FALSE),
|
||||
('model_011', 'text-embedding-3', 'Embedding', 'openai-embedding',
|
||||
'{"modelId":"text-embedding-3-small","apiUrl":"https://api.openai.com/v1/embeddings"}',
|
||||
'{"apiKey":"replace-me"}', TRUE, FALSE),
|
||||
('model_012', 'StepAudio 2.5 Realtime', 'Realtime', 'stepfun-realtime',
|
||||
'{"modelId":"stepaudio-2.5-realtime","apiUrl":"wss://api.stepfun.com/v1/realtime","voice":"linjiajiejie","inputSampleRate":24000,"outputSampleRate":24000,"prefixPaddingMs":500,"silenceDurationMs":300,"energyAwakenessThreshold":2500}',
|
||||
'{"apiKey":"replace-me"}', TRUE, FALSE)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
capability = EXCLUDED.capability,
|
||||
interface_type = EXCLUDED.interface_type,
|
||||
values = EXCLUDED.values,
|
||||
secrets = EXCLUDED.secrets,
|
||||
enabled = EXCLUDED.enabled,
|
||||
is_default = EXCLUDED.is_default,
|
||||
updated_at = now();
|
||||
-- Seed defaults must never overwrite resources configured through the UI.
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
|
||||
@@ -31,6 +31,9 @@ class AssistantConfig(BaseModel):
|
||||
stt_language: str = ""
|
||||
tts_speed: float = 1.0
|
||||
realtimeModel: str = ""
|
||||
realtime_interface_type: str = ""
|
||||
realtime_values: dict = {}
|
||||
realtime_secrets: dict = {}
|
||||
llm_interface_type: str = "openai-llm"
|
||||
stt_interface_type: str = "openai-asr"
|
||||
tts_interface_type: str = "openai-tts"
|
||||
@@ -51,6 +54,8 @@ class AssistantConfig(BaseModel):
|
||||
stt_base_url: str = ""
|
||||
tts_api_key: str = ""
|
||||
tts_base_url: str = ""
|
||||
realtime_api_key: str = ""
|
||||
realtime_base_url: str = ""
|
||||
|
||||
|
||||
class SignalingOffer(BaseModel):
|
||||
|
||||
@@ -59,6 +59,15 @@ def _resource_out(row: ModelResource) -> ModelResourceOut:
|
||||
)
|
||||
|
||||
|
||||
async def _commit_resource(
|
||||
session: AsyncSession, row: ModelResource
|
||||
) -> ModelResourceOut:
|
||||
"""Commit and reload server-generated fields before serializing the row."""
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return _resource_out(row)
|
||||
|
||||
|
||||
async def _definition(
|
||||
session: AsyncSession, interface_type: str
|
||||
) -> InterfaceDefinition:
|
||||
@@ -141,8 +150,7 @@ async def create_model_resource(
|
||||
.where(ModelResource.capability == row.capability, ModelResource.id != row.id)
|
||||
.values(is_default=False)
|
||||
)
|
||||
await session.commit()
|
||||
return _resource_out(row)
|
||||
return await _commit_resource(session, row)
|
||||
|
||||
|
||||
@router.post("/model-resources/test", response_model=ModelResourceTestResult)
|
||||
@@ -196,8 +204,7 @@ async def duplicate_model_resource(
|
||||
is_default=False,
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
return _resource_out(row)
|
||||
return await _commit_resource(session, row)
|
||||
|
||||
|
||||
@router.put("/model-resources/{resource_id}", response_model=ModelResourceOut)
|
||||
@@ -224,8 +231,7 @@ async def update_model_resource(
|
||||
.where(ModelResource.capability == row.capability, ModelResource.id != row.id)
|
||||
.values(is_default=False)
|
||||
)
|
||||
await session.commit()
|
||||
return _resource_out(row)
|
||||
return await _commit_resource(session, row)
|
||||
|
||||
|
||||
@router.delete("/model-resources/{resource_id}")
|
||||
|
||||
@@ -58,7 +58,9 @@ async def voice_signaling(websocket: WebSocket):
|
||||
except Exception as e:
|
||||
logger.error(f"WebRTC 信令出错: {e}")
|
||||
finally:
|
||||
for pc in peers.values():
|
||||
# disconnect() triggers the registered closed callback, which removes
|
||||
# the peer from this dict. Iterate over a snapshot to avoid mutation.
|
||||
for pc in list(peers.values()):
|
||||
await pc.disconnect()
|
||||
|
||||
|
||||
|
||||
@@ -83,6 +83,11 @@ async def resolve_runtime_config(
|
||||
stt_secrets=(stt_resource.secrets or {}) if stt_resource else {},
|
||||
tts_values=(tts_resource.values or {}) if tts_resource else {},
|
||||
tts_secrets=(tts_resource.secrets or {}) if tts_resource else {},
|
||||
realtime_interface_type=(
|
||||
realtime_resource.interface_type if realtime_resource else ""
|
||||
),
|
||||
realtime_values=(realtime_resource.values or {}) if realtime_resource else {},
|
||||
realtime_secrets=(realtime_resource.secrets or {}) if realtime_resource else {},
|
||||
# 运行时连接信息(真 key + url):模型资源优先,否则 .env 兜底
|
||||
llm_api_key=_secret(llm_resource, "apiKey", config.LLM_API_KEY),
|
||||
llm_base_url=str(_value(llm_resource, "apiUrl", config.LLM_BASE_URL)),
|
||||
@@ -90,4 +95,6 @@ async def resolve_runtime_config(
|
||||
stt_base_url=str(_value(stt_resource, "apiUrl", config.STT_BASE_URL)),
|
||||
tts_api_key=_secret(tts_resource, "apiKey", config.TTS_API_KEY),
|
||||
tts_base_url=str(_value(tts_resource, "apiUrl", config.TTS_BASE_URL)),
|
||||
realtime_api_key=_secret(realtime_resource, "apiKey", ""),
|
||||
realtime_base_url=str(_value(realtime_resource, "apiUrl", "")),
|
||||
)
|
||||
|
||||
@@ -78,6 +78,25 @@ INTERFACE_DEFINITIONS: list[dict] = [
|
||||
"capability": "Realtime",
|
||||
"fields": OPENAI_COMMON + [field("voice", "Voice")],
|
||||
},
|
||||
{
|
||||
"interface_type": "stepfun-realtime",
|
||||
"name": "StepFun StepAudio Realtime",
|
||||
"capability": "Realtime",
|
||||
"fields": OPENAI_COMMON
|
||||
+ [
|
||||
field("voice", "Voice", required=True, default="linjiajiejie"),
|
||||
field("inputSampleRate", "Input Sample Rate", type_="number", default=24000),
|
||||
field("outputSampleRate", "Output Sample Rate", type_="number", default=24000),
|
||||
field("prefixPaddingMs", "VAD Prefix Padding (ms)", type_="number", default=500),
|
||||
field("silenceDurationMs", "VAD Silence Duration (ms)", type_="number", default=300),
|
||||
field(
|
||||
"energyAwakenessThreshold",
|
||||
"VAD Energy Threshold",
|
||||
type_="number",
|
||||
default=2500,
|
||||
),
|
||||
],
|
||||
},
|
||||
{
|
||||
"interface_type": "xfyun-asr",
|
||||
"name": "Xfyun Streaming ASR",
|
||||
|
||||
@@ -2,11 +2,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
import wave
|
||||
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
||||
|
||||
import httpx
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
|
||||
import config
|
||||
from schemas import ModelResourceTestResult
|
||||
@@ -57,6 +61,8 @@ async def test_model_resource(
|
||||
message="讯飞连接参数有效",
|
||||
detail="鉴权字段和连接参数完整,请在语音测试页验证签名及音频链路",
|
||||
)
|
||||
if interface_type == "stepfun-realtime":
|
||||
return await _test_stepfun_realtime(values, secrets)
|
||||
if capability == "Realtime":
|
||||
return ModelResourceTestResult(
|
||||
ok=False,
|
||||
@@ -150,3 +156,61 @@ async def test_model_resource(
|
||||
message="无法连接到模型服务",
|
||||
detail=str(exc)[:300],
|
||||
)
|
||||
|
||||
|
||||
async def _test_stepfun_realtime(
|
||||
values: dict, secrets: dict
|
||||
) -> ModelResourceTestResult:
|
||||
api_url = str(values.get("apiUrl") or "")
|
||||
model_id = str(values.get("modelId") or "")
|
||||
api_key = str(secrets.get("apiKey") or "")
|
||||
parts = urlsplit(api_url)
|
||||
query = dict(parse_qsl(parts.query))
|
||||
query["model"] = model_id
|
||||
url = urlunsplit(
|
||||
(parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment)
|
||||
)
|
||||
started = time.perf_counter()
|
||||
|
||||
try:
|
||||
async with websocket_connect(
|
||||
url,
|
||||
additional_headers={"Authorization": f"Bearer {api_key}"},
|
||||
open_timeout=TEST_TIMEOUT_SECONDS,
|
||||
close_timeout=2,
|
||||
) as websocket:
|
||||
raw_message = await asyncio.wait_for(
|
||||
websocket.recv(), timeout=TEST_TIMEOUT_SECONDS
|
||||
)
|
||||
event = json.loads(raw_message)
|
||||
if event.get("type") != "session.created":
|
||||
return ModelResourceTestResult(
|
||||
ok=False,
|
||||
latency_ms=round((time.perf_counter() - started) * 1000),
|
||||
message="Realtime 连接返回了意外事件",
|
||||
detail=str(event.get("type") or event)[:300],
|
||||
)
|
||||
return ModelResourceTestResult(
|
||||
ok=True,
|
||||
latency_ms=round((time.perf_counter() - started) * 1000),
|
||||
message="Realtime 连接成功",
|
||||
detail="StepFun 返回 session.created",
|
||||
)
|
||||
except TimeoutError:
|
||||
return ModelResourceTestResult(
|
||||
ok=False,
|
||||
latency_ms=round((time.perf_counter() - started) * 1000),
|
||||
message="Realtime 连接超时",
|
||||
detail=f"服务未在 {TEST_TIMEOUT_SECONDS:g} 秒内创建 session",
|
||||
)
|
||||
except Exception as exc:
|
||||
detail = str(exc)
|
||||
for secret in secrets.values():
|
||||
if secret:
|
||||
detail = detail.replace(str(secret), "***")
|
||||
return ModelResourceTestResult(
|
||||
ok=False,
|
||||
latency_ms=round((time.perf_counter() - started) * 1000),
|
||||
message="无法连接到 StepFun Realtime",
|
||||
detail=detail[:300],
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from uuid import uuid4
|
||||
|
||||
from loguru import logger
|
||||
from models import AssistantConfig
|
||||
from services.pipecat.service_factory import create_services
|
||||
from services.pipecat.service_factory import create_realtime_service, create_services
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
@@ -106,6 +106,33 @@ class TextInputProcessor(FrameProcessor):
|
||||
await self._call_event_handler("on_text_append", text)
|
||||
|
||||
|
||||
class RealtimeTextInputProcessor(FrameProcessor):
|
||||
"""Route text input directly to a realtime service without cascade semantics."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._register_event_handler("on_text_input")
|
||||
self._register_event_handler("on_text_append")
|
||||
|
||||
async def process_frame(self, frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if not isinstance(frame, InputTransportMessageFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
parsed = _text_input(frame.message)
|
||||
if not parsed:
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
text, run_immediately = parsed
|
||||
await self._call_event_handler(
|
||||
"on_text_input" if run_immediately else "on_text_append",
|
||||
text,
|
||||
)
|
||||
|
||||
|
||||
class PassthroughLLMAssistantAggregator(LLMAssistantAggregator):
|
||||
"""聚合 LLM 回复进上下文,同时继续把回复帧交给下游 TTS。"""
|
||||
|
||||
@@ -176,6 +203,10 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
"""
|
||||
logger.info(f"启动管线: assistant={cfg.name} mode={cfg.runtimeMode}")
|
||||
|
||||
if cfg.runtimeMode == "realtime":
|
||||
await run_realtime_pipeline(transport, cfg)
|
||||
return
|
||||
|
||||
stt, llm, tts = create_services(cfg)
|
||||
|
||||
context = LLMContext(messages=[{"role": "system", "content": cfg.prompt}])
|
||||
@@ -327,3 +358,70 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
await runner.add_workers(worker)
|
||||
await runner.run()
|
||||
logger.info("管线已结束")
|
||||
|
||||
|
||||
async def run_realtime_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
"""Run a speech-to-speech model that owns ASR, reasoning, and synthesis."""
|
||||
realtime = create_realtime_service(cfg)
|
||||
text_input = RealtimeTextInputProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
text_input,
|
||||
realtime,
|
||||
transport.output(),
|
||||
]
|
||||
)
|
||||
worker = PipelineWorker(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=False,
|
||||
audio_in_sample_rate=int(
|
||||
cfg.realtime_values.get("inputSampleRate") or 24000
|
||||
),
|
||||
audio_out_sample_rate=int(
|
||||
cfg.realtime_values.get("outputSampleRate") or 24000
|
||||
),
|
||||
),
|
||||
enable_rtvi=False,
|
||||
)
|
||||
|
||||
async def queue_transcript(role: str, content: str) -> None:
|
||||
if content:
|
||||
await worker.queue_frame(
|
||||
OutputTransportMessageUrgentFrame(
|
||||
message={
|
||||
"type": "transcript",
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": time_now_iso8601(),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@text_input.event_handler("on_text_input")
|
||||
async def on_text_input(_processor, text):
|
||||
await queue_transcript("user", text)
|
||||
await realtime.interrupt()
|
||||
await realtime.send_text(text, run_immediately=True)
|
||||
|
||||
@text_input.event_handler("on_text_append")
|
||||
async def on_text_append(_processor, text):
|
||||
await queue_transcript("user", text)
|
||||
await realtime.send_text(text, run_immediately=False)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(_transport, _client):
|
||||
if cfg.greeting:
|
||||
await realtime.speak(cfg.greeting)
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(_transport, _client):
|
||||
logger.info("Realtime 对端断开,结束管线")
|
||||
await worker.queue_frame(EndFrame())
|
||||
|
||||
runner = WorkerRunner(handle_sigint=False)
|
||||
await runner.add_workers(worker)
|
||||
await runner.run()
|
||||
logger.info("Realtime 管线已结束")
|
||||
|
||||
@@ -133,3 +133,33 @@ def create_services(cfg: AssistantConfig):
|
||||
f"voice={cfg.voice or config.TTS_VOICE}"
|
||||
)
|
||||
return create_stt(cfg), create_llm(cfg), create_tts(cfg)
|
||||
|
||||
|
||||
def create_realtime_service(cfg: AssistantConfig):
|
||||
"""Create a speech-to-speech service that owns STT, LLM, and TTS."""
|
||||
if cfg.realtime_interface_type == "stepfun-realtime":
|
||||
from services.pipecat.stepfun_realtime import StepFunRealtimeService
|
||||
|
||||
return StepFunRealtimeService(
|
||||
api_key=cfg.realtime_api_key,
|
||||
model=cfg.realtimeModel,
|
||||
base_url=cfg.realtime_base_url,
|
||||
instructions=cfg.prompt,
|
||||
voice=str(cfg.realtime_values.get("voice") or "linjiajiejie"),
|
||||
input_sample_rate=int(
|
||||
cfg.realtime_values.get("inputSampleRate") or 24000
|
||||
),
|
||||
output_sample_rate=int(
|
||||
cfg.realtime_values.get("outputSampleRate") or 24000
|
||||
),
|
||||
prefix_padding_ms=int(
|
||||
cfg.realtime_values.get("prefixPaddingMs") or 500
|
||||
),
|
||||
silence_duration_ms=int(
|
||||
cfg.realtime_values.get("silenceDurationMs") or 300
|
||||
),
|
||||
energy_awakeness_threshold=int(
|
||||
cfg.realtime_values.get("energyAwakenessThreshold") or 2500
|
||||
),
|
||||
)
|
||||
raise ValueError(f"不支持的 Realtime 接口类型: {cfg.realtime_interface_type}")
|
||||
|
||||
367
backend/services/pipecat/stepfun_realtime.py
Normal file
367
backend/services/pipecat/stepfun_realtime.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""StepFun StepAudio realtime speech-to-speech Pipecat service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
||||
from uuid import uuid4
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InterruptionFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
OutputTransportMessageUrgentFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.settings import ServiceSettings
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
|
||||
DEFAULT_STEPFUN_REALTIME_URL = "wss://api.stepfun.com/v1/realtime"
|
||||
|
||||
|
||||
class StepFunRealtimeService(AIService):
|
||||
"""Bridge Pipecat audio frames to StepFun's Realtime WebSocket events."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str,
|
||||
base_url: str = DEFAULT_STEPFUN_REALTIME_URL,
|
||||
instructions: str = "",
|
||||
voice: str = "linjiajiejie",
|
||||
input_sample_rate: int = 24000,
|
||||
output_sample_rate: int = 24000,
|
||||
prefix_padding_ms: int = 500,
|
||||
silence_duration_ms: int = 300,
|
||||
energy_awakeness_threshold: int = 2500,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(settings=ServiceSettings(model=model), **kwargs)
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
self._base_url = base_url or DEFAULT_STEPFUN_REALTIME_URL
|
||||
self._instructions = instructions
|
||||
self._voice = voice
|
||||
self._input_sample_rate = input_sample_rate
|
||||
self._output_sample_rate = output_sample_rate
|
||||
self._prefix_padding_ms = prefix_padding_ms
|
||||
self._silence_duration_ms = silence_duration_ms
|
||||
self._energy_awakeness_threshold = energy_awakeness_threshold
|
||||
self._warned_input_sample_rate = False
|
||||
self._websocket = None
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._session_ready = asyncio.Event()
|
||||
self._pending_events: list[dict[str, Any]] = []
|
||||
self._assistant_turn_id: str | None = None
|
||||
self._assistant_text = ""
|
||||
self._assistant_timestamp = ""
|
||||
|
||||
async def start(self, frame: StartFrame) -> None:
|
||||
await super().start(frame)
|
||||
if not self._api_key or not self._model:
|
||||
await self.push_error(
|
||||
"StepFun Realtime requires api_key and model", fatal=True
|
||||
)
|
||||
return
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame) -> None:
|
||||
await self._disconnect()
|
||||
await super().stop(frame)
|
||||
|
||||
async def cancel(self, frame: CancelFrame) -> None:
|
||||
await self._disconnect()
|
||||
await super().cancel(frame)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
await self._disconnect()
|
||||
await super().cleanup()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, InputAudioRawFrame):
|
||||
if (
|
||||
frame.sample_rate != self._input_sample_rate
|
||||
and not self._warned_input_sample_rate
|
||||
):
|
||||
self._warned_input_sample_rate = True
|
||||
logger.warning(
|
||||
"StepFun Realtime expected {} Hz input, received {} Hz",
|
||||
self._input_sample_rate,
|
||||
frame.sample_rate,
|
||||
)
|
||||
await self._send_event(
|
||||
{
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": base64.b64encode(frame.audio).decode("ascii"),
|
||||
}
|
||||
)
|
||||
return
|
||||
if isinstance(frame, LLMMessagesAppendFrame):
|
||||
for message in frame.messages:
|
||||
text = self._message_text(message)
|
||||
if text:
|
||||
await self.send_text(text, run_immediately=frame.run_llm is not False)
|
||||
return
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._send_event({"type": "response.cancel"}, wait_until_ready=False)
|
||||
await self._finish_assistant_text(interrupted=True)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def send_text(self, text: str, *, run_immediately: bool = True) -> None:
|
||||
await self._send_event(
|
||||
{
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": text}],
|
||||
},
|
||||
}
|
||||
)
|
||||
if run_immediately:
|
||||
await self._send_event({"type": "response.create"})
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
await self._send_event({"type": "response.cancel"}, wait_until_ready=False)
|
||||
await self._finish_assistant_text(interrupted=True)
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def speak(self, text: str) -> None:
|
||||
"""Ask the realtime model to voice a fixed greeting."""
|
||||
if not text:
|
||||
return
|
||||
await self._send_event(
|
||||
{
|
||||
"type": "response.create",
|
||||
"session": {
|
||||
"instructions": f"请原样无修改地输出下面的话:\n{text}",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
async def _connect(self) -> None:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
try:
|
||||
self._websocket = await websocket_connect(
|
||||
self._connection_url(),
|
||||
additional_headers={"Authorization": f"Bearer {self._api_key}"},
|
||||
max_size=None,
|
||||
open_timeout=10,
|
||||
)
|
||||
self._receive_task = self.create_task(
|
||||
self._receive_messages(), name="stepfun_realtime_receive"
|
||||
)
|
||||
except Exception as exc:
|
||||
self._websocket = None
|
||||
await self.push_error(
|
||||
f"StepFun Realtime connection failed: {exc}",
|
||||
exception=exc,
|
||||
fatal=True,
|
||||
)
|
||||
|
||||
async def _disconnect(self) -> None:
|
||||
current_task = asyncio.current_task()
|
||||
task = self._receive_task
|
||||
self._receive_task = None
|
||||
if task and task is not current_task:
|
||||
await self.cancel_task(task)
|
||||
|
||||
websocket = self._websocket
|
||||
self._websocket = None
|
||||
self._session_ready.clear()
|
||||
if websocket and websocket.state is State.OPEN:
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _connection_url(self) -> str:
|
||||
parts = urlsplit(self._base_url)
|
||||
query = dict(parse_qsl(parts.query))
|
||||
query["model"] = self._model
|
||||
return urlunsplit(
|
||||
(parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment)
|
||||
)
|
||||
|
||||
async def _receive_messages(self) -> None:
|
||||
websocket = self._websocket
|
||||
if not websocket:
|
||||
return
|
||||
try:
|
||||
async for raw_message in websocket:
|
||||
payload = json.loads(raw_message)
|
||||
await self._handle_server_event(payload)
|
||||
except Exception as exc:
|
||||
if self._websocket is websocket:
|
||||
await self.push_error(
|
||||
f"StepFun Realtime receive failed: {exc}", exception=exc
|
||||
)
|
||||
finally:
|
||||
if self._websocket is websocket:
|
||||
self._websocket = None
|
||||
self._session_ready.clear()
|
||||
if self._receive_task is asyncio.current_task():
|
||||
self._receive_task = None
|
||||
|
||||
async def _handle_server_event(self, event: dict[str, Any]) -> None:
|
||||
event_type = event.get("type")
|
||||
if event_type == "session.created":
|
||||
await self._send_session_update()
|
||||
elif event_type == "session.updated":
|
||||
self._session_ready.set()
|
||||
pending, self._pending_events = self._pending_events, []
|
||||
for payload in pending:
|
||||
await self._send_event(payload, wait_until_ready=False)
|
||||
elif event_type == "response.audio.delta":
|
||||
audio = event.get("delta")
|
||||
if audio:
|
||||
await self.push_frame(
|
||||
TTSAudioRawFrame(
|
||||
base64.b64decode(audio),
|
||||
self._output_sample_rate,
|
||||
1,
|
||||
)
|
||||
)
|
||||
elif event_type in {"response.audio_transcript.delta", "response.text.delta"}:
|
||||
await self._append_assistant_text(str(event.get("delta") or ""))
|
||||
elif event_type in {"response.audio_transcript.done", "response.text.done"}:
|
||||
transcript = str(event.get("transcript") or event.get("text") or "")
|
||||
if transcript:
|
||||
if not self._assistant_turn_id:
|
||||
await self._append_assistant_text(transcript)
|
||||
else:
|
||||
self._assistant_text = transcript
|
||||
await self._finish_assistant_text(interrupted=False)
|
||||
elif event_type == "conversation.item.input_audio_transcription.completed":
|
||||
await self._send_transcript("user", str(event.get("transcript") or ""))
|
||||
elif event_type == "input_audio_buffer.speech_started":
|
||||
await self._send_event({"type": "response.cancel"}, wait_until_ready=False)
|
||||
await self.broadcast_interruption()
|
||||
elif event_type == "response.done":
|
||||
response = event.get("response")
|
||||
interrupted = isinstance(response, dict) and response.get("status") in {
|
||||
"cancelled",
|
||||
"incomplete",
|
||||
"interrupted",
|
||||
}
|
||||
await self._finish_assistant_text(interrupted=interrupted)
|
||||
elif event_type == "error":
|
||||
error = event.get("error")
|
||||
message = error.get("message") if isinstance(error, dict) else str(error)
|
||||
if "cancel" not in str(message).lower():
|
||||
await self.push_error(f"StepFun Realtime error: {message}")
|
||||
|
||||
async def _send_session_update(self) -> None:
|
||||
await self._send_event(
|
||||
{
|
||||
"type": "session.update",
|
||||
"session": {
|
||||
"modalities": ["text", "audio"],
|
||||
"instructions": self._instructions,
|
||||
"voice": self._voice,
|
||||
"input_audio_format": "pcm16",
|
||||
"output_audio_format": "pcm16",
|
||||
"turn_detection": {
|
||||
"type": "server_vad",
|
||||
"prefix_padding_ms": self._prefix_padding_ms,
|
||||
"silence_duration_ms": self._silence_duration_ms,
|
||||
"energy_awakeness_threshold": self._energy_awakeness_threshold,
|
||||
},
|
||||
},
|
||||
},
|
||||
wait_until_ready=False,
|
||||
)
|
||||
|
||||
async def _send_event(
|
||||
self, payload: dict[str, Any], *, wait_until_ready: bool = True
|
||||
) -> None:
|
||||
if wait_until_ready and not self._session_ready.is_set():
|
||||
self._pending_events.append(payload)
|
||||
return
|
||||
if not self._websocket or self._websocket.state is not State.OPEN:
|
||||
return
|
||||
payload = {"event_id": uuid4().hex, **payload}
|
||||
await self._websocket.send(json.dumps(payload, ensure_ascii=False))
|
||||
|
||||
async def _append_assistant_text(self, delta: str) -> None:
|
||||
if not delta:
|
||||
return
|
||||
if not self._assistant_turn_id:
|
||||
self._assistant_turn_id = uuid4().hex
|
||||
self._assistant_timestamp = time_now_iso8601()
|
||||
await self._send_transport_message(
|
||||
{
|
||||
"type": "assistant-text-start",
|
||||
"turn_id": self._assistant_turn_id,
|
||||
"timestamp": self._assistant_timestamp,
|
||||
}
|
||||
)
|
||||
self._assistant_text += delta
|
||||
await self._send_transport_message(
|
||||
{
|
||||
"type": "assistant-text-delta",
|
||||
"turn_id": self._assistant_turn_id,
|
||||
"delta": delta,
|
||||
}
|
||||
)
|
||||
|
||||
async def _finish_assistant_text(self, *, interrupted: bool) -> None:
|
||||
if not self._assistant_turn_id:
|
||||
return
|
||||
await self._send_transport_message(
|
||||
{
|
||||
"type": "assistant-text-end",
|
||||
"turn_id": self._assistant_turn_id,
|
||||
"content": self._assistant_text,
|
||||
"interrupted": interrupted,
|
||||
}
|
||||
)
|
||||
self._assistant_turn_id = None
|
||||
self._assistant_text = ""
|
||||
self._assistant_timestamp = ""
|
||||
|
||||
async def _send_transcript(self, role: str, content: str) -> None:
|
||||
if content:
|
||||
await self._send_transport_message(
|
||||
{
|
||||
"type": "transcript",
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": time_now_iso8601(),
|
||||
}
|
||||
)
|
||||
|
||||
async def _send_transport_message(self, message: dict[str, Any]) -> None:
|
||||
await self.push_frame(OutputTransportMessageUrgentFrame(message=message))
|
||||
|
||||
@staticmethod
|
||||
def _message_text(message: Any) -> str:
|
||||
if not isinstance(message, dict):
|
||||
return ""
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
return "\n".join(
|
||||
str(part.get("text") or "")
|
||||
for part in content
|
||||
if isinstance(part, dict) and part.get("type") == "text"
|
||||
).strip()
|
||||
return ""
|
||||
Reference in New Issue
Block a user