ChatPromptTemplate moved from langchain.prompts to langchain_core.prompts in langchain 1.x.
114 lines
3.8 KiB
Python
114 lines
3.8 KiB
Python
#
|
|
# Copyright (c) 2024-2026, Daily
|
|
#
|
|
# SPDX-License-Identifier: BSD 2-Clause License
|
|
#
|
|
|
|
import unittest
|
|
|
|
from langchain_core.language_models import FakeStreamingListLLM
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
from pipecat.frames.frames import (
|
|
InterruptionFrame,
|
|
LLMContextAssistantTimestampFrame,
|
|
LLMContextFrame,
|
|
LLMFullResponseEndFrame,
|
|
LLMFullResponseStartFrame,
|
|
TextFrame,
|
|
TranscriptionFrame,
|
|
UserStartedSpeakingFrame,
|
|
UserStoppedSpeakingFrame,
|
|
VADUserStartedSpeakingFrame,
|
|
VADUserStoppedSpeakingFrame,
|
|
)
|
|
from pipecat.pipeline.pipeline import Pipeline
|
|
from pipecat.processors.aggregators.llm_context import LLMContext
|
|
from pipecat.processors.aggregators.llm_response_universal import (
|
|
LLMContextAggregatorPair,
|
|
LLMUserAggregatorParams,
|
|
)
|
|
from pipecat.processors.frame_processor import FrameProcessor
|
|
from pipecat.processors.frameworks.langchain import LangchainProcessor
|
|
from pipecat.tests.utils import SleepFrame, run_test
|
|
from pipecat.turns.user_stop import SpeechTimeoutUserTurnStopStrategy
|
|
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
|
|
|
|
|
class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
|
class MockProcessor(FrameProcessor):
|
|
def __init__(self, name):
|
|
super().__init__(name=name)
|
|
self.token: list[str] = []
|
|
# Start collecting tokens when we see the start frame
|
|
self.start_collecting = False
|
|
|
|
def __str__(self):
|
|
return self.name
|
|
|
|
async def process_frame(self, frame, direction):
|
|
await super().process_frame(frame, direction)
|
|
|
|
if isinstance(frame, LLMFullResponseStartFrame):
|
|
self.start_collecting = True
|
|
elif isinstance(frame, TextFrame) and self.start_collecting:
|
|
self.token.append(frame.text)
|
|
elif isinstance(frame, LLMFullResponseEndFrame):
|
|
self.start_collecting = False
|
|
|
|
await self.push_frame(frame, direction)
|
|
|
|
def setUp(self):
|
|
self.expected_response = "Hello dear human"
|
|
self.fake_llm = FakeStreamingListLLM(responses=[self.expected_response])
|
|
|
|
async def test_langchain(self):
|
|
messages = [("system", "Say hello to {name}"), ("human", "{input}")]
|
|
prompt = ChatPromptTemplate.from_messages(messages).partial(name="Thomas")
|
|
chain = prompt | self.fake_llm
|
|
proc = LangchainProcessor(chain=chain)
|
|
self.mock_proc = self.MockProcessor("token_collector")
|
|
|
|
context = LLMContext()
|
|
context_aggregator = LLMContextAggregatorPair(
|
|
context,
|
|
user_params=LLMUserAggregatorParams(
|
|
user_turn_strategies=UserTurnStrategies(stop=[SpeechTimeoutUserTurnStopStrategy()])
|
|
),
|
|
)
|
|
|
|
pipeline = Pipeline(
|
|
[context_aggregator.user(), proc, self.mock_proc, context_aggregator.assistant()]
|
|
)
|
|
|
|
frames_to_send = [
|
|
VADUserStartedSpeakingFrame(),
|
|
TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"),
|
|
SleepFrame(),
|
|
VADUserStoppedSpeakingFrame(),
|
|
SleepFrame(sleep=1.0),
|
|
]
|
|
expected_down_frames = [
|
|
VADUserStartedSpeakingFrame,
|
|
UserStartedSpeakingFrame,
|
|
InterruptionFrame,
|
|
VADUserStoppedSpeakingFrame,
|
|
UserStoppedSpeakingFrame,
|
|
LLMContextFrame,
|
|
LLMContextAssistantTimestampFrame,
|
|
]
|
|
await run_test(
|
|
pipeline,
|
|
frames_to_send=frames_to_send,
|
|
expected_down_frames=expected_down_frames,
|
|
)
|
|
|
|
self.assertEqual("".join(self.mock_proc.token), self.expected_response)
|
|
self.assertEqual(
|
|
context_aggregator.assistant().messages[-1]["content"], self.expected_response
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|