tests: fix langchanin tests

This commit is contained in:
Aleix Conchillo Flaqué
2024-09-30 11:37:26 -07:00
parent a90ebdfe7c
commit d080a31a5c

View File

@@ -7,9 +7,9 @@
import unittest
from pipecat.frames.frames import (
EndFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
StopTaskFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
@@ -32,6 +32,7 @@ from langchain_core.language_models import FakeStreamingListLLM
class TestLangchain(unittest.IsolatedAsyncioTestCase):
class MockProcessor(FrameProcessor):
def __init__(self, name):
super().__init__()
self.name = name
self.token: list[str] = []
# Start collecting tokens when we see the start frame
@@ -55,13 +56,13 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.expected_response = "Hello dear human"
self.fake_llm = FakeStreamingListLLM(responses=[self.expected_response])
self.mock_proc = self.MockProcessor("token_collector")
async def test_langchain(self):
messages = [("system", "Say hello to {name}"), ("human", "{input}")]
prompt = ChatPromptTemplate.from_messages(messages).partial(name="Thomas")
chain = prompt | self.fake_llm
proc = LangchainProcessor(chain=chain)
self.mock_proc = self.MockProcessor("token_collector")
tma_in = LLMUserResponseAggregator(messages)
tma_out = LLMAssistantResponseAggregator(messages)
@@ -81,7 +82,7 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase):
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"),
UserStoppedSpeakingFrame(),
StopTaskFrame(),
EndFrame(),
]
)