diff --git a/CHANGELOG.md b/CHANGELOG.md index b712407cf..7d1cff375 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added foundational example `14u-function-calling-ollama.py` for Ollama + function calling. + - Added `LocalSmartTurnAnalyzerV2`, which supports local on-device inference with the new `smart-turn-v2` turn detection model. @@ -19,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed an issue in `OLLamaLLMService` where kwargs were not passed correctly + to the parent class. + - Fixed an issue where, in some edge cases, the `EmulateUserStartedSpeakingFrame` could be created even if we didn't have a transcription. diff --git a/examples/foundational/14u-function-calling-ollama.py b/examples/foundational/14u-function-calling-ollama.py new file mode 100644 index 000000000..39220b791 --- /dev/null +++ b/examples/foundational/14u-function-calling-ollama.py @@ -0,0 +1,162 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import argparse +import os + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import TTSSpeakFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.cartesia.tts import CartesiaTTSService +from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.llm_service import FunctionCallParams +from pipecat.services.ollama.llm import OLLamaLLMService +from pipecat.services.openai.llm import OpenAILLMService +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams +from pipecat.transports.services.daily import DailyParams + +load_dotenv(override=True) + + +async def fetch_weather_from_api(params: FunctionCallParams): + await params.result_callback({"conditions": "nice", "temperature": "75"}) + + +async def fetch_restaurant_recommendation(params: FunctionCallParams): + await params.result_callback({"name": "The Golden Dragon"}) + + +# We store functions so objects (e.g. SileroVADAnalyzer) don't get +# instantiated. The function will be called when the desired transport gets +# selected. +transport_params = { + "daily": lambda: DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + "twilio": lambda: FastAPIWebsocketParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + "webrtc": lambda: TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), +} + + +async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool): + logger.info(f"Starting bot") + + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady + ) + + llm = OLLamaLLMService(model="llama3.2") # Update to the model you're running locally + + # You can also register a function_name of None to get all functions + # sent to the same callback with an additional function_name parameter. + llm.register_function("get_current_weather", fetch_weather_from_api) + llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation) + + @llm.event_handler("on_function_calls_started") + async def on_function_calls_started(service, function_calls): + await tts.queue_frame(TTSSpeakFrame("Let me check on that.")) + + weather_function = FunctionSchema( + name="get_current_weather", + description="Get the current weather", + properties={ + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the user's location.", + }, + }, + required=["location", "format"], + ) + restaurant_function = FunctionSchema( + name="get_restaurant_recommendation", + description="Get a restaurant recommendation", + properties={ + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + }, + required=["location"], + ) + tools = ToolsSchema(standard_tools=[weather_function, restaurant_function]) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = OpenAILLMContext(messages, tools) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), + stt, + context_aggregator.user(), + llm, + tts, + transport.output(), + context_aggregator.assistant(), + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + ), + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + await task.queue_frames([context_aggregator.user().get_context_frame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + await task.cancel() + + runner = PipelineRunner(handle_sigint=handle_sigint) + + await runner.run(task) + + +if __name__ == "__main__": + from pipecat.examples.run import main + + main(run_example, transport_params=transport_params) diff --git a/src/pipecat/services/ollama/llm.py b/src/pipecat/services/ollama/llm.py index de01dbc5b..2284a5070 100644 --- a/src/pipecat/services/ollama/llm.py +++ b/src/pipecat/services/ollama/llm.py @@ -42,4 +42,4 @@ class OLLamaLLMService(OpenAILLMService): An OpenAI-compatible client configured for Ollama. """ logger.debug(f"Creating Ollama client with api {base_url}") - return super().create_client(base_url, **kwargs) + return super().create_client(base_url=base_url, **kwargs)