Files
pipecat/tests/test_langchain.py
Mark Backman 8c9e189394 Fix langchain imports for langchain 1.x compatibility
ChatPromptTemplate moved from langchain.prompts to langchain_core.prompts
in langchain 1.x.
2026-03-29 10:27:48 -04:00

114 lines
3.8 KiB
Python

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from langchain_core.language_models import FakeStreamingListLLM
from langchain_core.prompts import ChatPromptTemplate
from pipecat.frames.frames import (
InterruptionFrame,
LLMContextAssistantTimestampFrame,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
VADUserStartedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.processors.frameworks.langchain import LangchainProcessor
from pipecat.tests.utils import SleepFrame, run_test
from pipecat.turns.user_stop import SpeechTimeoutUserTurnStopStrategy
from pipecat.turns.user_turn_strategies import UserTurnStrategies
class TestLangchain(unittest.IsolatedAsyncioTestCase):
class MockProcessor(FrameProcessor):
def __init__(self, name):
super().__init__(name=name)
self.token: list[str] = []
# Start collecting tokens when we see the start frame
self.start_collecting = False
def __str__(self):
return self.name
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
if isinstance(frame, LLMFullResponseStartFrame):
self.start_collecting = True
elif isinstance(frame, TextFrame) and self.start_collecting:
self.token.append(frame.text)
elif isinstance(frame, LLMFullResponseEndFrame):
self.start_collecting = False
await self.push_frame(frame, direction)
def setUp(self):
self.expected_response = "Hello dear human"
self.fake_llm = FakeStreamingListLLM(responses=[self.expected_response])
async def test_langchain(self):
messages = [("system", "Say hello to {name}"), ("human", "{input}")]
prompt = ChatPromptTemplate.from_messages(messages).partial(name="Thomas")
chain = prompt | self.fake_llm
proc = LangchainProcessor(chain=chain)
self.mock_proc = self.MockProcessor("token_collector")
context = LLMContext()
context_aggregator = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(
user_turn_strategies=UserTurnStrategies(stop=[SpeechTimeoutUserTurnStopStrategy()])
),
)
pipeline = Pipeline(
[context_aggregator.user(), proc, self.mock_proc, context_aggregator.assistant()]
)
frames_to_send = [
VADUserStartedSpeakingFrame(),
TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"),
SleepFrame(),
VADUserStoppedSpeakingFrame(),
SleepFrame(sleep=1.0),
]
expected_down_frames = [
VADUserStartedSpeakingFrame,
UserStartedSpeakingFrame,
InterruptionFrame,
VADUserStoppedSpeakingFrame,
UserStoppedSpeakingFrame,
LLMContextFrame,
LLMContextAssistantTimestampFrame,
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.assertEqual("".join(self.mock_proc.token), self.expected_response)
self.assertEqual(
context_aggregator.assistant().messages[-1]["content"], self.expected_response
)
if __name__ == "__main__":
unittest.main()