99 lines
2.9 KiB
Python
99 lines
2.9 KiB
Python
from __future__ import annotations
|
|
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from loguru import logger
|
|
|
|
from pipecat.frames.frames import (
|
|
ControlFrame,
|
|
Frame,
|
|
InputTransportMessageFrame,
|
|
OutputTransportMessageUrgentFrame,
|
|
UninterruptibleFrame,
|
|
)
|
|
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
|
|
|
|
|
@dataclass
|
|
class FastGPTStateFlushRequestFrame(ControlFrame, UninterruptibleFrame):
|
|
"""Queued FastGPT state update request.
|
|
|
|
This frame carries one pending set_info operation. It is not a state cache:
|
|
the FastGPT service still reads the latest state before applying it.
|
|
"""
|
|
|
|
request_id: str
|
|
key: str
|
|
value: Any
|
|
|
|
|
|
class SetInfoProcessor(FrameProcessor):
|
|
"""Converts product set_info messages into queued FastGPT state writes."""
|
|
|
|
def __init__(self, *, enabled: bool = True):
|
|
super().__init__()
|
|
self._enabled = enabled
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
|
await super().process_frame(frame, direction)
|
|
|
|
if not isinstance(frame, InputTransportMessageFrame):
|
|
await self.push_frame(frame, direction)
|
|
return
|
|
|
|
message = frame.message
|
|
if not isinstance(message, dict) or message.get("type") != "session.set_info":
|
|
await self.push_frame(frame, direction)
|
|
return
|
|
|
|
request_id = str(message.get("request_id") or uuid.uuid4().hex)
|
|
if not self._enabled:
|
|
await self._push_ack(
|
|
request_id=request_id,
|
|
ok=False,
|
|
error="session.set_info requires FastGPT LLM backend",
|
|
retryable=False,
|
|
)
|
|
return
|
|
|
|
key = message.get("key")
|
|
if not isinstance(key, str) or not key.strip():
|
|
await self._push_ack(
|
|
request_id=request_id,
|
|
ok=False,
|
|
error="session.set_info requires non-empty string key",
|
|
retryable=False,
|
|
)
|
|
return
|
|
|
|
logger.info(f"Queueing FastGPT set_info request request_id={request_id} key={key!r}")
|
|
await self.push_frame(
|
|
FastGPTStateFlushRequestFrame(
|
|
request_id=request_id,
|
|
key=key.strip(),
|
|
value=message.get("value"),
|
|
),
|
|
FrameDirection.DOWNSTREAM,
|
|
)
|
|
|
|
async def _push_ack(
|
|
self,
|
|
*,
|
|
request_id: str,
|
|
ok: bool,
|
|
error: str | None = None,
|
|
retryable: bool | None = None,
|
|
) -> None:
|
|
payload: dict[str, Any] = {
|
|
"type": "session.set_info.ack",
|
|
"request_id": request_id,
|
|
"ok": ok,
|
|
}
|
|
if error is not None:
|
|
payload["error"] = error
|
|
if retryable is not None:
|
|
payload["retryable"] = retryable
|
|
await self.push_frame(OutputTransportMessageUrgentFrame(message=payload))
|