diff --git a/examples/foundational/07l-interruptible-together.py b/examples/foundational/07l-interruptible-together.py
index ca3386718..a99b07a1a 100644
--- a/examples/foundational/07l-interruptible-together.py
+++ b/examples/foundational/07l-interruptible-together.py
@@ -67,7 +67,7 @@ async def main():
messages = [
{
"role": "system",
- "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
+ "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond in plain language. Respond to what the user said in a creative and helpful way.",
},
]
@@ -87,7 +87,12 @@ async def main():
]
)
- task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
+ task = PipelineTask(
+ pipeline,
+ PipelineParams(
+ allow_interruptions=True, enable_metrics=True, enable_usage_metrics=True
+ ),
+ )
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
diff --git a/examples/foundational/14-function-calling.py b/examples/foundational/14-function-calling.py
index 9141029ca..35a02743b 100644
--- a/examples/foundational/14-function-calling.py
+++ b/examples/foundational/14-function-calling.py
@@ -9,11 +9,9 @@ import aiohttp
import os
import sys
-from pipecat.frames.frames import TextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
-from pipecat.processors.logger import FrameLogger
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.openai import OpenAILLMContext, OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
@@ -72,9 +70,6 @@ async def main():
# sent to the same callback with an additional function_name parameter.
llm.register_function(None, fetch_weather_from_api, start_callback=start_fetch_weather)
- fl_in = FrameLogger("Inner")
- fl_out = FrameLogger("Outer")
-
tools = [
ChatCompletionToolParam(
type="function",
@@ -111,11 +106,9 @@ async def main():
pipeline = Pipeline(
[
- # fl_in,
transport.input(),
context_aggregator.user(),
llm,
- # fl_out,
tts,
transport.output(),
context_aggregator.assistant(),
diff --git a/examples/foundational/19a-tools-anthropic.py b/examples/foundational/14a-function-calling-anthropic.py
similarity index 100%
rename from examples/foundational/19a-tools-anthropic.py
rename to examples/foundational/14a-function-calling-anthropic.py
diff --git a/examples/foundational/19b-tools-video-anthropic.py b/examples/foundational/14b-function-calling-anthropic-video.py
similarity index 100%
rename from examples/foundational/19b-tools-video-anthropic.py
rename to examples/foundational/14b-function-calling-anthropic-video.py
diff --git a/examples/foundational/14c-function-calling-together.py b/examples/foundational/14c-function-calling-together.py
new file mode 100644
index 000000000..ebfc4b5df
--- /dev/null
+++ b/examples/foundational/14c-function-calling-together.py
@@ -0,0 +1,136 @@
+#
+# Copyright (c) 2024, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import asyncio
+import aiohttp
+import os
+import sys
+
+from pipecat.pipeline.pipeline import Pipeline
+from pipecat.pipeline.runner import PipelineRunner
+from pipecat.pipeline.task import PipelineTask
+from pipecat.services.cartesia import CartesiaTTSService
+from pipecat.services.openai import OpenAILLMContext
+from pipecat.services.together import TogetherLLMService
+from pipecat.transports.services.daily import DailyParams, DailyTransport
+from pipecat.vad.silero import SileroVADAnalyzer
+
+from openai.types.chat import ChatCompletionToolParam
+
+from runner import configure
+
+from loguru import logger
+
+from dotenv import load_dotenv
+
+load_dotenv(override=True)
+
+logger.remove(0)
+logger.add(sys.stderr, level="DEBUG")
+
+
+async def start_fetch_weather(function_name, llm, context):
+ # note: we can't push a frame to the LLM here. the bot
+ # can interrupt itself and/or cause audio overlapping glitches.
+ # possible question for Aleix and Chad about what the right way
+ # to trigger speech is, now, with the new queues/async/sync refactors.
+ # await llm.push_frame(TextFrame("Let me check on that."))
+ logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
+
+
+async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
+ await result_callback({"conditions": "nice", "temperature": "75"})
+
+
+async def main():
+ async with aiohttp.ClientSession() as session:
+ (room_url, token) = await configure(session)
+
+ transport = DailyTransport(
+ room_url,
+ token,
+ "Respond bot",
+ DailyParams(
+ audio_out_enabled=True,
+ transcription_enabled=True,
+ vad_enabled=True,
+ vad_analyzer=SileroVADAnalyzer(),
+ ),
+ )
+
+ tts = CartesiaTTSService(
+ api_key=os.getenv("CARTESIA_API_KEY"),
+ voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
+ )
+
+ llm = TogetherLLMService(
+ api_key=os.getenv("TOGETHER_API_KEY"),
+ model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
+ )
+ # Register a function_name of None to get all functions
+ # sent to the same callback with an additional function_name parameter.
+ llm.register_function(None, fetch_weather_from_api, start_callback=start_fetch_weather)
+
+ tools = [
+ ChatCompletionToolParam(
+ type="function",
+ function={
+ "name": "get_current_weather",
+ "description": "Get the current weather",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "format": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "The temperature unit to use. Infer this from the users location.",
+ },
+ },
+ "required": ["location", "format"],
+ },
+ },
+ )
+ ]
+ messages = [
+ {
+ "role": "system",
+ "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
+ },
+ ]
+
+ context = OpenAILLMContext(messages, tools)
+ context_aggregator = llm.create_context_aggregator(context)
+
+ pipeline = Pipeline(
+ [
+ transport.input(),
+ context_aggregator.user(),
+ llm,
+ tts,
+ transport.output(),
+ context_aggregator.assistant(),
+ ]
+ )
+
+ task = PipelineTask(pipeline)
+
+ @transport.event_handler("on_first_participant_joined")
+ async def on_first_participant_joined(transport, participant):
+ transport.capture_participant_transcription(participant["id"])
+ # Kick off the conversation.
+ # await tts.say("Hi! Ask me about the weather in San Francisco.")
+
+ runner = PipelineRunner()
+
+ await runner.run(task)
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/examples/foundational/14d-function-calling-video.py b/examples/foundational/14d-function-calling-video.py
new file mode 100644
index 000000000..f42665d5b
--- /dev/null
+++ b/examples/foundational/14d-function-calling-video.py
@@ -0,0 +1,167 @@
+#
+# Copyright (c) 2024, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import asyncio
+import aiohttp
+import os
+import sys
+
+from pipecat.pipeline.pipeline import Pipeline
+from pipecat.pipeline.runner import PipelineRunner
+from pipecat.pipeline.task import PipelineTask
+from pipecat.services.cartesia import CartesiaTTSService
+from pipecat.services.openai import OpenAILLMContext, OpenAILLMService
+from pipecat.transports.services.daily import DailyParams, DailyTransport
+from pipecat.vad.silero import SileroVADAnalyzer
+
+from openai.types.chat import ChatCompletionToolParam
+
+from runner import configure
+
+from loguru import logger
+
+from dotenv import load_dotenv
+
+load_dotenv(override=True)
+
+logger.remove(0)
+logger.add(sys.stderr, level="DEBUG")
+
+video_participant_id = None
+
+
+async def get_weather(function_name, tool_call_id, arguments, llm, context, result_callback):
+ location = arguments["location"]
+ await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
+
+
+async def get_image(function_name, tool_call_id, arguments, llm, context, result_callback):
+ logger.debug(f"!!! IN get_image {video_participant_id}, {arguments}")
+ question = arguments["question"]
+ await llm.request_image_frame(user_id=video_participant_id, text_content=question)
+
+
+async def main():
+ async with aiohttp.ClientSession() as session:
+ (room_url, token) = await configure(session)
+
+ transport = DailyTransport(
+ room_url,
+ token,
+ "Respond bot",
+ DailyParams(
+ audio_out_enabled=True,
+ transcription_enabled=True,
+ vad_enabled=True,
+ vad_analyzer=SileroVADAnalyzer(),
+ ),
+ )
+
+ tts = CartesiaTTSService(
+ api_key=os.getenv("CARTESIA_API_KEY"),
+ voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
+ )
+
+ llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
+ llm.register_function("get_weather", get_weather)
+ llm.register_function("get_image", get_image)
+
+ tools = [
+ ChatCompletionToolParam(
+ type="function",
+ function={
+ "name": "get_weather",
+ "description": "Get the current weather",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "format": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "The temperature unit to use. Infer this from the users location.",
+ },
+ },
+ "required": ["location", "format"],
+ },
+ },
+ ),
+ ChatCompletionToolParam(
+ type="function",
+ function={
+ "name": "get_image",
+ "description": "Get an image from the video stream.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "question": {
+ "type": "string",
+ "description": "The question to ask the AI to generate an image of",
+ },
+ },
+ "required": ["question"],
+ },
+ },
+ ),
+ ]
+
+ system_prompt = """\
+You are a helpful assistant who converses with a user and answers questions. Respond concisely to general questions.
+
+Your response will be turned into speech so use only simple words and punctuation.
+
+You have access to two tools: get_weather and get_image.
+
+You can respond to questions about the weather using the get_weather tool.
+
+You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \
+indicate you should use the get_image tool are:
+ - What do you see?
+ - What's in the video?
+ - Can you describe the video?
+ - Tell me about what you see.
+ - Tell me something interesting about what you see.
+ - What's happening in the video?
+"""
+ messages = [
+ {"role": "system", "content": system_prompt},
+ ]
+
+ context = OpenAILLMContext(messages, tools)
+ context_aggregator = llm.create_context_aggregator(context)
+
+ pipeline = Pipeline(
+ [
+ transport.input(),
+ context_aggregator.user(),
+ llm,
+ tts,
+ transport.output(),
+ context_aggregator.assistant(),
+ ]
+ )
+
+ task = PipelineTask(pipeline)
+
+ @transport.event_handler("on_first_participant_joined")
+ async def on_first_participant_joined(transport, participant):
+ global video_participant_id
+ video_participant_id = participant["id"]
+ transport.capture_participant_transcription(participant["id"])
+ transport.capture_participant_video(video_participant_id, framerate=0)
+ # Kick off the conversation.
+ await tts.say("Hi! Ask me about the weather in San Francisco.")
+
+ runner = PipelineRunner()
+
+ await runner.run(task)
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/examples/foundational/19c-tools-togetherai.py b/examples/foundational/19c-tools-togetherai.py
deleted file mode 100644
index f8e63ef75..000000000
--- a/examples/foundational/19c-tools-togetherai.py
+++ /dev/null
@@ -1,137 +0,0 @@
-#
-# Copyright (c) 2024, Daily
-#
-# SPDX-License-Identifier: BSD 2-Clause License
-#
-
-import asyncio
-import aiohttp
-import os
-import sys
-import json
-
-from pipecat.frames.frames import LLMMessagesFrame
-from pipecat.pipeline.pipeline import Pipeline
-from pipecat.pipeline.runner import PipelineRunner
-from pipecat.pipeline.task import PipelineParams, PipelineTask
-from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
-from pipecat.services.cartesia import CartesiaTTSService
-from pipecat.services.together import TogetherLLMService
-from pipecat.transports.services.daily import DailyParams, DailyTransport
-from pipecat.vad.silero import SileroVADAnalyzer
-
-from runner import configure
-
-from loguru import logger
-
-from dotenv import load_dotenv
-
-load_dotenv(override=True)
-
-logger.remove(0)
-logger.add(sys.stderr, level="DEBUG")
-
-
-async def get_current_weather(
- function_name, tool_call_id, arguments, llm, context, result_callback
-):
- logger.debug("IN get_current_weather")
- location = arguments["location"]
- await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
-
-
-async def main():
- async with aiohttp.ClientSession() as session:
- (room_url, token) = await configure(session)
-
- transport = DailyTransport(
- room_url,
- token,
- "Respond bot",
- DailyParams(
- audio_out_enabled=True,
- transcription_enabled=True,
- vad_enabled=True,
- vad_analyzer=SileroVADAnalyzer(),
- ),
- )
-
- tts = CartesiaTTSService(
- api_key=os.getenv("CARTESIA_API_KEY"),
- voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
- )
-
- llm = TogetherLLMService(
- api_key=os.getenv("TOGETHER_API_KEY"),
- model=os.getenv("TOGETHER_MODEL"),
- )
- llm.register_function("get_current_weather", get_current_weather)
-
- weatherTool = {
- "name": "get_current_weather",
- "description": "Get the current weather in a given location",
- "parameters": {
- "type": "object",
- "properties": {
- "location": {
- "type": "string",
- "description": "The city and state, e.g. San Francisco, CA",
- },
- },
- "required": ["location"],
- },
- }
-
- system_prompt = f"""\
-You have access to the following functions:
-
-Use the function '{weatherTool["name"]}' to '{weatherTool["description"]}':
-{json.dumps(weatherTool)}
-
-If you choose to call a function ONLY reply in the following format with no prefix or suffix:
-
-{{\"example_name\": \"example_value\"}}
-
-Reminder:
-- Function calls MUST follow the specified format, start with
-- Required parameters MUST be specified
-- Only call one function at a time
-- Put the entire function call reply on one line
-- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
-
-"""
-
- messages = [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": "Wait for the user to say something."},
- ]
-
- context = OpenAILLMContext(messages)
- context_aggregator = llm.create_context_aggregator(context)
-
- pipeline = Pipeline(
- [
- transport.input(), # Transport user input
- context_aggregator.user(), # User speech to text
- llm, # LLM
- tts, # TTS
- transport.output(), # Transport bot output
- context_aggregator.assistant(), # Assistant spoken responses and tool context
- ]
- )
-
- task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True, enable_metrics=True))
-
- @transport.event_handler("on_first_participant_joined")
- async def on_first_participant_joined(transport, participant):
- transport.capture_participant_transcription(participant["id"])
- # Kick off the conversation.
- await task.queue_frames([LLMMessagesFrame(messages)])
-
- runner = PipelineRunner()
-
- await runner.run(task)
-
-
-if __name__ == "__main__":
- asyncio.run(main())
diff --git a/src/pipecat/processors/aggregators/openai_llm_context.py b/src/pipecat/processors/aggregators/openai_llm_context.py
index 4bf3f042c..c86045fab 100644
--- a/src/pipecat/processors/aggregators/openai_llm_context.py
+++ b/src/pipecat/processors/aggregators/openai_llm_context.py
@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
+import base64
+import copy
import io
import json
@@ -60,6 +62,7 @@ class OpenAILLMContext:
self._messages: List[ChatCompletionMessageParam] = messages if messages else []
self._tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice
self._tools: List[ChatCompletionToolParam] | NotGiven = tools
+ self._user_image_request_context = {}
@staticmethod
def from_messages(messages: List[dict]) -> "OpenAILLMContext":
@@ -114,6 +117,19 @@ class OpenAILLMContext:
def get_messages_json(self) -> str:
return json.dumps(self._messages, cls=CustomEncoder)
+ def get_messages_for_logging(self) -> str:
+ msgs = []
+ for message in self.messages:
+ msg = copy.deepcopy(message)
+ if "content" in msg:
+ if isinstance(msg["content"], list):
+ for item in msg["content"]:
+ if item["type"] == "image_url":
+ if item["image_url"]["url"].startswith("data:image/"):
+ item["image_url"]["url"] = "data:image/..."
+ msgs.append(msg)
+ return json.dumps(msgs)
+
def set_tool_choice(self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven):
self._tool_choice = tool_choice
@@ -122,6 +138,21 @@ class OpenAILLMContext:
tools = NOT_GIVEN
self._tools = tools
+ def add_image_frame_message(
+ self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
+ ):
+ buffer = io.BytesIO()
+ Image.frombytes(format, size, image).save(buffer, format="JPEG")
+ encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
+
+ content = [
+ {"type": "text", "text": text},
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},
+ ]
+ if text:
+ content.append({"type": "text", "text": text})
+ self.add_message({"role": "user", "content": content})
+
async def call_function(
self,
f: Callable[
diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py
index 0089a152e..a329239dc 100644
--- a/src/pipecat/services/ai_services.py
+++ b/src/pipecat/services/ai_services.py
@@ -116,7 +116,7 @@ class LLMService(AIService):
tool_call_id: str,
function_name: str,
arguments: str,
- run_llm: bool,
+ run_llm: bool = True,
) -> None:
f = None
if function_name in self._callbacks.keys():
diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py
index 86e1e3726..639a922e6 100644
--- a/src/pipecat/services/anthropic.py
+++ b/src/pipecat/services/anthropic.py
@@ -55,6 +55,7 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
+# internal use only -- todo: refactor
@dataclass
class AnthropicImageMessageFrame(Frame):
user_image_raw_frame: UserImageRawFrame
@@ -359,7 +360,6 @@ class AnthropicLLMContext(OpenAILLMContext):
system: str | NotGiven = NOT_GIVEN,
):
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
- self._user_image_request_context = {}
# For beta prompt caching. This is a counter that tracks the number of turns
# we've seen above the cache threshold. We reset this when we reset the
diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py
index 49fd04371..9a7cc9023 100644
--- a/src/pipecat/services/openai.py
+++ b/src/pipecat/services/openai.py
@@ -31,6 +31,8 @@ from pipecat.frames.frames import (
TTSStartedFrame,
TTSStoppedFrame,
URLImageRawFrame,
+ UserImageRawFrame,
+ UserImageRequestFrame,
VisionImageRawFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
@@ -181,7 +183,7 @@ class BaseOpenAILLMService(LLMService):
async def _stream_chat_completions(
self, context: OpenAILLMContext
) -> AsyncStream[ChatCompletionChunk]:
- logger.debug(f"Generating chat: {context.get_messages_json()}")
+ logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
messages: List[ChatCompletionMessageParam] = context.get_messages()
@@ -476,10 +478,49 @@ class OpenAITTSService(TTSService):
logger.exception(f"{self} error generating TTS: {e}")
+# internal use only -- todo: refactor
+@dataclass
+class OpenAIImageMessageFrame(Frame):
+ user_image_raw_frame: UserImageRawFrame
+ text: Optional[str] = None
+
+
class OpenAIUserContextAggregator(LLMUserContextAggregator):
def __init__(self, context: OpenAILLMContext):
super().__init__(context=context)
+ async def process_frame(self, frame, direction):
+ await super().process_frame(frame, direction)
+ # Our parent method has already called push_frame(). So we can't interrupt the
+ # flow here and we don't need to call push_frame() ourselves.
+ try:
+ if isinstance(frame, UserImageRequestFrame):
+ # The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
+ # that frame so we can use it when we assemble the image message in the assistant
+ # context aggregator.
+ if frame.context:
+ if isinstance(frame.context, str):
+ self._context._user_image_request_context[frame.user_id] = frame.context
+ else:
+ logger.error(
+ f"Unexpected UserImageRequestFrame context type: {type(frame.context)}"
+ )
+ del self._context._user_image_request_context[frame.user_id]
+ else:
+ if frame.user_id in self._context._user_image_request_context:
+ del self._context._user_image_request_context[frame.user_id]
+ elif isinstance(frame, UserImageRawFrame):
+ # Push a new AnthropicImageMessageFrame with the text context we cached
+ # downstream to be handled by our assistant context aggregator. This is
+ # necessary so that we add the message to the context in the right order.
+ text = self._context._user_image_request_context.get(frame.user_id) or ""
+ if text:
+ del self._context._user_image_request_context[frame.user_id]
+ frame = OpenAIImageMessageFrame(user_image_raw_frame=frame, text=text)
+ await self.push_frame(frame)
+ except Exception as e:
+ logger.error(f"Error processing frame: {e}")
+
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs):
@@ -487,6 +528,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
self._user_context_aggregator = user_context_aggregator
self._function_calls_in_progress = {}
self._function_call_result = None
+ self._pending_image_frame_message = None
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
@@ -507,9 +549,14 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
"FunctionCallResultFrame tool_call_id does not match any function call in progress"
)
self._function_call_result = None
+ elif isinstance(frame, OpenAIImageMessageFrame):
+ self._pending_image_frame_message = frame
+ await self._push_aggregation()
async def _push_aggregation(self):
- if not (self._aggregation or self._function_call_result):
+ if not (
+ self._aggregation or self._function_call_result or self._pending_image_frame_message
+ ):
return
run_llm = False
@@ -548,6 +595,17 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
else:
self._context.add_message({"role": "assistant", "content": aggregation})
+ if self._pending_image_frame_message:
+ frame = self._pending_image_frame_message
+ self._pending_image_frame_message = None
+ self._context.add_image_frame_message(
+ format=frame.user_image_raw_frame.format,
+ size=frame.user_image_raw_frame.size,
+ image=frame.user_image_raw_frame.image,
+ text=frame.text,
+ )
+ run_llm = True
+
if run_llm:
await self._user_context_aggregator.push_context_frame()
diff --git a/src/pipecat/services/together.py b/src/pipecat/services/together.py
index 3f4d97964..5da470002 100644
--- a/src/pipecat/services/together.py
+++ b/src/pipecat/services/together.py
@@ -4,42 +4,21 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
-import json
-import re
-import uuid
-from asyncio import CancelledError
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional
-
+from typing import Any, Dict, Optional
+import httpx
from loguru import logger
from pydantic import BaseModel, Field
from pipecat.frames.frames import (
- Frame,
- FunctionCallInProgressFrame,
- FunctionCallResultFrame,
- LLMFullResponseEndFrame,
- LLMFullResponseStartFrame,
- LLMMessagesFrame,
LLMUpdateSettingsFrame,
- StartInterruptionFrame,
- TextFrame,
- UserImageRequestFrame,
)
-from pipecat.metrics.metrics import LLMTokenUsage
-from pipecat.processors.aggregators.llm_response import (
- LLMAssistantContextAggregator,
- LLMUserContextAggregator,
-)
-from pipecat.processors.aggregators.openai_llm_context import (
- OpenAILLMContext,
- OpenAILLMContextFrame,
-)
-from pipecat.processors.frame_processor import FrameDirection
-from pipecat.services.ai_services import LLMService
+from pipecat.services.openai import OpenAILLMService
+
try:
- from together import AsyncTogether
+ # Together.ai is recommending OpenAI-compatible function calling, so we've switched over
+ # to using the OpenAI client library here rather than the Together Python client library.
+ from openai import AsyncOpenAI, DefaultAsyncHttpxClient
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
@@ -48,19 +27,7 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
-@dataclass
-class TogetherContextAggregatorPair:
- _user: "TogetherUserContextAggregator"
- _assistant: "TogetherAssistantContextAggregator"
-
- def user(self) -> "TogetherUserContextAggregator":
- return self._user
-
- def assistant(self) -> "TogetherAssistantContextAggregator":
- return self._assistant
-
-
-class TogetherLLMService(LLMService):
+class TogetherLLMService(OpenAILLMService):
"""This class implements inference with Together's Llama 3.1 models"""
class InputParams(BaseModel):
@@ -68,20 +35,23 @@ class TogetherLLMService(LLMService):
max_tokens: Optional[int] = Field(default=4096, ge=1)
presence_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
temperature: Optional[float] = Field(default=None, ge=0.0, le=1.0)
+ # Note: top_k is currently not supported by the OpenAI client library,
+ # so top_k is ignore right now.
top_k: Optional[int] = Field(default=None, ge=0)
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
+ seed: Optional[int] = Field(default=None)
def __init__(
self,
*,
api_key: str,
+ base_url: str = "https://api.together.xyz/v1",
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
params: InputParams = InputParams(),
**kwargs,
):
- super().__init__(**kwargs)
- self._client = AsyncTogether(api_key=api_key)
+ super().__init__(api_key=api_key, base_url=base_url, model=model, params=params, **kwargs)
self.set_model_name(model)
self._max_tokens = params.max_tokens
self._frequency_penalty = params.frequency_penalty
@@ -94,15 +64,17 @@ class TogetherLLMService(LLMService):
def can_generate_metrics(self) -> bool:
return True
- @staticmethod
- def create_context_aggregator(
- context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
- ) -> TogetherContextAggregatorPair:
- user = TogetherUserContextAggregator(context)
- assistant = TogetherAssistantContextAggregator(
- user, expect_stripped_words=assistant_expect_stripped_words
+ def create_client(self, api_key=None, base_url=None, **kwargs):
+ logger.debug(f"Creating Together.ai client with api {base_url}")
+ return AsyncOpenAI(
+ api_key=api_key,
+ base_url=base_url,
+ http_client=DefaultAsyncHttpxClient(
+ limits=httpx.Limits(
+ max_keepalive_connections=100, max_connections=1000, keepalive_expiry=None
+ )
+ ),
)
- return TogetherContextAggregatorPair(_user=user, _assistant=assistant)
async def set_frequency_penalty(self, frequency_penalty: float):
logger.debug(f"Switching LLM frequency_penalty to: [{frequency_penalty}]")
@@ -150,252 +122,3 @@ class TogetherLLMService(LLMService):
await self.set_top_p(frame.top_p)
if frame.extra:
await self.set_extra(frame.extra)
-
- async def _process_context(self, context: OpenAILLMContext):
- try:
- await self.push_frame(LLMFullResponseStartFrame())
- await self.start_processing_metrics()
-
- logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
-
- await self.start_ttfb_metrics()
-
- params = {
- "messages": context.messages,
- "model": self.model_name,
- "max_tokens": self._max_tokens,
- "stream": True,
- "frequency_penalty": self._frequency_penalty,
- "presence_penalty": self._presence_penalty,
- "temperature": self._temperature,
- "top_k": self._top_k,
- "top_p": self._top_p,
- }
-
- params.update(self._extra)
-
- stream = await self._client.chat.completions.create(**params)
-
- # Function calling
- got_first_chunk = False
- accumulating_function_call = False
- function_call_accumulator = ""
-
- async for chunk in stream:
- # logger.debug(f"Together LLM event: {chunk}")
- if chunk.usage:
- tokens = LLMTokenUsage(
- prompt_tokens=chunk.usage.prompt_tokens,
- completion_tokens=chunk.usage.completion_tokens,
- total_tokens=chunk.usage.total_tokens,
- )
- await self.start_llm_usage_metrics(tokens)
-
- if len(chunk.choices) == 0:
- continue
-
- if not got_first_chunk:
- await self.stop_ttfb_metrics()
- if chunk.choices[0].delta.content:
- got_first_chunk = True
- if chunk.choices[0].delta.content[0] == "<":
- accumulating_function_call = True
-
- if chunk.choices[0].delta.content:
- if accumulating_function_call:
- function_call_accumulator += chunk.choices[0].delta.content
- else:
- await self.push_frame(TextFrame(chunk.choices[0].delta.content))
-
- if chunk.choices[0].finish_reason == "eos" and accumulating_function_call:
- await self._extract_function_call(context, function_call_accumulator)
-
- except CancelledError:
- # todo: implement token counting estimates for use when the user interrupts a long generation
- # we do this in the anthropic.py service
- raise
- except Exception as e:
- logger.exception(f"{self} exception: {e}")
- finally:
- await self.stop_processing_metrics()
- await self.push_frame(LLMFullResponseEndFrame())
-
- async def process_frame(self, frame: Frame, direction: FrameDirection):
- await super().process_frame(frame, direction)
-
- context = None
- if isinstance(frame, OpenAILLMContextFrame):
- context = frame.context
- elif isinstance(frame, LLMMessagesFrame):
- context = TogetherLLMContext.from_messages(frame.messages)
- elif isinstance(frame, LLMUpdateSettingsFrame):
- await self._update_settings(frame)
- else:
- await self.push_frame(frame, direction)
-
- if context:
- await self._process_context(context)
-
- async def _extract_function_call(self, context, function_call_accumulator):
- context.add_message({"role": "assistant", "content": function_call_accumulator})
-
- function_regex = r"(.*?)"
- match = re.search(function_regex, function_call_accumulator)
- if match:
- function_name, args_string = match.groups()
- try:
- arguments = json.loads(args_string)
- await self.call_function(
- context=context,
- tool_call_id=str(uuid.uuid4()),
- function_name=function_name,
- arguments=arguments,
- )
- return
- except json.JSONDecodeError as error:
- # We get here if the LLM returns a function call with invalid JSON arguments. This could happen
- # because of LLM non-determinism, or maybe more often because of user error in the prompt.
- # Should we do anything more than log a warning?
- logger.debug(f"Error parsing function arguments: {error}")
-
-
-class TogetherLLMContext(OpenAILLMContext):
- def __init__(
- self,
- messages: list[dict] | None = None,
- ):
- super().__init__(messages=messages)
-
- @classmethod
- def from_openai_context(cls, openai_context: OpenAILLMContext):
- self = cls(
- messages=openai_context.messages,
- )
- return self
-
- @classmethod
- def from_messages(cls, messages: List[dict]) -> "TogetherLLMContext":
- return cls(messages=messages)
-
- def add_message(self, message):
- try:
- self.messages.append(message)
- except Exception as e:
- logger.error(f"Error adding message: {e}")
-
- def get_messages_for_logging(self) -> str:
- return json.dumps(self.messages)
-
-
-class TogetherUserContextAggregator(LLMUserContextAggregator):
- def __init__(self, context: OpenAILLMContext | TogetherLLMContext):
- super().__init__(context=context)
-
- if isinstance(context, OpenAILLMContext):
- self._context = TogetherLLMContext.from_openai_context(context)
-
- async def push_messages_frame(self):
- frame = OpenAILLMContextFrame(self._context)
- await self.push_frame(frame)
-
- async def process_frame(self, frame, direction):
- await super().process_frame(frame, direction)
- # Our parent method has already called push_frame(). So we can't interrupt the
- # flow here and we don't need to call push_frame() ourselves. Possibly something
- # to talk through (tagging @aleix). At some point we might need to refactor these
- # context aggregators.
- try:
- if isinstance(frame, UserImageRequestFrame):
- # The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
- # that frame so we can use it when we assemble the image message in the assistant
- # context aggregator.
- if frame.context:
- if isinstance(frame.context, str):
- self._context._user_image_request_context[frame.user_id] = frame.context
- else:
- logger.error(
- f"Unexpected UserImageRequestFrame context type: {type(frame.context)}"
- )
- del self._context._user_image_request_context[frame.user_id]
- else:
- if frame.user_id in self._context._user_image_request_context:
- del self._context._user_image_request_context[frame.user_id]
- except Exception as e:
- logger.error(f"Error processing frame: {e}")
-
-
-#
-# Claude returns a text content block along with a tool use content block. This works quite nicely
-# with streaming. We get the text first, so we can start streaming it right away. Then we get the
-# tool_use block. While the text is streaming to TTS and the transport, we can run the tool call.
-#
-# But Claude is verbose. It would be nice to come up with prompt language that suppresses Claude's
-# chattiness about it's tool thinking.
-#
-
-
-class TogetherAssistantContextAggregator(LLMAssistantContextAggregator):
- def __init__(self, user_context_aggregator: TogetherUserContextAggregator, **kwargs):
- super().__init__(context=user_context_aggregator._context, **kwargs)
- self._user_context_aggregator = user_context_aggregator
- self._function_call_in_progress = None
- self._function_call_result = None
-
- async def process_frame(self, frame, direction):
- await super().process_frame(frame, direction)
- # See note above about not calling push_frame() here.
- if isinstance(frame, StartInterruptionFrame):
- self._function_call_in_progress = None
- self._function_call_finished = None
- elif isinstance(frame, FunctionCallInProgressFrame):
- self._function_call_in_progress = frame
- elif isinstance(frame, FunctionCallResultFrame):
- if (
- self._function_call_in_progress
- and self._function_call_in_progress.tool_call_id == frame.tool_call_id
- ):
- self._function_call_in_progress = None
- self._function_call_result = frame
- await self._push_aggregation()
- else:
- logger.warning(
- "FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id"
- )
- self._function_call_in_progress = None
- self._function_call_result = None
-
- def add_message(self, message):
- self._user_context_aggregator.add_message(message)
-
- async def _push_aggregation(self):
- if not (self._aggregation or self._function_call_result):
- return
-
- run_llm = False
-
- aggregation = self._aggregation
- self._reset()
-
- try:
- if self._function_call_result:
- frame = self._function_call_result
- self._function_call_result = None
- self._context.add_message(
- {
- "role": "tool",
- # Together expects the content here to be a string, so stringify it
- "content": str(frame.result),
- }
- )
- run_llm = True
- else:
- self._context.add_message({"role": "assistant", "content": aggregation})
-
- if run_llm:
- await self._user_context_aggregator.push_messages_frame()
-
- frame = OpenAILLMContextFrame(self._context)
- await self.push_frame(frame)
-
- except Exception as e:
- logger.error(f"Error processing frame: {e}")