wip: First stab at langchain support

Is this a service or processor?
How to deal with conversation history? LC has sophisticated means of this, but might get in the way of `LLMResponseAggregator`
This commit is contained in:
TomTom101
2024-05-28 10:51:42 +02:00
parent c444004eec
commit 278a2fed56
4 changed files with 231 additions and 0 deletions

View File

@@ -0,0 +1,111 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.llm_response import (
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
from pipecat.services.elevenlabs import ElevenLabsTTSService
from pipecat.services.langchain import LangchainProcessor
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer
load_dotenv(override=True)
try:
from langchain.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
except ModuleNotFoundError as e:
logger.exception(
"You need to `pip install langchain_openai` for this example. Also, be sure to set `OPENAI_API_KEY` in the environment variable."
)
raise Exception(f"Missing module: {e}")
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
async def main(room_url: str, token):
async with aiohttp.ClientSession() as session:
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
tts = ElevenLabsTTSService(
aiohttp_session=session,
api_key=os.getenv("ELEVENLABS_API_KEY"),
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
)
llm = ChatOpenAI(model="gpt-4o", temperature=0.7)
prompt = ChatPromptTemplate.from_messages(
[
("system",
"Be nice and helpful. Answer very briefly and without special characters like `#` or `*`. Your response will be synthesized to voice and those characters will create unnatural sounds.",
),
("human",
"{input}"),
])
chain = prompt | llm
lc = LangchainProcessor(chain)
tma_in = LLMUserResponseAggregator()
tma_out = LLMAssistantResponseAggregator()
pipeline = Pipeline(
[
transport.input(), # Transport user input
tma_in, # User responses
lc, # Langchain
tts, # TTS
transport.output(), # Transport bot output
tma_out, # Assistant spoken responses
]
)
task = PipelineTask(pipeline, allow_interruptions=True)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
# the `LLMMessagesFrame` will be picked up by the LangchainProcessor using
# only the content of the last message to inject it in the prompt defined
# above. So no role is required here.
messages = [(
{
"content": "Please briefly introduce yourself to the user."
}
)]
await task.queue_frames([LLMMessagesFrame(messages)])
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
(url, token) = configure()
asyncio.run(main(url, token))

View File

@@ -40,6 +40,7 @@ examples = [ "python-dotenv~=1.0.0", "flask~=3.0.3", "flask_cors~=4.0.1" ]
fal = [ "fal-client~=0.4.0" ]
google = [ "google-generativeai~=0.5.3" ]
fireworks = [ "openai~=1.26.0" ]
langchain = [ "langchain~=0.2.1" ]
local = [ "pyaudio~=0.2.0" ]
moondream = [ "einops~=0.8.0", "timm~=0.9.16", "transformers~=4.40.2" ]
openai = [ "openai~=1.26.0" ]

View File

@@ -0,0 +1,62 @@
import sys
from typing import Union
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables import Runnable
from loguru import logger
from pipecat.frames.frames import (Frame, LLMFullResponseEndFrame,
LLMFullResponseStartFrame, LLMMessagesFrame,
LLMResponseEndFrame, LLMResponseStartFrame,
TextFrame)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class LangchainProcessor(FrameProcessor):
def __init__(self, chain: Runnable, transcript_key: str = "input"):
super().__init__()
self._chain = chain
self._transcript_key = transcript_key
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, LLMMessagesFrame):
# Messages are accumulated by the `LLMUserResponseAggregator` in a list of messages.
# The last one by the human is the one we want to send to the LLM.
logger.debug(f"Got transcription frame {frame}")
text: str = frame.messages[-1]["content"]
await self._ainvoke(text.strip())
else:
await self.push_frame(frame)
async def _invoke(self, text: str):
response = await self._chain.ainvoke({self._transcript_key: text})
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(TextFrame(response))
await self.push_frame(LLMFullResponseEndFrame())
@staticmethod
def __get_token_value(text: Union[str, AIMessageChunk]) -> str | None:
match text:
case str():
return text
case AIMessageChunk():
return text.content
case _:
return None
async def _ainvoke(self, text: str):
logger.debug(f"Invoking chain with {text}")
await self.push_frame(LLMFullResponseStartFrame())
try:
async for token in self._chain.astream({self._transcript_key: text}):
await self.push_frame(LLMResponseStartFrame())
await self.push_frame(TextFrame(self.__get_token_value(token)))
await self.push_frame(LLMResponseEndFrame())
except GeneratorExit:
logger.warning("Generator was closed prematurely")
raise # Re-raise to ensure proper generator closure
except Exception as e:
logger.error(f"An unknown error occurred: {e}")
raise
await self.push_frame(LLMFullResponseEndFrame())

57
tests/test_langchain.py Normal file
View File

@@ -0,0 +1,57 @@
import pytest
from langchain.prompts import ChatPromptTemplate
from langchain_core.language_models import FakeStreamingListLLM
from pipecat.frames.frames import (StopTaskFrame, TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.llm_response import (
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
from pipecat.processors.logger import FrameLogger
from pipecat.services.langchain import LangchainProcessor
@pytest.fixture
def fake_llm():
responses = ["Hello dear human"]
return FakeStreamingListLLM(responses=responses)
@pytest.mark.asyncio
async def test_langchain(fake_llm: FakeStreamingListLLM):
fl_in = FrameLogger("Inner")
fl_out = FrameLogger("Outer")
messages = [("system", "Say hello to {name}"), ("human", "{input}")]
prompt = ChatPromptTemplate.from_messages(messages).partial(name="Thomas")
chain = prompt | fake_llm
proc = LangchainProcessor(chain=chain)
tma_in = LLMUserResponseAggregator(messages)
tma_out = LLMAssistantResponseAggregator(messages)
pipeline = Pipeline(
[
fl_in,
tma_in,
proc,
tma_out,
fl_out,
]
)
task = PipelineTask(pipeline)
await task.queue_frames(
[
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"),
UserStoppedSpeakingFrame(),
StopTaskFrame(),
]
)
runner = PipelineRunner()
await runner.run(task)