diff --git a/examples/foundational/07b-interruptible-langchain.py b/examples/foundational/07b-interruptible-langchain.py new file mode 100644 index 000000000..5e32964f9 --- /dev/null +++ b/examples/foundational/07b-interruptible-langchain.py @@ -0,0 +1,130 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + + +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.frames.frames import LLMMessagesFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineTask +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantResponseAggregator, LLMUserResponseAggregator) +from pipecat.services.elevenlabs import ElevenLabsTTSService +from pipecat.services.langchain import LangchainProcessor +from pipecat.transports.services.daily import DailyParams, DailyTransport +from pipecat.vad.silero import SileroVADAnalyzer + +load_dotenv(override=True) + +try: + from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder + from langchain_community.chat_message_histories import ChatMessageHistory + from langchain_core.chat_history import BaseChatMessageHistory + from langchain_core.runnables.history import RunnableWithMessageHistory + from langchain_openai import ChatOpenAI + +except ModuleNotFoundError as e: + logger.exception( + "In order to run this example you need to `pip install pipecat-ai[langchain] langchain-community langchain-openai. Also, be sure to set `OPENAI_API_KEY` in the environment variable." + ) + raise Exception(f"Missing module: {e}") + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + +message_store = {} + + +def get_session_history(session_id: str) -> BaseChatMessageHistory: + if session_id not in message_store: + message_store[session_id] = ChatMessageHistory() + return message_store[session_id] + + +async def main(room_url: str, token): + async with aiohttp.ClientSession() as session: + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + + tts = ElevenLabsTTSService( + aiohttp_session=session, + api_key=os.getenv("ELEVENLABS_API_KEY"), + voice_id=os.getenv("ELEVENLABS_VOICE_ID"), + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", + "Be nice and helpful. Answer very briefly and without special characters like `#` or `*`. " + "Your response will be synthesized to voice and those characters will create unnatural sounds.", + ), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ]) + chain = prompt | ChatOpenAI(model="gpt-4o", temperature=0.7) + history_chain = RunnableWithMessageHistory( + chain, + get_session_history, + history_messages_key="chat_history", + input_messages_key="input") + lc = LangchainProcessor(history_chain) + + tma_in = LLMUserResponseAggregator() + tma_out = LLMAssistantResponseAggregator() + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + tma_in, # User responses + lc, # Langchain + tts, # TTS + transport.output(), # Transport bot output + tma_out, # Assistant spoken responses + ] + ) + + task = PipelineTask(pipeline, allow_interruptions=True) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + transport.capture_participant_transcription(participant["id"]) + lc.set_participant_id(participant["id"]) + # Kick off the conversation. + # the `LLMMessagesFrame` will be picked up by the LangchainProcessor using + # only the content of the last message to inject it in the prompt defined + # above. So no role is required here. + messages = [( + { + "content": "Please briefly introduce yourself to the user." + } + )] + await task.queue_frames([LLMMessagesFrame(messages)]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + (url, token) = configure() + asyncio.run(main(url, token)) diff --git a/pyproject.toml b/pyproject.toml index f52db355a..90363e34b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ examples = [ "python-dotenv~=1.0.0", "flask~=3.0.3", "flask_cors~=4.0.1" ] fal = [ "fal-client~=0.4.0" ] google = [ "google-generativeai~=0.5.3" ] fireworks = [ "openai~=1.26.0" ] +langchain = [ "langchain~=0.2.1" ] local = [ "pyaudio~=0.2.0" ] moondream = [ "einops~=0.8.0", "timm~=0.9.16", "transformers~=4.40.2" ] openai = [ "openai~=1.26.0" ] diff --git a/src/pipecat/services/langchain.py b/src/pipecat/services/langchain.py new file mode 100644 index 000000000..83522a16c --- /dev/null +++ b/src/pipecat/services/langchain.py @@ -0,0 +1,79 @@ +import sys +from typing import Union + +from loguru import logger + +from pipecat.frames.frames import (Frame, LLMFullResponseEndFrame, + LLMFullResponseStartFrame, LLMMessagesFrame, + LLMResponseEndFrame, LLMResponseStartFrame, + TextFrame) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + +try: + from langchain_core.messages import AIMessageChunk + from langchain_core.runnables import Runnable +except ModuleNotFoundError as e: + logger.exception( + "In order to use Langchain, you need to `pip install pipecat-ai[langchain]`. " + ) + raise Exception(f"Missing module: {e}") + + +class LangchainProcessor(FrameProcessor): + def __init__(self, chain: Runnable, transcript_key: str = "input"): + super().__init__() + self._chain = chain + self._transcript_key = transcript_key + self._participant_id: str | None = None + + def set_participant_id(self, participant_id: str): + self._participant_id = participant_id + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, LLMMessagesFrame): + # Messages are accumulated by the `LLMUserResponseAggregator` in a list of messages. + # The last one by the human is the one we want to send to the LLM. + logger.debug(f"Got transcription frame {frame}") + text: str = frame.messages[-1]["content"] + + await self._ainvoke(text.strip()) + else: + await self.push_frame(frame) + + async def _invoke(self, text: str): + response = await self._chain.ainvoke( + {self._transcript_key: text}, + config={"configurable": {"session_id": self._participant_id}}, + ) + await self.push_frame(LLMFullResponseStartFrame()) + await self.push_frame(TextFrame(response)) + await self.push_frame(LLMFullResponseEndFrame()) + + @staticmethod + def __get_token_value(text: Union[str, AIMessageChunk]) -> str | None: + match text: + case str(): + return text + case AIMessageChunk(): + return text.content + case _: + return None + + async def _ainvoke(self, text: str): + logger.debug(f"Invoking chain with {text}") + await self.push_frame(LLMFullResponseStartFrame()) + try: + async for token in self._chain.astream( + {self._transcript_key: text}, + config={"configurable": {"session_id": self._participant_id}}, + ): + await self.push_frame(LLMResponseStartFrame()) + await self.push_frame(TextFrame(self.__get_token_value(token))) + await self.push_frame(LLMResponseEndFrame()) + except GeneratorExit: + logger.warning("Generator was closed prematurely") + raise # Re-raise to ensure proper generator closure + except Exception as e: + logger.error(f"An unknown error occurred: {e}") + raise + await self.push_frame(LLMFullResponseEndFrame()) diff --git a/tests/test_langchain.py b/tests/test_langchain.py new file mode 100644 index 000000000..0f3ccdc86 --- /dev/null +++ b/tests/test_langchain.py @@ -0,0 +1,86 @@ +import unittest + +from langchain.prompts import ChatPromptTemplate +from langchain_core.language_models import FakeStreamingListLLM + +from pipecat.frames.frames import (LLMFullResponseEndFrame, + LLMFullResponseStartFrame, StopTaskFrame, + TextFrame, TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame) +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineTask +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantResponseAggregator, LLMUserResponseAggregator) +from pipecat.processors.frame_processor import FrameProcessor +from pipecat.services.langchain import LangchainProcessor + + +class TestLangchain(unittest.IsolatedAsyncioTestCase): + + class MockProcessor(FrameProcessor): + def __init__(self, name): + self.name = name + self.token: list[str] = [] + # Start collecting tokens when we see the start frame + self.start_collecting = False + + def __str__(self): + return self.name + + async def process_frame(self, frame, direction): + if isinstance(frame, LLMFullResponseStartFrame): + self.start_collecting = True + elif isinstance(frame, TextFrame) and self.start_collecting: + self.token.append(frame.text) + elif isinstance(frame, LLMFullResponseEndFrame): + self.start_collecting = False + + await self.push_frame(frame, direction) + + def setUp(self): + self.expected_response = "Hello dear human" + self.fake_llm = FakeStreamingListLLM(responses=[self.expected_response]) + self.mock_proc = self.MockProcessor("token_collector") + + async def test_langchain(self): + + messages = [("system", "Say hello to {name}"), ("human", "{input}")] + prompt = ChatPromptTemplate.from_messages(messages).partial(name="Thomas") + chain = prompt | self.fake_llm + proc = LangchainProcessor(chain=chain) + + tma_in = LLMUserResponseAggregator(messages) + tma_out = LLMAssistantResponseAggregator(messages) + + pipeline = Pipeline( + [ + tma_in, + proc, + self.mock_proc, + tma_out, + ] + ) + + task = PipelineTask(pipeline) + await task.queue_frames( + [ + UserStartedSpeakingFrame(), + TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"), + UserStoppedSpeakingFrame(), + StopTaskFrame(), + ] + ) + + runner = PipelineRunner() + await runner.run(task) + self.assertEqual("".join(self.mock_proc.token), self.expected_response) + # TODO: Address this issue + # This next one would fail with: + # AssertionError: ' H e l l o d e a r h u m a n' != 'Hello dear human' + # self.assertEqual(tma_out.messages[-1]["content"], self.expected_response) + + +if __name__ == "__main__": + unittest.main()