Merge pull request #190 from TomTom101/TomTom101/langchain
Langchain service
This commit is contained in:
130
examples/foundational/07b-interruptible-langchain.py
Normal file
130
examples/foundational/07b-interruptible-langchain.py
Normal file
@@ -0,0 +1,130 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
|
||||
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||
from pipecat.services.langchain import LangchainProcessor
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
try:
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.chat_message_histories import ChatMessageHistory
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.exception(
|
||||
"In order to run this example you need to `pip install pipecat-ai[langchain] langchain-community langchain-openai. Also, be sure to set `OPENAI_API_KEY` in the environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
message_store = {}
|
||||
|
||||
|
||||
def get_session_history(session_id: str) -> BaseChatMessageHistory:
|
||||
if session_id not in message_store:
|
||||
message_store[session_id] = ChatMessageHistory()
|
||||
return message_store[session_id]
|
||||
|
||||
|
||||
async def main(room_url: str, token):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
aiohttp_session=session,
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY"),
|
||||
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system",
|
||||
"Be nice and helpful. Answer very briefly and without special characters like `#` or `*`. "
|
||||
"Your response will be synthesized to voice and those characters will create unnatural sounds.",
|
||||
),
|
||||
MessagesPlaceholder("chat_history"),
|
||||
("human", "{input}"),
|
||||
])
|
||||
chain = prompt | ChatOpenAI(model="gpt-4o", temperature=0.7)
|
||||
history_chain = RunnableWithMessageHistory(
|
||||
chain,
|
||||
get_session_history,
|
||||
history_messages_key="chat_history",
|
||||
input_messages_key="input")
|
||||
lc = LangchainProcessor(history_chain)
|
||||
|
||||
tma_in = LLMUserResponseAggregator()
|
||||
tma_out = LLMAssistantResponseAggregator()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
tma_in, # User responses
|
||||
lc, # Langchain
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
tma_out, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, allow_interruptions=True)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
lc.set_participant_id(participant["id"])
|
||||
# Kick off the conversation.
|
||||
# the `LLMMessagesFrame` will be picked up by the LangchainProcessor using
|
||||
# only the content of the last message to inject it in the prompt defined
|
||||
# above. So no role is required here.
|
||||
messages = [(
|
||||
{
|
||||
"content": "Please briefly introduce yourself to the user."
|
||||
}
|
||||
)]
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
(url, token) = configure()
|
||||
asyncio.run(main(url, token))
|
||||
@@ -41,6 +41,7 @@ examples = [ "python-dotenv~=1.0.0", "flask~=3.0.3", "flask_cors~=4.0.1" ]
|
||||
fal = [ "fal-client~=0.4.0" ]
|
||||
google = [ "google-generativeai~=0.5.3" ]
|
||||
fireworks = [ "openai~=1.26.0" ]
|
||||
langchain = [ "langchain~=0.2.1" ]
|
||||
local = [ "pyaudio~=0.2.0" ]
|
||||
moondream = [ "einops~=0.8.0", "timm~=0.9.16", "transformers~=4.40.2" ]
|
||||
openai = [ "openai~=1.26.0" ]
|
||||
|
||||
79
src/pipecat/services/langchain.py
Normal file
79
src/pipecat/services/langchain.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import sys
|
||||
from typing import Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (Frame, LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame, LLMMessagesFrame,
|
||||
LLMResponseEndFrame, LLMResponseStartFrame,
|
||||
TextFrame)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
try:
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.runnables import Runnable
|
||||
except ModuleNotFoundError as e:
|
||||
logger.exception(
|
||||
"In order to use Langchain, you need to `pip install pipecat-ai[langchain]`. "
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class LangchainProcessor(FrameProcessor):
|
||||
def __init__(self, chain: Runnable, transcript_key: str = "input"):
|
||||
super().__init__()
|
||||
self._chain = chain
|
||||
self._transcript_key = transcript_key
|
||||
self._participant_id: str | None = None
|
||||
|
||||
def set_participant_id(self, participant_id: str):
|
||||
self._participant_id = participant_id
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, LLMMessagesFrame):
|
||||
# Messages are accumulated by the `LLMUserResponseAggregator` in a list of messages.
|
||||
# The last one by the human is the one we want to send to the LLM.
|
||||
logger.debug(f"Got transcription frame {frame}")
|
||||
text: str = frame.messages[-1]["content"]
|
||||
|
||||
await self._ainvoke(text.strip())
|
||||
else:
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _invoke(self, text: str):
|
||||
response = await self._chain.ainvoke(
|
||||
{self._transcript_key: text},
|
||||
config={"configurable": {"session_id": self._participant_id}},
|
||||
)
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.push_frame(TextFrame(response))
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@staticmethod
|
||||
def __get_token_value(text: Union[str, AIMessageChunk]) -> str | None:
|
||||
match text:
|
||||
case str():
|
||||
return text
|
||||
case AIMessageChunk():
|
||||
return text.content
|
||||
case _:
|
||||
return None
|
||||
|
||||
async def _ainvoke(self, text: str):
|
||||
logger.debug(f"Invoking chain with {text}")
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
try:
|
||||
async for token in self._chain.astream(
|
||||
{self._transcript_key: text},
|
||||
config={"configurable": {"session_id": self._participant_id}},
|
||||
):
|
||||
await self.push_frame(LLMResponseStartFrame())
|
||||
await self.push_frame(TextFrame(self.__get_token_value(token)))
|
||||
await self.push_frame(LLMResponseEndFrame())
|
||||
except GeneratorExit:
|
||||
logger.warning("Generator was closed prematurely")
|
||||
raise # Re-raise to ensure proper generator closure
|
||||
except Exception as e:
|
||||
logger.error(f"An unknown error occurred: {e}")
|
||||
raise
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
86
tests/test_langchain.py
Normal file
86
tests/test_langchain.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import unittest
|
||||
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain_core.language_models import FakeStreamingListLLM
|
||||
|
||||
from pipecat.frames.frames import (LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame, StopTaskFrame,
|
||||
TextFrame, TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.services.langchain import LangchainProcessor
|
||||
|
||||
|
||||
class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
class MockProcessor(FrameProcessor):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.token: list[str] = []
|
||||
# Start collecting tokens when we see the start frame
|
||||
self.start_collecting = False
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
if isinstance(frame, LLMFullResponseStartFrame):
|
||||
self.start_collecting = True
|
||||
elif isinstance(frame, TextFrame) and self.start_collecting:
|
||||
self.token.append(frame.text)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
self.start_collecting = False
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def setUp(self):
|
||||
self.expected_response = "Hello dear human"
|
||||
self.fake_llm = FakeStreamingListLLM(responses=[self.expected_response])
|
||||
self.mock_proc = self.MockProcessor("token_collector")
|
||||
|
||||
async def test_langchain(self):
|
||||
|
||||
messages = [("system", "Say hello to {name}"), ("human", "{input}")]
|
||||
prompt = ChatPromptTemplate.from_messages(messages).partial(name="Thomas")
|
||||
chain = prompt | self.fake_llm
|
||||
proc = LangchainProcessor(chain=chain)
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
tma_in,
|
||||
proc,
|
||||
self.mock_proc,
|
||||
tma_out,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline)
|
||||
await task.queue_frames(
|
||||
[
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"),
|
||||
UserStoppedSpeakingFrame(),
|
||||
StopTaskFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
runner = PipelineRunner()
|
||||
await runner.run(task)
|
||||
self.assertEqual("".join(self.mock_proc.token), self.expected_response)
|
||||
# TODO: Address this issue
|
||||
# This next one would fail with:
|
||||
# AssertionError: ' H e l l o d e a r h u m a n' != 'Hello dear human'
|
||||
# self.assertEqual(tma_out.messages[-1]["content"], self.expected_response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user