fix text input no context update
This commit is contained in:
@@ -14,7 +14,6 @@ from pipecat.pipeline.runner import PipelineRunner
|
|||||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||||
from pipecat.processors.aggregators.llm_response_universal import (
|
from pipecat.processors.aggregators.llm_response_universal import (
|
||||||
AssistantTurnStoppedMessage,
|
|
||||||
LLMContextAggregatorPair,
|
LLMContextAggregatorPair,
|
||||||
LLMUserAggregatorParams,
|
LLMUserAggregatorParams,
|
||||||
UserTurnStoppedMessage,
|
UserTurnStoppedMessage,
|
||||||
@@ -34,7 +33,7 @@ from .config import EngineConfig
|
|||||||
from .product_protocol import ProductWebsocketSerializer
|
from .product_protocol import ProductWebsocketSerializer
|
||||||
from .services import create_llm_service, create_stt_service, create_tts_service
|
from .services import create_llm_service, create_stt_service, create_tts_service
|
||||||
from .text_input import ProductTextInputProcessor
|
from .text_input import ProductTextInputProcessor
|
||||||
from .text_stream import ProductTextStreamProcessor
|
from .text_stream import ProductAssistantTurnStoppedMessage, ProductTextStreamProcessor
|
||||||
from .transcript_stream import ProductTranscriptStreamProcessor
|
from .transcript_stream import ProductTranscriptStreamProcessor
|
||||||
from .turn_start import InterruptionGateUserTurnStartStrategy
|
from .turn_start import InterruptionGateUserTurnStartStrategy
|
||||||
|
|
||||||
@@ -118,13 +117,14 @@ async def run_pipeline_with_serializer(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
user_aggregator, _ = LLMContextAggregatorPair(
|
||||||
context,
|
context,
|
||||||
user_params=LLMUserAggregatorParams(
|
user_params=LLMUserAggregatorParams(
|
||||||
vad_analyzer=SileroVADAnalyzer(params=vad_params),
|
vad_analyzer=SileroVADAnalyzer(params=vad_params),
|
||||||
user_turn_strategies=user_turn_strategies,
|
user_turn_strategies=user_turn_strategies,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
text_stream = ProductTextStreamProcessor(context)
|
||||||
|
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
[
|
[
|
||||||
@@ -134,10 +134,9 @@ async def run_pipeline_with_serializer(
|
|||||||
ProductTranscriptStreamProcessor(),
|
ProductTranscriptStreamProcessor(),
|
||||||
user_aggregator,
|
user_aggregator,
|
||||||
llm,
|
llm,
|
||||||
ProductTextStreamProcessor(),
|
text_stream,
|
||||||
tts,
|
tts,
|
||||||
transport.output(),
|
transport.output(),
|
||||||
assistant_aggregator,
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -188,11 +187,10 @@ async def run_pipeline_with_serializer(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: assistant turn started/final events are emitted by
|
@text_stream.event_handler("on_assistant_turn_stopped")
|
||||||
# ProductTextStreamProcessor, upstream of TTS, so text streams to the
|
async def on_assistant_turn_stopped(
|
||||||
# client ahead of audio. This logger is kept for server-side visibility.
|
_aggregator, message: ProductAssistantTurnStoppedMessage
|
||||||
@assistant_aggregator.event_handler("on_assistant_turn_stopped")
|
):
|
||||||
async def on_assistant_turn_stopped(_aggregator, message: AssistantTurnStoppedMessage):
|
|
||||||
logger.info(f"Assistant: {message.content}")
|
logger.info(f"Assistant: {message.content}")
|
||||||
|
|
||||||
runner = PipelineRunner(handle_sigint=False)
|
runner = PipelineRunner(handle_sigint=False)
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from pipecat.frames.frames import (
|
from pipecat.frames.frames import (
|
||||||
Frame,
|
Frame,
|
||||||
InterruptionFrame,
|
InterruptionFrame,
|
||||||
@@ -9,7 +13,16 @@ from pipecat.frames.frames import (
|
|||||||
OutputTransportMessageUrgentFrame,
|
OutputTransportMessageUrgentFrame,
|
||||||
TTSSpeakFrame,
|
TTSSpeakFrame,
|
||||||
)
|
)
|
||||||
|
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||||
|
from pipecat.utils.time import time_now_iso8601
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProductAssistantTurnStoppedMessage:
|
||||||
|
content: str
|
||||||
|
interrupted: bool
|
||||||
|
timestamp: str
|
||||||
|
|
||||||
|
|
||||||
class ProductTextStreamProcessor(FrameProcessor):
|
class ProductTextStreamProcessor(FrameProcessor):
|
||||||
@@ -29,10 +42,13 @@ class ProductTextStreamProcessor(FrameProcessor):
|
|||||||
started/delta/final sequence for its fixed text.
|
started/delta/final sequence for its fixed text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, context: LLMContext | None = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self._context = context
|
||||||
self._aggregation: list[str] = []
|
self._aggregation: list[str] = []
|
||||||
self._turn_active = False
|
self._turn_active = False
|
||||||
|
self._turn_start_timestamp = ""
|
||||||
|
self._register_event_handler("on_assistant_turn_stopped")
|
||||||
|
|
||||||
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
|
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
|
||||||
await super().process_frame(frame, direction)
|
await super().process_frame(frame, direction)
|
||||||
@@ -62,6 +78,7 @@ class ProductTextStreamProcessor(FrameProcessor):
|
|||||||
return
|
return
|
||||||
self._turn_active = True
|
self._turn_active = True
|
||||||
self._aggregation = []
|
self._aggregation = []
|
||||||
|
self._turn_start_timestamp = time_now_iso8601()
|
||||||
await self._emit("response.text.started")
|
await self._emit("response.text.started")
|
||||||
|
|
||||||
async def _delta(self, text: str) -> None:
|
async def _delta(self, text: str) -> None:
|
||||||
@@ -78,11 +95,26 @@ class ProductTextStreamProcessor(FrameProcessor):
|
|||||||
full_text = "".join(self._aggregation)
|
full_text = "".join(self._aggregation)
|
||||||
self._turn_active = False
|
self._turn_active = False
|
||||||
self._aggregation = []
|
self._aggregation = []
|
||||||
|
if self._context and full_text:
|
||||||
|
self._context.add_message({"role": "assistant", "content": full_text})
|
||||||
|
logger.info(
|
||||||
|
"Assistant committed to LLM context before TTS: "
|
||||||
|
f"{full_text[:120]}"
|
||||||
|
)
|
||||||
await self._emit(
|
await self._emit(
|
||||||
"response.text.final",
|
"response.text.final",
|
||||||
text=full_text,
|
text=full_text,
|
||||||
interrupted=interrupted,
|
interrupted=interrupted,
|
||||||
)
|
)
|
||||||
|
await self._call_event_handler(
|
||||||
|
"on_assistant_turn_stopped",
|
||||||
|
ProductAssistantTurnStoppedMessage(
|
||||||
|
content=full_text,
|
||||||
|
interrupted=interrupted,
|
||||||
|
timestamp=self._turn_start_timestamp or time_now_iso8601(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._turn_start_timestamp = ""
|
||||||
|
|
||||||
async def _emit(self, event_type: str, **payload: object) -> None:
|
async def _emit(self, event_type: str, **payload: object) -> None:
|
||||||
await self.push_frame(
|
await self.push_frame(
|
||||||
|
|||||||
Reference in New Issue
Block a user