From 278a2fed56040098dd0e21a70fe503915a43500a Mon Sep 17 00:00:00 2001 From: TomTom101 Date: Tue, 28 May 2024 10:51:42 +0200 Subject: [PATCH] wip: First stab at langchain support Is this a service or processor? How to deal with conversation history? LC has sophisticated means of this, but might get in the way of `LLMResponseAggregator` --- .../07b-interruptible-langchain.py | 111 ++++++++++++++++++ pyproject.toml | 1 + src/pipecat/services/langchain.py | 62 ++++++++++ tests/test_langchain.py | 57 +++++++++ 4 files changed, 231 insertions(+) create mode 100644 examples/foundational/07b-interruptible-langchain.py create mode 100644 src/pipecat/services/langchain.py create mode 100644 tests/test_langchain.py diff --git a/examples/foundational/07b-interruptible-langchain.py b/examples/foundational/07b-interruptible-langchain.py new file mode 100644 index 000000000..40082f994 --- /dev/null +++ b/examples/foundational/07b-interruptible-langchain.py @@ -0,0 +1,111 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.frames.frames import LLMMessagesFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineTask +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantResponseAggregator, LLMUserResponseAggregator) +from pipecat.services.elevenlabs import ElevenLabsTTSService +from pipecat.services.langchain import LangchainProcessor +from pipecat.transports.services.daily import DailyParams, DailyTransport +from pipecat.vad.silero import SileroVADAnalyzer + +load_dotenv(override=True) + +try: + from langchain.prompts import ChatPromptTemplate + from langchain_openai import ChatOpenAI +except ModuleNotFoundError as e: + logger.exception( + "You need to `pip install langchain_openai` for this example. Also, be sure to set `OPENAI_API_KEY` in the environment variable." + ) + raise Exception(f"Missing module: {e}") + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +async def main(room_url: str, token): + async with aiohttp.ClientSession() as session: + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + + tts = ElevenLabsTTSService( + aiohttp_session=session, + api_key=os.getenv("ELEVENLABS_API_KEY"), + voice_id=os.getenv("ELEVENLABS_VOICE_ID"), + ) + + llm = ChatOpenAI(model="gpt-4o", temperature=0.7) + prompt = ChatPromptTemplate.from_messages( + [ + ("system", + "Be nice and helpful. Answer very briefly and without special characters like `#` or `*`. Your response will be synthesized to voice and those characters will create unnatural sounds.", + ), + ("human", + "{input}"), + ]) + chain = prompt | llm + lc = LangchainProcessor(chain) + + tma_in = LLMUserResponseAggregator() + tma_out = LLMAssistantResponseAggregator() + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + tma_in, # User responses + lc, # Langchain + tts, # TTS + transport.output(), # Transport bot output + tma_out, # Assistant spoken responses + ] + ) + + task = PipelineTask(pipeline, allow_interruptions=True) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + # the `LLMMessagesFrame` will be picked up by the LangchainProcessor using + # only the content of the last message to inject it in the prompt defined + # above. So no role is required here. + messages = [( + { + "content": "Please briefly introduce yourself to the user." + } + )] + await task.queue_frames([LLMMessagesFrame(messages)]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + (url, token) = configure() + asyncio.run(main(url, token)) diff --git a/pyproject.toml b/pyproject.toml index 23245cfdc..75b8c0f72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ examples = [ "python-dotenv~=1.0.0", "flask~=3.0.3", "flask_cors~=4.0.1" ] fal = [ "fal-client~=0.4.0" ] google = [ "google-generativeai~=0.5.3" ] fireworks = [ "openai~=1.26.0" ] +langchain = [ "langchain~=0.2.1" ] local = [ "pyaudio~=0.2.0" ] moondream = [ "einops~=0.8.0", "timm~=0.9.16", "transformers~=4.40.2" ] openai = [ "openai~=1.26.0" ] diff --git a/src/pipecat/services/langchain.py b/src/pipecat/services/langchain.py new file mode 100644 index 000000000..6675005eb --- /dev/null +++ b/src/pipecat/services/langchain.py @@ -0,0 +1,62 @@ +import sys +from typing import Union + +from langchain_core.messages import AIMessageChunk +from langchain_core.runnables import Runnable +from loguru import logger + +from pipecat.frames.frames import (Frame, LLMFullResponseEndFrame, + LLMFullResponseStartFrame, LLMMessagesFrame, + LLMResponseEndFrame, LLMResponseStartFrame, + TextFrame) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + + +class LangchainProcessor(FrameProcessor): + def __init__(self, chain: Runnable, transcript_key: str = "input"): + super().__init__() + self._chain = chain + self._transcript_key = transcript_key + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, LLMMessagesFrame): + # Messages are accumulated by the `LLMUserResponseAggregator` in a list of messages. + # The last one by the human is the one we want to send to the LLM. + logger.debug(f"Got transcription frame {frame}") + text: str = frame.messages[-1]["content"] + + await self._ainvoke(text.strip()) + else: + await self.push_frame(frame) + + async def _invoke(self, text: str): + response = await self._chain.ainvoke({self._transcript_key: text}) + await self.push_frame(LLMFullResponseStartFrame()) + await self.push_frame(TextFrame(response)) + await self.push_frame(LLMFullResponseEndFrame()) + + @staticmethod + def __get_token_value(text: Union[str, AIMessageChunk]) -> str | None: + match text: + case str(): + return text + case AIMessageChunk(): + return text.content + case _: + return None + + async def _ainvoke(self, text: str): + logger.debug(f"Invoking chain with {text}") + await self.push_frame(LLMFullResponseStartFrame()) + try: + async for token in self._chain.astream({self._transcript_key: text}): + await self.push_frame(LLMResponseStartFrame()) + await self.push_frame(TextFrame(self.__get_token_value(token))) + await self.push_frame(LLMResponseEndFrame()) + except GeneratorExit: + logger.warning("Generator was closed prematurely") + raise # Re-raise to ensure proper generator closure + except Exception as e: + logger.error(f"An unknown error occurred: {e}") + raise + await self.push_frame(LLMFullResponseEndFrame()) diff --git a/tests/test_langchain.py b/tests/test_langchain.py new file mode 100644 index 000000000..e204c56a2 --- /dev/null +++ b/tests/test_langchain.py @@ -0,0 +1,57 @@ +import pytest +from langchain.prompts import ChatPromptTemplate +from langchain_core.language_models import FakeStreamingListLLM + +from pipecat.frames.frames import (StopTaskFrame, TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame) +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineTask +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantResponseAggregator, LLMUserResponseAggregator) +from pipecat.processors.logger import FrameLogger +from pipecat.services.langchain import LangchainProcessor + + +@pytest.fixture +def fake_llm(): + responses = ["Hello dear human"] + return FakeStreamingListLLM(responses=responses) + + +@pytest.mark.asyncio +async def test_langchain(fake_llm: FakeStreamingListLLM): + fl_in = FrameLogger("Inner") + fl_out = FrameLogger("Outer") + + messages = [("system", "Say hello to {name}"), ("human", "{input}")] + prompt = ChatPromptTemplate.from_messages(messages).partial(name="Thomas") + chain = prompt | fake_llm + proc = LangchainProcessor(chain=chain) + + tma_in = LLMUserResponseAggregator(messages) + tma_out = LLMAssistantResponseAggregator(messages) + + pipeline = Pipeline( + [ + fl_in, + tma_in, + proc, + tma_out, + fl_out, + ] + ) + + task = PipelineTask(pipeline) + await task.queue_frames( + [ + UserStartedSpeakingFrame(), + TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"), + UserStoppedSpeakingFrame(), + StopTaskFrame(), + ] + ) + + runner = PipelineRunner() + await runner.run(task)