Merge pull request #531 from pipecat-ai/khk/function-calling-improvements

This commit is contained in:
Kwindla Hultman Kramer
2024-10-01 07:23:38 -07:00
committed by GitHub
12 changed files with 426 additions and 450 deletions

View File

@@ -67,7 +67,7 @@ async def main():
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond in plain language. Respond to what the user said in a creative and helpful way.",
},
]
@@ -87,7 +87,12 @@ async def main():
]
)
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
task = PipelineTask(
pipeline,
PipelineParams(
allow_interruptions=True, enable_metrics=True, enable_usage_metrics=True
),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):

View File

@@ -9,11 +9,9 @@ import aiohttp
import os
import sys
from pipecat.frames.frames import TextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.logger import FrameLogger
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.openai import OpenAILLMContext, OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
@@ -72,9 +70,6 @@ async def main():
# sent to the same callback with an additional function_name parameter.
llm.register_function(None, fetch_weather_from_api, start_callback=start_fetch_weather)
fl_in = FrameLogger("Inner")
fl_out = FrameLogger("Outer")
tools = [
ChatCompletionToolParam(
type="function",
@@ -111,11 +106,9 @@ async def main():
pipeline = Pipeline(
[
# fl_in,
transport.input(),
context_aggregator.user(),
llm,
# fl_out,
tts,
transport.output(),
context_aggregator.assistant(),

View File

@@ -0,0 +1,136 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import aiohttp
import os
import sys
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.openai import OpenAILLMContext
from pipecat.services.together import TogetherLLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer
from openai.types.chat import ChatCompletionToolParam
from runner import configure
from loguru import logger
from dotenv import load_dotenv
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
async def start_fetch_weather(function_name, llm, context):
# note: we can't push a frame to the LLM here. the bot
# can interrupt itself and/or cause audio overlapping glitches.
# possible question for Aleix and Chad about what the right way
# to trigger speech is, now, with the new queues/async/sync refactors.
# await llm.push_frame(TextFrame("Let me check on that."))
logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
await result_callback({"conditions": "nice", "temperature": "75"})
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = TogetherLLMService(
api_key=os.getenv("TOGETHER_API_KEY"),
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
)
# Register a function_name of None to get all functions
# sent to the same callback with an additional function_name parameter.
llm.register_function(None, fetch_weather_from_api, start_callback=start_fetch_weather)
tools = [
ChatCompletionToolParam(
type="function",
function={
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
},
)
]
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages, tools)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(),
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
]
)
task = PipelineTask(pipeline)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
# await tts.say("Hi! Ask me about the weather in San Francisco.")
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,167 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import aiohttp
import os
import sys
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.openai import OpenAILLMContext, OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer
from openai.types.chat import ChatCompletionToolParam
from runner import configure
from loguru import logger
from dotenv import load_dotenv
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
video_participant_id = None
async def get_weather(function_name, tool_call_id, arguments, llm, context, result_callback):
location = arguments["location"]
await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
async def get_image(function_name, tool_call_id, arguments, llm, context, result_callback):
logger.debug(f"!!! IN get_image {video_participant_id}, {arguments}")
question = arguments["question"]
await llm.request_image_frame(user_id=video_participant_id, text_content=question)
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
llm.register_function("get_weather", get_weather)
llm.register_function("get_image", get_image)
tools = [
ChatCompletionToolParam(
type="function",
function={
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
},
),
ChatCompletionToolParam(
type="function",
function={
"name": "get_image",
"description": "Get an image from the video stream.",
"parameters": {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The question to ask the AI to generate an image of",
},
},
"required": ["question"],
},
},
),
]
system_prompt = """\
You are a helpful assistant who converses with a user and answers questions. Respond concisely to general questions.
Your response will be turned into speech so use only simple words and punctuation.
You have access to two tools: get_weather and get_image.
You can respond to questions about the weather using the get_weather tool.
You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \
indicate you should use the get_image tool are:
- What do you see?
- What's in the video?
- Can you describe the video?
- Tell me about what you see.
- Tell me something interesting about what you see.
- What's happening in the video?
"""
messages = [
{"role": "system", "content": system_prompt},
]
context = OpenAILLMContext(messages, tools)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(),
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
]
)
task = PipelineTask(pipeline)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
global video_participant_id
video_participant_id = participant["id"]
transport.capture_participant_transcription(participant["id"])
transport.capture_participant_video(video_participant_id, framerate=0)
# Kick off the conversation.
await tts.say("Hi! Ask me about the weather in San Francisco.")
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,137 +0,0 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import aiohttp
import os
import sys
import json
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.together import TogetherLLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer
from runner import configure
from loguru import logger
from dotenv import load_dotenv
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
async def get_current_weather(
function_name, tool_call_id, arguments, llm, context, result_callback
):
logger.debug("IN get_current_weather")
location = arguments["location"]
await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = TogetherLLMService(
api_key=os.getenv("TOGETHER_API_KEY"),
model=os.getenv("TOGETHER_MODEL"),
)
llm.register_function("get_current_weather", get_current_weather)
weatherTool = {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
"required": ["location"],
},
}
system_prompt = f"""\
You have access to the following functions:
Use the function '{weatherTool["name"]}' to '{weatherTool["description"]}':
{json.dumps(weatherTool)}
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>
Reminder:
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": "Wait for the user to say something."},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
context_aggregator.user(), # User speech to text
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses and tool context
]
)
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True, enable_metrics=True))
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
await task.queue_frames([LLMMessagesFrame(messages)])
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import base64
import copy
import io
import json
@@ -60,6 +62,7 @@ class OpenAILLMContext:
self._messages: List[ChatCompletionMessageParam] = messages if messages else []
self._tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice
self._tools: List[ChatCompletionToolParam] | NotGiven = tools
self._user_image_request_context = {}
@staticmethod
def from_messages(messages: List[dict]) -> "OpenAILLMContext":
@@ -114,6 +117,19 @@ class OpenAILLMContext:
def get_messages_json(self) -> str:
return json.dumps(self._messages, cls=CustomEncoder)
def get_messages_for_logging(self) -> str:
msgs = []
for message in self.messages:
msg = copy.deepcopy(message)
if "content" in msg:
if isinstance(msg["content"], list):
for item in msg["content"]:
if item["type"] == "image_url":
if item["image_url"]["url"].startswith("data:image/"):
item["image_url"]["url"] = "data:image/..."
msgs.append(msg)
return json.dumps(msgs)
def set_tool_choice(self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven):
self._tool_choice = tool_choice
@@ -122,6 +138,21 @@ class OpenAILLMContext:
tools = NOT_GIVEN
self._tools = tools
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
content = [
{"type": "text", "text": text},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},
]
if text:
content.append({"type": "text", "text": text})
self.add_message({"role": "user", "content": content})
async def call_function(
self,
f: Callable[

View File

@@ -116,7 +116,7 @@ class LLMService(AIService):
tool_call_id: str,
function_name: str,
arguments: str,
run_llm: bool,
run_llm: bool = True,
) -> None:
f = None
if function_name in self._callbacks.keys():

View File

@@ -55,6 +55,7 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
# internal use only -- todo: refactor
@dataclass
class AnthropicImageMessageFrame(Frame):
user_image_raw_frame: UserImageRawFrame
@@ -359,7 +360,6 @@ class AnthropicLLMContext(OpenAILLMContext):
system: str | NotGiven = NOT_GIVEN,
):
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
self._user_image_request_context = {}
# For beta prompt caching. This is a counter that tracks the number of turns
# we've seen above the cache threshold. We reset this when we reset the

View File

@@ -31,6 +31,8 @@ from pipecat.frames.frames import (
TTSStartedFrame,
TTSStoppedFrame,
URLImageRawFrame,
UserImageRawFrame,
UserImageRequestFrame,
VisionImageRawFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
@@ -181,7 +183,7 @@ class BaseOpenAILLMService(LLMService):
async def _stream_chat_completions(
self, context: OpenAILLMContext
) -> AsyncStream[ChatCompletionChunk]:
logger.debug(f"Generating chat: {context.get_messages_json()}")
logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
messages: List[ChatCompletionMessageParam] = context.get_messages()
@@ -476,10 +478,49 @@ class OpenAITTSService(TTSService):
logger.exception(f"{self} error generating TTS: {e}")
# internal use only -- todo: refactor
@dataclass
class OpenAIImageMessageFrame(Frame):
user_image_raw_frame: UserImageRawFrame
text: Optional[str] = None
class OpenAIUserContextAggregator(LLMUserContextAggregator):
def __init__(self, context: OpenAILLMContext):
super().__init__(context=context)
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
# Our parent method has already called push_frame(). So we can't interrupt the
# flow here and we don't need to call push_frame() ourselves.
try:
if isinstance(frame, UserImageRequestFrame):
# The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
# that frame so we can use it when we assemble the image message in the assistant
# context aggregator.
if frame.context:
if isinstance(frame.context, str):
self._context._user_image_request_context[frame.user_id] = frame.context
else:
logger.error(
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}"
)
del self._context._user_image_request_context[frame.user_id]
else:
if frame.user_id in self._context._user_image_request_context:
del self._context._user_image_request_context[frame.user_id]
elif isinstance(frame, UserImageRawFrame):
# Push a new AnthropicImageMessageFrame with the text context we cached
# downstream to be handled by our assistant context aggregator. This is
# necessary so that we add the message to the context in the right order.
text = self._context._user_image_request_context.get(frame.user_id) or ""
if text:
del self._context._user_image_request_context[frame.user_id]
frame = OpenAIImageMessageFrame(user_image_raw_frame=frame, text=text)
await self.push_frame(frame)
except Exception as e:
logger.error(f"Error processing frame: {e}")
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs):
@@ -487,6 +528,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
self._user_context_aggregator = user_context_aggregator
self._function_calls_in_progress = {}
self._function_call_result = None
self._pending_image_frame_message = None
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
@@ -507,9 +549,14 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
"FunctionCallResultFrame tool_call_id does not match any function call in progress"
)
self._function_call_result = None
elif isinstance(frame, OpenAIImageMessageFrame):
self._pending_image_frame_message = frame
await self._push_aggregation()
async def _push_aggregation(self):
if not (self._aggregation or self._function_call_result):
if not (
self._aggregation or self._function_call_result or self._pending_image_frame_message
):
return
run_llm = False
@@ -548,6 +595,17 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
else:
self._context.add_message({"role": "assistant", "content": aggregation})
if self._pending_image_frame_message:
frame = self._pending_image_frame_message
self._pending_image_frame_message = None
self._context.add_image_frame_message(
format=frame.user_image_raw_frame.format,
size=frame.user_image_raw_frame.size,
image=frame.user_image_raw_frame.image,
text=frame.text,
)
run_llm = True
if run_llm:
await self._user_context_aggregator.push_context_frame()

View File

@@ -4,42 +4,21 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import json
import re
import uuid
from asyncio import CancelledError
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
import httpx
from loguru import logger
from pydantic import BaseModel, Field
from pipecat.frames.frames import (
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMUpdateSettingsFrame,
StartInterruptionFrame,
TextFrame,
UserImageRequestFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_response import (
LLMAssistantContextAggregator,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService
from pipecat.services.openai import OpenAILLMService
try:
from together import AsyncTogether
# Together.ai is recommending OpenAI-compatible function calling, so we've switched over
# to using the OpenAI client library here rather than the Together Python client library.
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
@@ -48,19 +27,7 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
@dataclass
class TogetherContextAggregatorPair:
_user: "TogetherUserContextAggregator"
_assistant: "TogetherAssistantContextAggregator"
def user(self) -> "TogetherUserContextAggregator":
return self._user
def assistant(self) -> "TogetherAssistantContextAggregator":
return self._assistant
class TogetherLLMService(LLMService):
class TogetherLLMService(OpenAILLMService):
"""This class implements inference with Together's Llama 3.1 models"""
class InputParams(BaseModel):
@@ -68,20 +35,23 @@ class TogetherLLMService(LLMService):
max_tokens: Optional[int] = Field(default=4096, ge=1)
presence_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
temperature: Optional[float] = Field(default=None, ge=0.0, le=1.0)
# Note: top_k is currently not supported by the OpenAI client library,
# so top_k is ignore right now.
top_k: Optional[int] = Field(default=None, ge=0)
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
seed: Optional[int] = Field(default=None)
def __init__(
self,
*,
api_key: str,
base_url: str = "https://api.together.xyz/v1",
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
params: InputParams = InputParams(),
**kwargs,
):
super().__init__(**kwargs)
self._client = AsyncTogether(api_key=api_key)
super().__init__(api_key=api_key, base_url=base_url, model=model, params=params, **kwargs)
self.set_model_name(model)
self._max_tokens = params.max_tokens
self._frequency_penalty = params.frequency_penalty
@@ -94,15 +64,17 @@ class TogetherLLMService(LLMService):
def can_generate_metrics(self) -> bool:
return True
@staticmethod
def create_context_aggregator(
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
) -> TogetherContextAggregatorPair:
user = TogetherUserContextAggregator(context)
assistant = TogetherAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
def create_client(self, api_key=None, base_url=None, **kwargs):
logger.debug(f"Creating Together.ai client with api {base_url}")
return AsyncOpenAI(
api_key=api_key,
base_url=base_url,
http_client=DefaultAsyncHttpxClient(
limits=httpx.Limits(
max_keepalive_connections=100, max_connections=1000, keepalive_expiry=None
)
),
)
return TogetherContextAggregatorPair(_user=user, _assistant=assistant)
async def set_frequency_penalty(self, frequency_penalty: float):
logger.debug(f"Switching LLM frequency_penalty to: [{frequency_penalty}]")
@@ -150,252 +122,3 @@ class TogetherLLMService(LLMService):
await self.set_top_p(frame.top_p)
if frame.extra:
await self.set_extra(frame.extra)
async def _process_context(self, context: OpenAILLMContext):
try:
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
await self.start_ttfb_metrics()
params = {
"messages": context.messages,
"model": self.model_name,
"max_tokens": self._max_tokens,
"stream": True,
"frequency_penalty": self._frequency_penalty,
"presence_penalty": self._presence_penalty,
"temperature": self._temperature,
"top_k": self._top_k,
"top_p": self._top_p,
}
params.update(self._extra)
stream = await self._client.chat.completions.create(**params)
# Function calling
got_first_chunk = False
accumulating_function_call = False
function_call_accumulator = ""
async for chunk in stream:
# logger.debug(f"Together LLM event: {chunk}")
if chunk.usage:
tokens = LLMTokenUsage(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
)
await self.start_llm_usage_metrics(tokens)
if len(chunk.choices) == 0:
continue
if not got_first_chunk:
await self.stop_ttfb_metrics()
if chunk.choices[0].delta.content:
got_first_chunk = True
if chunk.choices[0].delta.content[0] == "<":
accumulating_function_call = True
if chunk.choices[0].delta.content:
if accumulating_function_call:
function_call_accumulator += chunk.choices[0].delta.content
else:
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
if chunk.choices[0].finish_reason == "eos" and accumulating_function_call:
await self._extract_function_call(context, function_call_accumulator)
except CancelledError:
# todo: implement token counting estimates for use when the user interrupts a long generation
# we do this in the anthropic.py service
raise
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
context = None
if isinstance(frame, OpenAILLMContextFrame):
context = frame.context
elif isinstance(frame, LLMMessagesFrame):
context = TogetherLLMContext.from_messages(frame.messages)
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame)
else:
await self.push_frame(frame, direction)
if context:
await self._process_context(context)
async def _extract_function_call(self, context, function_call_accumulator):
context.add_message({"role": "assistant", "content": function_call_accumulator})
function_regex = r"<function=(\w+)>(.*?)</function>"
match = re.search(function_regex, function_call_accumulator)
if match:
function_name, args_string = match.groups()
try:
arguments = json.loads(args_string)
await self.call_function(
context=context,
tool_call_id=str(uuid.uuid4()),
function_name=function_name,
arguments=arguments,
)
return
except json.JSONDecodeError as error:
# We get here if the LLM returns a function call with invalid JSON arguments. This could happen
# because of LLM non-determinism, or maybe more often because of user error in the prompt.
# Should we do anything more than log a warning?
logger.debug(f"Error parsing function arguments: {error}")
class TogetherLLMContext(OpenAILLMContext):
def __init__(
self,
messages: list[dict] | None = None,
):
super().__init__(messages=messages)
@classmethod
def from_openai_context(cls, openai_context: OpenAILLMContext):
self = cls(
messages=openai_context.messages,
)
return self
@classmethod
def from_messages(cls, messages: List[dict]) -> "TogetherLLMContext":
return cls(messages=messages)
def add_message(self, message):
try:
self.messages.append(message)
except Exception as e:
logger.error(f"Error adding message: {e}")
def get_messages_for_logging(self) -> str:
return json.dumps(self.messages)
class TogetherUserContextAggregator(LLMUserContextAggregator):
def __init__(self, context: OpenAILLMContext | TogetherLLMContext):
super().__init__(context=context)
if isinstance(context, OpenAILLMContext):
self._context = TogetherLLMContext.from_openai_context(context)
async def push_messages_frame(self):
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
# Our parent method has already called push_frame(). So we can't interrupt the
# flow here and we don't need to call push_frame() ourselves. Possibly something
# to talk through (tagging @aleix). At some point we might need to refactor these
# context aggregators.
try:
if isinstance(frame, UserImageRequestFrame):
# The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
# that frame so we can use it when we assemble the image message in the assistant
# context aggregator.
if frame.context:
if isinstance(frame.context, str):
self._context._user_image_request_context[frame.user_id] = frame.context
else:
logger.error(
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}"
)
del self._context._user_image_request_context[frame.user_id]
else:
if frame.user_id in self._context._user_image_request_context:
del self._context._user_image_request_context[frame.user_id]
except Exception as e:
logger.error(f"Error processing frame: {e}")
#
# Claude returns a text content block along with a tool use content block. This works quite nicely
# with streaming. We get the text first, so we can start streaming it right away. Then we get the
# tool_use block. While the text is streaming to TTS and the transport, we can run the tool call.
#
# But Claude is verbose. It would be nice to come up with prompt language that suppresses Claude's
# chattiness about it's tool thinking.
#
class TogetherAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: TogetherUserContextAggregator, **kwargs):
super().__init__(context=user_context_aggregator._context, **kwargs)
self._user_context_aggregator = user_context_aggregator
self._function_call_in_progress = None
self._function_call_result = None
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
# See note above about not calling push_frame() here.
if isinstance(frame, StartInterruptionFrame):
self._function_call_in_progress = None
self._function_call_finished = None
elif isinstance(frame, FunctionCallInProgressFrame):
self._function_call_in_progress = frame
elif isinstance(frame, FunctionCallResultFrame):
if (
self._function_call_in_progress
and self._function_call_in_progress.tool_call_id == frame.tool_call_id
):
self._function_call_in_progress = None
self._function_call_result = frame
await self._push_aggregation()
else:
logger.warning(
"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id"
)
self._function_call_in_progress = None
self._function_call_result = None
def add_message(self, message):
self._user_context_aggregator.add_message(message)
async def _push_aggregation(self):
if not (self._aggregation or self._function_call_result):
return
run_llm = False
aggregation = self._aggregation
self._reset()
try:
if self._function_call_result:
frame = self._function_call_result
self._function_call_result = None
self._context.add_message(
{
"role": "tool",
# Together expects the content here to be a string, so stringify it
"content": str(frame.result),
}
)
run_llm = True
else:
self._context.add_message({"role": "assistant", "content": aggregation})
if run_llm:
await self._user_context_aggregator.push_messages_frame()
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
except Exception as e:
logger.error(f"Error processing frame: {e}")