Compare commits
37 Commits
v0.0.92
...
hush/realt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6101ff9661 | ||
|
|
c20aa78648 | ||
|
|
38f27ad991 | ||
|
|
0c38585034 | ||
|
|
8a09bbbf0e | ||
|
|
fb737ff671 | ||
|
|
b7a4d7371c | ||
|
|
ef88d6a2ea | ||
|
|
5c1bd8cda2 | ||
|
|
a82158045a | ||
|
|
b1533ddfc4 | ||
|
|
0abc699f24 | ||
|
|
09018071e8 | ||
|
|
1c53a5fd01 | ||
|
|
05d4753d3e | ||
|
|
87131850bc | ||
|
|
af83f45a49 | ||
|
|
62e45f466a | ||
|
|
e85e93b9b1 | ||
|
|
074d3ff162 | ||
|
|
d680ec2e69 | ||
|
|
d905b21f72 | ||
|
|
6c5d84ca4c | ||
|
|
57f6ae9e50 | ||
|
|
2d03e51109 | ||
|
|
09a7e08cbf | ||
|
|
6f172bba8f | ||
|
|
1433df4de2 | ||
|
|
8d0e7e5e16 | ||
|
|
e7b8da7a83 | ||
|
|
35c48a45cf | ||
|
|
14a365aa16 | ||
|
|
779fc0419d | ||
|
|
5052da8ce6 | ||
|
|
1ecf6e05fe | ||
|
|
5cc1d8a024 | ||
|
|
1e31fc7f9b |
36
CHANGELOG.md
36
CHANGELOG.md
@@ -5,10 +5,46 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Added support for loading external observers. You can now register custom
|
||||
pipeline observers by setting the `PIPECAT_OBSERVER_FILES` environment
|
||||
variable. This variable should contain a colon-separated list of Python files
|
||||
(e.g. `export PIPECAT_OBSERVER_FILES="observer1.py:observer2.py:..."`). Each
|
||||
file must define a function with the following signature:
|
||||
|
||||
```python
|
||||
async def create_observers(task: PipelineTask) -> Iterable[BaseObserver]:
|
||||
...
|
||||
```
|
||||
|
||||
- Added support for new sonic-3 languages in `CartesiaTTSService` and
|
||||
`CartesiaHttpTTSService`.
|
||||
|
||||
- `EndFrame` and `EndTaskFrame` have an optional `reason` field to indicate why
|
||||
the pipeline is being ended.
|
||||
|
||||
- `CancelFrame` and `CancelTaskFrame` have an optional `reason` field to
|
||||
indicate why the pipeline is being canceled. This can be also specified when
|
||||
you cancel a task with `PipelineTask.cancel(reason="cancellation your
|
||||
reason")`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed `GeminiLiveLLMService` session resumption after a connection timeout.
|
||||
|
||||
- `GeminiLiveLLMService` now properly supports context-provided system
|
||||
instruction and tools.
|
||||
|
||||
## [0.0.92] - 2025-10-31 🎃 "The Haunted Edition" 👻
|
||||
|
||||
### Added
|
||||
|
||||
- Added supprt for Sarvam Speech-to-Text service (`SarvamSTTService`) with streaming WebSocket
|
||||
support for `saarika` (STT) and `saaras` (STT-translate) models.
|
||||
|
||||
- Added a new `DeepgramHttpTTSService`, which delivers a meaningful reduction
|
||||
in latency when compared to the `DeepgramTTSService`.
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService
|
||||
from pipecat.services.sarvam.tts import SarvamHttpTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
@@ -63,7 +63,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Create an HTTP session
|
||||
async with aiohttp.ClientSession() as session:
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
stt = SarvamSTTService(
|
||||
api_key=os.getenv("SARVAM_API_KEY"),
|
||||
model="saarika:v2.5",
|
||||
)
|
||||
|
||||
tts = SarvamHttpTTSService(
|
||||
api_key=os.getenv("SARVAM_API_KEY"),
|
||||
|
||||
@@ -24,8 +24,8 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
@@ -62,7 +62,10 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
stt = SarvamSTTService(
|
||||
api_key=os.getenv("SARVAM_API_KEY"),
|
||||
model="saarika:v2.5",
|
||||
)
|
||||
|
||||
tts = SarvamTTSService(
|
||||
api_key=os.getenv("SARVAM_API_KEY"),
|
||||
|
||||
@@ -75,7 +75,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# text_filters=[MarkdownTextFilter()],
|
||||
)
|
||||
|
||||
llm = NimLLMService(api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.3-70b-instruct")
|
||||
llm = NimLLMService(
|
||||
api_key=os.getenv("NVIDIA_API_KEY"),
|
||||
model="nvidia/llama-3.3-nemotron-super-49b-v1.5",
|
||||
# Recommended when turning thinking off
|
||||
params=NimLLMService.InputParams(temperature=0.0),
|
||||
)
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
@@ -102,6 +107,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function])
|
||||
messages = [
|
||||
# Disable thinking by sending this message first
|
||||
# Check the model for the corresponding "no thinking" message
|
||||
{"role": "system", "content": "/no_think"},
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
|
||||
@@ -4,10 +4,39 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Realtime API Example with Mem0 Memory Integration.
|
||||
|
||||
This example demonstrates how to use OpenAI's Realtime API with Pipecat
|
||||
for conversational AI with memory capabilities using Mem0.
|
||||
|
||||
The example:
|
||||
1. Sets up a real-time audio conversation using OpenAI's Realtime API
|
||||
2. Uses Mem0 to store and retrieve memories from conversations
|
||||
3. Creates personalized greetings based on previous interactions
|
||||
4. Demonstrates function calling capabilities
|
||||
5. Shows how to add tools dynamically at runtime
|
||||
|
||||
Example usage (run from pipecat root directory):
|
||||
$ pip install "pipecat-ai[daily,openai,mem0]"
|
||||
$ python examples/foundational/19-openai-realtime.py
|
||||
|
||||
Requirements:
|
||||
- OpenAI API key (for Realtime API)
|
||||
- Daily API key (for video/audio transport)
|
||||
- Mem0 API key (for cloud-based memory storage)
|
||||
- [Optional] Deepgram API key (for STT fallback)
|
||||
|
||||
Environment variables (set in .env or in your terminal using `export`):
|
||||
DAILY_SAMPLE_ROOM_URL=daily_sample_room_url
|
||||
DAILY_API_KEY=daily_api_key
|
||||
OPENAI_API_KEY=openai_api_key
|
||||
MEM0_API_KEY=mem0_api_key
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
@@ -27,6 +56,7 @@ from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.mem0.memory import Mem0MemoryService
|
||||
from pipecat.services.openai.realtime.events import (
|
||||
AudioConfiguration,
|
||||
AudioInput,
|
||||
@@ -42,6 +72,64 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
try:
|
||||
from mem0 import Memory, MemoryClient # noqa: F401
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Mem0, you need to `pip install mem0ai`. Also, set the environment variable MEM0_API_KEY."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
async def get_initial_greeting(
|
||||
memory_client: Union[MemoryClient, Memory], user_id: str, agent_id: str, run_id: str
|
||||
) -> str:
|
||||
"""Fetch all memories for the user and create a personalized greeting.
|
||||
|
||||
Returns:
|
||||
A personalized greeting based on user memories
|
||||
"""
|
||||
try:
|
||||
if isinstance(memory_client, Memory):
|
||||
filters = {"user_id": user_id, "agent_id": agent_id, "run_id": run_id}
|
||||
filters = {k: v for k, v in filters.items() if v is not None}
|
||||
memories = memory_client.get_all(**filters)
|
||||
else:
|
||||
# Create filters based on available IDs
|
||||
id_pairs = [("user_id", user_id), ("agent_id", agent_id), ("run_id", run_id)]
|
||||
clauses = [{name: value} for name, value in id_pairs if value is not None]
|
||||
filters = {"AND": clauses} if clauses else {}
|
||||
|
||||
# Get all memories for this user
|
||||
memories = memory_client.get_all(filters=filters, version="v2", output_format="v1.1")
|
||||
|
||||
if not memories or len(memories) == 0:
|
||||
logger.debug(f"!!! No memories found for this user. {memories}")
|
||||
return "Hello! It's nice to meet you. How can I help you today?"
|
||||
|
||||
# Create a personalized greeting based on memories
|
||||
greeting = "Hello! It's great to see you again. "
|
||||
|
||||
# Add some personalization based on memories (limit to 3 memories for brevity)
|
||||
if len(memories) > 0:
|
||||
greeting += "Based on our previous conversations, I remember: "
|
||||
for i, memory in enumerate(memories["results"][:3], 1):
|
||||
memory_content = memory.get("memory", "")
|
||||
# Keep memory references brief
|
||||
if len(memory_content) > 100:
|
||||
memory_content = memory_content[:97] + "..."
|
||||
greeting += f"{memory_content} "
|
||||
|
||||
greeting += "How can I help you today?"
|
||||
|
||||
logger.debug(f"Created personalized greeting from {len(memories)} memories")
|
||||
return greeting
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving initial memories from Mem0: {e}")
|
||||
return "Hello! How can I help you today?"
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
@@ -134,8 +222,62 @@ transport_params = {
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# Note: You can pass the user_id as a parameter in API call
|
||||
USER_ID = "pipecat-realtime-user"
|
||||
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
# =====================================================================
|
||||
# OPTION 1: Using Mem0 API (cloud-based approach)
|
||||
# This approach uses Mem0's cloud service for memory management
|
||||
# Requires: MEM0_API_KEY set in your environment
|
||||
# =====================================================================
|
||||
memory = Mem0MemoryService(
|
||||
api_key=os.getenv("MEM0_API_KEY"), # Your Mem0 API key
|
||||
user_id=USER_ID, # Unique identifier for the user
|
||||
agent_id="realtime-agent", # Optional identifier for the agent
|
||||
run_id="realtime-session", # Optional identifier for the run
|
||||
params=Mem0MemoryService.InputParams(
|
||||
search_limit=10,
|
||||
search_threshold=0.3,
|
||||
api_version="v2",
|
||||
system_prompt="Based on previous conversations, I recall: \n\n",
|
||||
add_as_system_message=True,
|
||||
position=1,
|
||||
),
|
||||
)
|
||||
|
||||
# =====================================================================
|
||||
# OPTION 2: Using Mem0 with local configuration (self-hosted approach)
|
||||
# This approach uses a local LLM configuration for memory management
|
||||
# Requires: Anthropic API key if using Claude model
|
||||
# =====================================================================
|
||||
# Uncomment the following code and comment out the previous memory initialization to use local config
|
||||
|
||||
# local_config = {
|
||||
# "llm": {
|
||||
# "provider": "anthropic",
|
||||
# "config": {
|
||||
# "model": "claude-3-5-sonnet-20240620",
|
||||
# "api_key": os.getenv("ANTHROPIC_API_KEY"), # Make sure to set this in your .env
|
||||
# }
|
||||
# },
|
||||
# "embedder": {
|
||||
# "provider": "openai",
|
||||
# "config": {
|
||||
# "model": "text-embedding-3-large"
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
|
||||
# # Initialize Mem0 memory service with local configuration
|
||||
# memory = Mem0MemoryService(
|
||||
# local_config=local_config, # Use local LLM for memory processing
|
||||
# user_id=USER_ID, # Unique identifier for the user
|
||||
# # agent_id="realtime-agent", # Optional identifier for the agent
|
||||
# # run_id="realtime-session", # Optional identifier for the run
|
||||
# )
|
||||
|
||||
session_properties = SessionProperties(
|
||||
audio=AudioConfiguration(
|
||||
input=AudioInput(
|
||||
@@ -149,7 +291,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
)
|
||||
),
|
||||
# tools=tools,
|
||||
instructions="""You are a helpful and friendly AI.
|
||||
instructions="""You are a helpful and friendly AI with memory capabilities.
|
||||
|
||||
Act like a human, but remember that you aren't a human and that you can't do human
|
||||
things in the real world. Your voice and personality should be warm and engaging, with a lively and
|
||||
@@ -162,6 +304,9 @@ even if you're asked about them.
|
||||
You are participating in a voice conversation. Keep your responses concise, short, and to the point
|
||||
unless specifically asked to elaborate on a topic.
|
||||
|
||||
You can remember things about the person you are talking to. If the user asks you to remember
|
||||
something, make sure to remember it. Greet the user by their name if you know about it.
|
||||
|
||||
Remember, your responses should be short. Just one or two sentences, usually. Respond in English.""",
|
||||
)
|
||||
|
||||
@@ -182,8 +327,9 @@ Remember, your responses should be short. Just one or two sentences, usually. Re
|
||||
# Create a standard OpenAI LLM context object using the normal messages format. The
|
||||
# OpenAIRealtimeLLMService will convert this internally to messages that the
|
||||
# openai WebSocket API can understand.
|
||||
# We'll add the initial greeting message after getting memories
|
||||
context = LLMContext(
|
||||
[{"role": "user", "content": "Say hello!"}],
|
||||
[],
|
||||
tools,
|
||||
)
|
||||
|
||||
@@ -194,6 +340,7 @@ Remember, your responses should be short. Just one or two sentences, usually. Re
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
transcript.user(), # LLM pushes TranscriptionFrames upstream
|
||||
memory, # Mem0 memory service
|
||||
llm, # LLM
|
||||
transport.output(), # Transport bot output
|
||||
transcript.assistant(), # After the transcript output, to time with the audio output
|
||||
@@ -214,6 +361,18 @@ Remember, your responses should be short. Just one or two sentences, usually. Re
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
|
||||
# Get personalized greeting based on user memories
|
||||
greeting = await get_initial_greeting(
|
||||
memory_client=memory.memory_client,
|
||||
user_id=USER_ID,
|
||||
agent_id="realtime-agent",
|
||||
run_id="realtime-session",
|
||||
)
|
||||
|
||||
# Add the greeting as a user message to start the conversation
|
||||
context.add_message({"role": "user", "content": greeting})
|
||||
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
|
||||
@@ -141,8 +141,16 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
# You can provide the system instructions and tools in the context rather
|
||||
# than as arguments to GeminiLiveLLMService, but note that doing so will
|
||||
# trigger a (fast) reconnection when the GeminiLiveLLMService first
|
||||
# receives the context (i.e. when we send the LLMRunFrame below).
|
||||
context = LLMContext(
|
||||
[{"role": "user", "content": "Say hello."}],
|
||||
[
|
||||
# {"role": "system", "content": system_instruction},
|
||||
{"role": "user", "content": "Say hello."},
|
||||
],
|
||||
# tools,
|
||||
)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -63,10 +64,12 @@ class UrlToImageProcessor(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def extract_url(self, text: str):
|
||||
pattern = r"!\[[^\]]*\]\((https?://[^)]+\.(png|jpg|jpeg|PNG|JPG|JPEG))\)"
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
|
||||
return None
|
||||
|
||||
async def run_image_process(self, image_url: str):
|
||||
@@ -130,9 +133,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
mcp = MCPClient(
|
||||
server_params=StdioServerParameters(
|
||||
command=shutil.which("npx"),
|
||||
args=["-y", "@programcomputer/nasa-mcp-server@latest"],
|
||||
# https://api.nasa.gov
|
||||
env={"NASA_API_KEY": os.getenv("NASA_API_KEY")},
|
||||
# https://github.com/r-huijts/rijksmuseum-mcp
|
||||
args=["-y", "mcp-server-rijksmuseum"],
|
||||
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -141,15 +144,20 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
mcp_image = UrlToImageProcessor(aiohttp_session=session)
|
||||
|
||||
tools = await mcp.register_tools(llm)
|
||||
tools = {}
|
||||
try:
|
||||
tools = await mcp.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
system = f"""
|
||||
You are a helpful LLM in a WebRTC call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to a number of tools provided by NASA MCP. Use any and all tools to help users.
|
||||
When asked for the astronomy picture of the day, PASS in NO date to the API.
|
||||
This ensures we get the latest picture available. If as specific date is asked for, you
|
||||
can pass in that date to the API.
|
||||
You have access to tools to search the Rijksmuseum collection.
|
||||
Offer, for example, to show the earliest Rembrandt work from the museum. Use the `search_artwork` tool.
|
||||
The tool may respond with a JSON object with an `artworks` array. Choose the art from that array.
|
||||
Once the tool has responded, tell the user the title and use the `open_image_in_browser` tool.
|
||||
Your output will be converted to audio so don't include special characters in your answers.
|
||||
Respond to what the user said in a creative and helpful way.
|
||||
Don't overexplain what you are doing.
|
||||
@@ -206,14 +214,13 @@ async def bot(runner_args: RunnerArguments):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.getenv("NASA_API_KEY"):
|
||||
if not os.getenv("RIJKSMUSEUM_API_KEY"):
|
||||
logger.error(
|
||||
f"Please set NASA_API_KEY environment variable for this example. See https://api.nasa.gov"
|
||||
f"Please set RIJKSMUSEUM_API_KEY environment variable for this example. See https://github.com/r-huijts/rijksmuseum-mcp and https://www.rijksmuseum.nl/en/register?redirectUrl=https://www.https://www.rijksmuseum.nl/en/rijksstudio/my/profile"
|
||||
)
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
|
||||
@@ -79,7 +79,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.error(f"error setting up mcp")
|
||||
logger.exception("error trace:")
|
||||
|
||||
tools = await mcp.register_tools(llm)
|
||||
tools = {}
|
||||
try:
|
||||
tools = await mcp.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
system = f"""
|
||||
You are a helpful LLM in a WebRTC call.
|
||||
|
||||
@@ -132,9 +132,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
system = f"""
|
||||
You are a helpful LLM in a WebRTC call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to a number of tools provided by NASA MCP. Use any and all tools to help users.
|
||||
When asked for today's date, use 'https://www.datetoday.net/'.
|
||||
When asked for the astronomy picture of the day, use 'https://www.datetoday.net/', to get today's date.
|
||||
You have access to tools to search the Rijksmuseum collection.
|
||||
Offer, for example, to show the earliest Rembrandt work from the museum. Use the `search_artwork` tool.
|
||||
The tool may respond with a JSON object with an `artworks` array. Choose the art from that array.
|
||||
Once the tool has responded, tell the user the title and use the `open_image_in_browser` tool.
|
||||
Your output will be converted to audio so don't include special characters in your answers.
|
||||
Respond to what the user said in a creative and helpful way.
|
||||
Don't overexplain what you are doing.
|
||||
@@ -147,13 +148,13 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
mcp = MCPClient(
|
||||
server_params=StdioServerParameters(
|
||||
command=shutil.which("npx"),
|
||||
args=["-y", "@programcomputer/nasa-mcp-server@latest"],
|
||||
# https://api.nasa.gov
|
||||
env={"NASA_API_KEY": os.getenv("NASA_API_KEY")},
|
||||
# https://github.com/r-huijts/rijksmuseum-mcp
|
||||
args=["-y", "mcp-server-error setting up mcp"],
|
||||
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up nasa mcp")
|
||||
logger.error(f"error setting up rijksmuseum mcp")
|
||||
logger.exception("error trace:")
|
||||
try:
|
||||
# https://docs.mcp.run/integrating/tutorials/mcp-run-sse-openai-agents/
|
||||
@@ -164,8 +165,14 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.error(f"error setting up mcp.run")
|
||||
logger.exception("error trace:")
|
||||
|
||||
tools = await mcp.register_tools(llm)
|
||||
run_tools = await mcp_run.register_tools(llm)
|
||||
tools = {}
|
||||
run_tools = {}
|
||||
try:
|
||||
tools = await mcp.register_tools(llm)
|
||||
run_tools = await mcp_run.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
all_standard_tools = run_tools.standard_tools + tools.standard_tools
|
||||
all_tools = ToolsSchema(standard_tools=all_standard_tools)
|
||||
@@ -219,9 +226,9 @@ async def bot(runner_args: RunnerArguments):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.getenv("NASA_API_KEY") or not os.getenv("MCP_RUN_SSE_URL"):
|
||||
if not os.getenv("RIJKSMUSEUM_API_KEY") or not os.getenv("MCP_RUN_SSE_URL"):
|
||||
logger.error(
|
||||
f"Please set NASA_API_KEY and MCP_RUN_SSE_URL environment variables. See https://api.nasa.gov and https://mcp.run"
|
||||
f"Please set RIJKSMUSEUM_API_KEY and MCP_RUN_SSE_URL environment variables. See https://github.com/r-huijts/rijksmuseum-mcp and https://mcp.run"
|
||||
)
|
||||
import sys
|
||||
|
||||
|
||||
@@ -85,7 +85,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.error(f"error setting up mcp")
|
||||
logger.exception("error trace:")
|
||||
|
||||
tools = await mcp.register_tools(llm)
|
||||
tools = {}
|
||||
try:
|
||||
tools = await mcp.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
system = f"""
|
||||
You are a helpful LLM in a WebRTC call.
|
||||
|
||||
@@ -93,7 +93,7 @@ rime = [ "pipecat-ai[websockets-base]" ]
|
||||
riva = [ "nvidia-riva-client~=2.21.1" ]
|
||||
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.117.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
|
||||
sambanova = []
|
||||
sarvam = [ "pipecat-ai[websockets-base]" ]
|
||||
sarvam = [ "sarvamai==0.1.21", "pipecat-ai[websockets-base]" ]
|
||||
sentry = [ "sentry-sdk>=2.28.0,<3" ]
|
||||
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ]
|
||||
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1,<2" ]
|
||||
|
||||
@@ -773,9 +773,15 @@ class CancelFrame(SystemFrame):
|
||||
|
||||
Indicates that a pipeline needs to stop right away without
|
||||
processing remaining queued frames.
|
||||
|
||||
Parameters:
|
||||
reason: Optional reason for pushing a cancel frame.
|
||||
"""
|
||||
|
||||
pass
|
||||
reason: Optional[str] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(reason: {self.reason})"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1366,9 +1372,15 @@ class EndTaskFrame(TaskFrame):
|
||||
This is used to notify the pipeline task that the pipeline should be
|
||||
closed nicely (flushing all the queued frames) by pushing an EndFrame
|
||||
downstream. This frame should be pushed upstream.
|
||||
|
||||
Parameters:
|
||||
reason: Optional reason for pushing an end frame.
|
||||
"""
|
||||
|
||||
pass
|
||||
reason: Optional[str] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(reason: {self.reason})"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1378,9 +1390,15 @@ class CancelTaskFrame(TaskFrame):
|
||||
This is used to notify the pipeline task that the pipeline should be
|
||||
stopped immediately by pushing a CancelFrame downstream. This frame
|
||||
should be pushed upstream.
|
||||
|
||||
Parameters:
|
||||
reason: Optional reason for pushing a cancel frame.
|
||||
"""
|
||||
|
||||
pass
|
||||
reason: Optional[str] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(reason: {self.reason})"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1451,9 +1469,15 @@ class EndFrame(ControlFrame):
|
||||
sending frames to its output channel(s) and close all its threads. Note,
|
||||
that this is a control frame, which means it will be received in the order it
|
||||
was sent.
|
||||
|
||||
Parameters:
|
||||
reason: Optional reason for pushing an end frame.
|
||||
"""
|
||||
|
||||
pass
|
||||
reason: Optional[str] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(reason: {self.reason})"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -12,6 +12,9 @@ including heartbeats, idle detection, and observer integration.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
from loguru import logger
|
||||
@@ -446,10 +449,14 @@ class PipelineTask(BasePipelineTask):
|
||||
logger.debug(f"Task {self} scheduled to stop when done")
|
||||
await self.queue_frame(EndFrame())
|
||||
|
||||
async def cancel(self):
|
||||
"""Request the running pipeline to cancel."""
|
||||
async def cancel(self, *, reason: Optional[str] = None):
|
||||
"""Request the running pipeline to cancel.
|
||||
|
||||
Args:
|
||||
reason: Optional reason to indicate why the pipeline is being cancelled.
|
||||
"""
|
||||
if not self._finished:
|
||||
await self._cancel()
|
||||
await self._cancel(reason=reason)
|
||||
|
||||
async def run(self, params: PipelineTaskParams):
|
||||
"""Start and manage the pipeline execution until completion or cancellation.
|
||||
@@ -513,12 +520,16 @@ class PipelineTask(BasePipelineTask):
|
||||
for frame in frames:
|
||||
await self.queue_frame(frame)
|
||||
|
||||
async def _cancel(self):
|
||||
"""Internal cancellation logic for the pipeline task."""
|
||||
async def _cancel(self, *, reason: Optional[str] = None):
|
||||
"""Internal cancellation logic for the pipeline task.
|
||||
|
||||
Args:
|
||||
reason: Optional reason to indicate why the pipeline is being cancelled.
|
||||
"""
|
||||
if not self._cancelled:
|
||||
logger.debug(f"Cancelling pipeline task {self}")
|
||||
self._cancelled = True
|
||||
await self.queue_frame(CancelFrame())
|
||||
await self.queue_frame(CancelFrame(reason=reason))
|
||||
|
||||
async def _create_tasks(self):
|
||||
"""Create and start all pipeline processing tasks."""
|
||||
@@ -633,6 +644,9 @@ class PipelineTask(BasePipelineTask):
|
||||
|
||||
async def _setup(self, params: PipelineTaskParams):
|
||||
"""Set up the pipeline task and all processors."""
|
||||
# Load additional observers.
|
||||
await self._load_observer_files()
|
||||
|
||||
mgr_params = TaskManagerParams(loop=params.loop)
|
||||
self._task_manager.setup(mgr_params)
|
||||
|
||||
@@ -716,11 +730,11 @@ class PipelineTask(BasePipelineTask):
|
||||
if isinstance(frame, EndTaskFrame):
|
||||
# Tell the task we should end nicely.
|
||||
logger.debug(f"{self}: received end task frame {frame}")
|
||||
await self.queue_frame(EndFrame())
|
||||
await self.queue_frame(EndFrame(reason=frame.reason))
|
||||
elif isinstance(frame, CancelTaskFrame):
|
||||
# Tell the task we should end right away.
|
||||
logger.debug(f"{self}: received cancel task frame {frame}")
|
||||
await self.queue_frame(CancelFrame())
|
||||
await self.queue_frame(CancelFrame(reason=frame.reason))
|
||||
elif isinstance(frame, StopTaskFrame):
|
||||
# Tell the task we should stop nicely.
|
||||
logger.debug(f"{self}: received stop task frame {frame}")
|
||||
@@ -836,6 +850,27 @@ class PipelineTask(BasePipelineTask):
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _load_observer_files(self):
|
||||
observer_files = os.environ.get("PIPECAT_OBSERVER_FILES", "").split(":")
|
||||
for f in observer_files:
|
||||
try:
|
||||
path = Path(f).resolve()
|
||||
module_name = path.stem
|
||||
spec = importlib.util.spec_from_file_location(module_name, str(path))
|
||||
if spec:
|
||||
logger.debug(f"{self} loading observers from {path}")
|
||||
|
||||
# Load module.
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Create observers.
|
||||
observers = await module.create_observers(self)
|
||||
for observer in observers:
|
||||
self.add_observer(observer)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error loading external observers from {f}: {e}")
|
||||
|
||||
def _print_dangling_tasks(self):
|
||||
"""Log any dangling tasks that haven't been properly cleaned up."""
|
||||
tasks = [t.get_name() for t in self._task_manager.current_tasks()]
|
||||
|
||||
@@ -216,6 +216,7 @@ async def parse_telephony_websocket(websocket: WebSocket):
|
||||
"account_sid": start_data.get("account_sid"),
|
||||
"from": start_data.get("from", ""),
|
||||
"to": start_data.get("to", ""),
|
||||
"custom_parameters": start_data.get("custom_parameters", ""),
|
||||
}
|
||||
|
||||
else:
|
||||
|
||||
@@ -78,20 +78,47 @@ def language_to_cartesia_language(language: Language) -> Optional[str]:
|
||||
The corresponding Cartesia language code, or None if not supported.
|
||||
"""
|
||||
BASE_LANGUAGES = {
|
||||
Language.AR: "ar",
|
||||
Language.BG: "bg",
|
||||
Language.BN: "bn",
|
||||
Language.CS: "cs",
|
||||
Language.DA: "da",
|
||||
Language.DE: "de",
|
||||
Language.EN: "en",
|
||||
Language.EL: "el",
|
||||
Language.ES: "es",
|
||||
Language.FI: "fi",
|
||||
Language.FR: "fr",
|
||||
Language.GU: "gu",
|
||||
Language.HE: "he",
|
||||
Language.HI: "hi",
|
||||
Language.HR: "hr",
|
||||
Language.HU: "hu",
|
||||
Language.ID: "id",
|
||||
Language.IT: "it",
|
||||
Language.JA: "ja",
|
||||
Language.KA: "ka",
|
||||
Language.KN: "kn",
|
||||
Language.KO: "ko",
|
||||
Language.ML: "ml",
|
||||
Language.MR: "mr",
|
||||
Language.MS: "ms",
|
||||
Language.NL: "nl",
|
||||
Language.NO: "no",
|
||||
Language.PA: "pa",
|
||||
Language.PL: "pl",
|
||||
Language.PT: "pt",
|
||||
Language.RO: "ro",
|
||||
Language.RU: "ru",
|
||||
Language.SK: "sk",
|
||||
Language.SV: "sv",
|
||||
Language.TA: "ta",
|
||||
Language.TE: "te",
|
||||
Language.TH: "th",
|
||||
Language.TL: "tl",
|
||||
Language.TR: "tr",
|
||||
Language.UK: "uk",
|
||||
Language.VI: "vi",
|
||||
Language.ZH: "zh",
|
||||
}
|
||||
|
||||
|
||||
@@ -672,8 +672,8 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._voice_id = voice_id
|
||||
self._language_code = params.language
|
||||
|
||||
self._system_instruction = system_instruction
|
||||
self._tools = tools
|
||||
self._system_instruction_from_init = system_instruction
|
||||
self._tools_from_init = tools
|
||||
self._inference_on_context_initialization = inference_on_context_initialization
|
||||
self._needs_turn_complete_message = False
|
||||
|
||||
@@ -964,16 +964,51 @@ class GeminiLiveLLMService(LLMService):
|
||||
if not self._context:
|
||||
# We got our initial context
|
||||
self._context = context
|
||||
if context.tools:
|
||||
self._tools = context.tools
|
||||
|
||||
# If context contains system instruction or tools, reconnect in
|
||||
# order to apply them.
|
||||
# (Context-provided system instruction and tools take precedence
|
||||
# over the ones provided at initialization time. Note that we could
|
||||
# do more sophisticated comparisons here, but for now this is
|
||||
# sufficient: we'll assume folks won't mean to provide these
|
||||
# settings both in the context and at initialization time. In a
|
||||
# future change, we could/should implement the ability to swap
|
||||
# these settings at any point).
|
||||
adapter: GeminiLLMAdapter = self.get_llm_adapter()
|
||||
params = adapter.get_llm_invocation_params(self._context)
|
||||
system_instruction = params["system_instruction"]
|
||||
tools = params["tools"]
|
||||
if system_instruction and self._system_instruction_from_init:
|
||||
logger.warning(
|
||||
"System instruction provided both at init time and in context; using context-provided value."
|
||||
)
|
||||
if tools and self._tools_from_init:
|
||||
logger.warning(
|
||||
"Tools provided both at init time and in context; using context-provided value."
|
||||
)
|
||||
if system_instruction or tools:
|
||||
await self._reconnect()
|
||||
|
||||
# Initialize our bookkeeping of already-completed tool calls in
|
||||
# the context
|
||||
await self._process_completed_function_calls(send_new_results=False)
|
||||
|
||||
# Create initial response if needed, based on conversation history
|
||||
# in context
|
||||
await self._create_initial_response()
|
||||
else:
|
||||
# We got an updated context.
|
||||
# This may contain a new user message or tool call result.
|
||||
self._context = context
|
||||
|
||||
# Here we assume that the updated context will contain either:
|
||||
# - new messages (that the Gemini Live service, with its own
|
||||
# context management, is already aware of), or
|
||||
# - tool call results (that we need to tell the remote service
|
||||
# about).
|
||||
# (In the future, we could do more sophisticated diffing here,
|
||||
# which would enable the user to programmatically manipulate the
|
||||
# context).
|
||||
|
||||
# Send results for newly-completed function calls, if any.
|
||||
await self._process_completed_function_calls(send_new_results=True)
|
||||
|
||||
@@ -1103,18 +1138,25 @@ class GeminiLiveLLMService(LLMService):
|
||||
automatic_activity_detection=vad_config
|
||||
)
|
||||
|
||||
# Add system instruction to configuration, if provided
|
||||
system_instruction = self._system_instruction or ""
|
||||
if self._context and hasattr(self._context, "extract_system_instructions"):
|
||||
system_instruction += "\n" + self._context.extract_system_instructions()
|
||||
# Add system instruction and tools to configuration, if provided.
|
||||
# These settings from the context take precedence over the ones
|
||||
# provided at initialization time.
|
||||
adapter: GeminiLLMAdapter = self.get_llm_adapter()
|
||||
system_instruction = None
|
||||
tools = None
|
||||
if self._context:
|
||||
params = adapter.get_llm_invocation_params(self._context)
|
||||
system_instruction = params["system_instruction"]
|
||||
tools = params["tools"]
|
||||
else:
|
||||
system_instruction = self._system_instruction_from_init
|
||||
tools = adapter.from_standard_tools(self._tools_from_init)
|
||||
if system_instruction:
|
||||
logger.debug(f"Setting system instruction: {system_instruction}")
|
||||
config.system_instruction = system_instruction
|
||||
|
||||
# Add tools to configuration, if provided
|
||||
if self._tools:
|
||||
logger.debug(f"Setting tools: {self._tools}")
|
||||
config.tools = self.get_llm_adapter().from_standard_tools(self._tools)
|
||||
if tools:
|
||||
logger.debug(f"Setting tools: {tools}")
|
||||
config.tools = tools
|
||||
|
||||
# Start the connection
|
||||
self._connection_task = self.create_task(self._connection_task_handler(config=config))
|
||||
@@ -1675,13 +1717,17 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._session_resumption_handle = update.new_handle
|
||||
|
||||
async def _handle_send_error(self, error: Exception):
|
||||
# Ignore "expected" errors that may have occurred for messages that
|
||||
# were in-flight when a disconnection occurred.
|
||||
if self._disconnecting or not self._session:
|
||||
return
|
||||
|
||||
# In server-to-server contexts, a WebSocket error should be quite rare.
|
||||
# Given how hard it is to recover from a send-side error with proper
|
||||
# state management, and that exponential backoff for retries can have
|
||||
# cost/stability implications for a service cluster, let's just treat a
|
||||
# send-side error as fatal.
|
||||
if not self._disconnecting:
|
||||
await self.push_error(ErrorFrame(error=f"{self} Send error: {error}", fatal=True))
|
||||
await self.push_error(ErrorFrame(error=f"{self} Send error: {error}", fatal=True))
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
|
||||
468
src/pipecat/services/sarvam/stt.py
Normal file
468
src/pipecat/services/sarvam/stt.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""Sarvam AI Speech-to-Text service implementation.
|
||||
|
||||
This module provides a streaming Speech-to-Text service using Sarvam AI's WebSocket-based
|
||||
API. It supports real-time transcription with Voice Activity Detection (VAD) and
|
||||
can handle multiple audio formats for Indian language speech recognition.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
from sarvamai import AsyncSarvamAI
|
||||
from sarvamai.core.api_error import ApiError
|
||||
from sarvamai.core.events import EventType
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Sarvam, you need to `pip install pipecat-ai[sarvam]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_sarvam_language(language: Language) -> str:
|
||||
"""Convert a Language enum to Sarvam's language code format.
|
||||
|
||||
Args:
|
||||
language: The Language enum value to convert.
|
||||
|
||||
Returns:
|
||||
The Sarvam language code string.
|
||||
"""
|
||||
# Mapping of pipecat Language enum to Sarvam language codes
|
||||
SARVAM_LANGUAGES = {
|
||||
Language.BN_IN: "bn-IN",
|
||||
Language.GU_IN: "gu-IN",
|
||||
Language.HI_IN: "hi-IN",
|
||||
Language.KN_IN: "kn-IN",
|
||||
Language.ML_IN: "ml-IN",
|
||||
Language.MR_IN: "mr-IN",
|
||||
Language.TA_IN: "ta-IN",
|
||||
Language.TE_IN: "te-IN",
|
||||
Language.PA_IN: "pa-IN",
|
||||
Language.OR_IN: "od-IN",
|
||||
Language.EN_IN: "en-IN",
|
||||
Language.AS_IN: "as-IN",
|
||||
}
|
||||
|
||||
return SARVAM_LANGUAGES.get(
|
||||
language, "unknown"
|
||||
) # Default to unknown (Sarvam models auto-detect the language)
|
||||
|
||||
|
||||
class SarvamSTTService(STTService):
|
||||
"""Sarvam speech-to-text service.
|
||||
|
||||
Provides real-time speech recognition using Sarvam's WebSocket API.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Sarvam STT service.
|
||||
|
||||
Parameters:
|
||||
language: Target language for transcription. Defaults to None (required for saarika models).
|
||||
prompt: Optional prompt to guide translation style/context for STT-Translate models.
|
||||
Only applicable to saaras (STT-Translate) models. Defaults to None.
|
||||
vad_signals: Enable VAD signals in response. Defaults to True.
|
||||
high_vad_sensitivity: Enable high VAD (Voice Activity Detection) sensitivity. Defaults to False.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = None
|
||||
prompt: Optional[str] = None
|
||||
vad_signals: bool = True
|
||||
high_vad_sensitivity: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str = "saarika:v2.5",
|
||||
sample_rate: Optional[int] = None,
|
||||
input_audio_codec: str = "wav",
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Sarvam STT service.
|
||||
|
||||
Args:
|
||||
api_key: Sarvam API key for authentication.
|
||||
model: Sarvam model to use for transcription.
|
||||
sample_rate: Audio sample rate. Defaults to 16000 if not specified.
|
||||
input_audio_codec: Audio codec/format of the input file. Defaults to "wav".
|
||||
params: Configuration parameters for Sarvam STT service.
|
||||
**kwargs: Additional arguments passed to the parent STTService.
|
||||
"""
|
||||
params = params or SarvamSTTService.InputParams()
|
||||
|
||||
# Validate that saaras models don't accept language parameter
|
||||
if "saaras" in model.lower():
|
||||
if params.language is not None:
|
||||
raise ValueError(
|
||||
f"Model '{model}' does not accept language parameter. "
|
||||
"STT-Translate models auto-detect language."
|
||||
)
|
||||
|
||||
# Validate that saarika models don't accept prompt parameter
|
||||
if "saarika" in model.lower():
|
||||
if params.prompt is not None:
|
||||
raise ValueError(
|
||||
f"Model '{model}' does not accept prompt parameter. "
|
||||
"Prompts are only supported for STT-Translate models"
|
||||
)
|
||||
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self.set_model_name(model)
|
||||
self._api_key = api_key
|
||||
self._language_code = params.language
|
||||
# For saarika models, default to "unknown" if language is not provided
|
||||
if params.language:
|
||||
self._language_string = language_to_sarvam_language(params.language)
|
||||
elif "saarika" in model.lower():
|
||||
self._language_string = "unknown"
|
||||
else:
|
||||
self._language_string = None
|
||||
self._prompt = params.prompt
|
||||
|
||||
# Store connection parameters
|
||||
self._vad_signals = params.vad_signals
|
||||
self._high_vad_sensitivity = params.high_vad_sensitivity
|
||||
self._input_audio_codec = input_audio_codec
|
||||
|
||||
# Initialize Sarvam SDK client
|
||||
self._sarvam_client = AsyncSarvamAI(api_subscription_key=api_key)
|
||||
self._websocket_context = None
|
||||
self._socket_client = None
|
||||
self._receive_task = None
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str:
|
||||
"""Convert pipecat Language enum to Sarvam's language code.
|
||||
|
||||
Args:
|
||||
language: The Language enum value to convert.
|
||||
|
||||
Returns:
|
||||
The Sarvam language code string.
|
||||
"""
|
||||
return language_to_sarvam_language(language)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Sarvam service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the recognition language and reconnect.
|
||||
|
||||
Args:
|
||||
language: The language to use for speech recognition.
|
||||
"""
|
||||
# saaras models do not accept a language parameter
|
||||
if "saaras" in self.model_name.lower():
|
||||
raise ValueError(
|
||||
f"Model '{self.model_name}' (saaras) does not accept language parameter. "
|
||||
"saaras models auto-detect language."
|
||||
)
|
||||
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._language_code = language
|
||||
self._language_string = language_to_sarvam_language(language)
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def set_prompt(self, prompt: Optional[str]):
|
||||
"""Set the translation prompt and reconnect.
|
||||
|
||||
Args:
|
||||
prompt: Prompt text to guide translation style/context.
|
||||
Pass None to clear/disable prompt.
|
||||
Only applicable to STT-Translate models, not STT models.
|
||||
"""
|
||||
# saarika models do not accept prompt parameter
|
||||
if "saarika" in self.model_name.lower():
|
||||
if prompt is not None:
|
||||
raise ValueError(
|
||||
f"Model '{self.model_name}' does not accept prompt parameter. "
|
||||
"Prompts are only supported for STT-Translate models."
|
||||
)
|
||||
# If prompt is None and it's saarika, just silently return (no-op)
|
||||
return
|
||||
|
||||
logger.info("Updating STT-Translate prompt.")
|
||||
self._prompt = prompt
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Sarvam STT service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Sarvam STT service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Sarvam STT service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes):
|
||||
"""Send audio data to Sarvam for transcription.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
Frame: None (transcription results come via WebSocket callbacks).
|
||||
"""
|
||||
if not self._socket_client:
|
||||
logger.warning("WebSocket not connected, cannot process audio")
|
||||
yield None
|
||||
return
|
||||
|
||||
try:
|
||||
# Convert audio bytes to base64 for Sarvam API
|
||||
audio_base64 = base64.b64encode(audio).decode("utf-8")
|
||||
|
||||
# Convert input_audio_codec to encoding format (prepend "audio/" if needed)
|
||||
encoding = (
|
||||
self._input_audio_codec
|
||||
if self._input_audio_codec.startswith("audio/")
|
||||
else f"audio/{self._input_audio_codec}"
|
||||
)
|
||||
|
||||
# Build method arguments
|
||||
method_kwargs = {
|
||||
"audio": audio_base64,
|
||||
"encoding": encoding,
|
||||
"sample_rate": self.sample_rate,
|
||||
}
|
||||
|
||||
# Use appropriate method based on service type
|
||||
if "saarika" in self.model_name.lower():
|
||||
# STT service
|
||||
await self._socket_client.transcribe(**method_kwargs)
|
||||
else:
|
||||
# STT-Translate service - auto-detects input language and returns translated text
|
||||
await self._socket_client.translate(**method_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio to Sarvam: {e}")
|
||||
await self.push_error(ErrorFrame(f"Failed to send audio: {e}"))
|
||||
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to Sarvam WebSocket API using the SDK."""
|
||||
logger.debug("Connecting to Sarvam")
|
||||
|
||||
try:
|
||||
# Convert boolean parameters to string for SDK
|
||||
vad_signals_str = "true" if self._vad_signals else "false"
|
||||
high_vad_sensitivity_str = "true" if self._high_vad_sensitivity else "false"
|
||||
|
||||
# Build common connection parameters
|
||||
connect_kwargs = {
|
||||
"model": self.model_name,
|
||||
"vad_signals": vad_signals_str,
|
||||
"high_vad_sensitivity": high_vad_sensitivity_str,
|
||||
"input_audio_codec": self._input_audio_codec,
|
||||
"sample_rate": str(self.sample_rate),
|
||||
}
|
||||
|
||||
# Choose the appropriate service based on model
|
||||
if "saarika" in self.model_name.lower():
|
||||
# STT service - requires language_code
|
||||
connect_kwargs["language_code"] = self._language_string
|
||||
self._websocket_context = self._sarvam_client.speech_to_text_streaming.connect(
|
||||
**connect_kwargs
|
||||
)
|
||||
else:
|
||||
# STT-Translate service - auto-detects input language and returns translated text
|
||||
self._websocket_context = (
|
||||
self._sarvam_client.speech_to_text_translate_streaming.connect(**connect_kwargs)
|
||||
)
|
||||
|
||||
# Enter the async context manager
|
||||
self._socket_client = await self._websocket_context.__aenter__()
|
||||
|
||||
# Set prompt if provided (only for STT-Translate models, after connection)
|
||||
if self._prompt is not None and "saaras" in self.model_name.lower():
|
||||
await self._socket_client.set_prompt(self._prompt)
|
||||
|
||||
# Register event handler for incoming messages
|
||||
def _message_handler(message):
|
||||
"""Wrapper to handle async response handler."""
|
||||
# Use Pipecat's built-in task management
|
||||
self.create_task(self._handle_message(message))
|
||||
|
||||
self._socket_client.on(EventType.MESSAGE, _message_handler)
|
||||
|
||||
# Start receive task using Pipecat's task management
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
|
||||
logger.info("Connected to Sarvam successfully")
|
||||
|
||||
except ApiError as e:
|
||||
logger.error(f"Sarvam API error: {e}")
|
||||
await self.push_error(ErrorFrame(f"Sarvam API error: {e}"))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Sarvam: {e}")
|
||||
self._socket_client = None
|
||||
self._websocket_context = None
|
||||
await self.push_error(ErrorFrame(f"Failed to connect to Sarvam: {e}"))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from Sarvam WebSocket API using SDK."""
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
if self._websocket_context and self._socket_client:
|
||||
try:
|
||||
# Exit the async context manager
|
||||
await self._websocket_context.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebSocket connection: {e}")
|
||||
finally:
|
||||
logger.debug("Disconnected from Sarvam WebSocket")
|
||||
self._socket_client = None
|
||||
self._websocket_context = None
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
"""Handle incoming messages from Sarvam WebSocket.
|
||||
|
||||
This task wraps the SDK's start_listening() method which processes
|
||||
messages via the registered event handler callback.
|
||||
"""
|
||||
if not self._socket_client:
|
||||
return
|
||||
|
||||
try:
|
||||
# Start listening for messages from the Sarvam SDK
|
||||
# Messages will be handled via the _message_handler callback
|
||||
await self._socket_client.start_listening()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Sarvam receive task: {e}")
|
||||
await self.push_error(ErrorFrame(f"Sarvam receive task error: {e}"))
|
||||
|
||||
async def _handle_message(self, message):
|
||||
"""Handle incoming WebSocket message from Sarvam SDK.
|
||||
|
||||
Processes transcription data and VAD events from the Sarvam service.
|
||||
|
||||
Args:
|
||||
message: The parsed response object from Sarvam WebSocket.
|
||||
"""
|
||||
logger.debug(f"Received response: {message}")
|
||||
|
||||
try:
|
||||
if message.type == "events":
|
||||
# VAD event
|
||||
signal = message.data.signal_type
|
||||
timestamp = message.data.occured_at
|
||||
logger.debug(f"VAD Signal: {signal}, Occurred at: {timestamp}")
|
||||
|
||||
if signal == "START_SPEECH":
|
||||
await self.start_metrics()
|
||||
logger.debug("User started speaking")
|
||||
await self._call_event_handler("on_speech_started")
|
||||
|
||||
elif message.type == "data":
|
||||
await self.stop_ttfb_metrics()
|
||||
transcript = message.data.transcript
|
||||
language_code = message.data.language_code
|
||||
# Prefer language from message (auto-detected for translate models). Fallback to configured.
|
||||
if language_code:
|
||||
language = self._map_language_code_to_enum(language_code)
|
||||
elif self._language_string:
|
||||
language = self._map_language_code_to_enum(self._language_string)
|
||||
else:
|
||||
language = Language.HI_IN
|
||||
|
||||
# Emit utterance end event
|
||||
await self._call_event_handler("on_utterance_end")
|
||||
|
||||
if transcript and transcript.strip():
|
||||
# Record tracing for this transcription event
|
||||
await self._handle_transcription(transcript, True, language)
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=(message.dict() if hasattr(message, "dict") else str(message)),
|
||||
)
|
||||
)
|
||||
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling Sarvam message: {e}")
|
||||
await self.push_error(ErrorFrame(f"Failed to handle message: {e}"))
|
||||
await self.stop_all_metrics()
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing.
|
||||
|
||||
This method is decorated with @traced_stt for observability.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _map_language_code_to_enum(self, language_code: str) -> Language:
|
||||
"""Map Sarvam language code to pipecat Language enum."""
|
||||
mapping = {
|
||||
"bn-IN": Language.BN_IN,
|
||||
"gu-IN": Language.GU_IN,
|
||||
"hi-IN": Language.HI_IN,
|
||||
"kn-IN": Language.KN_IN,
|
||||
"ml-IN": Language.ML_IN,
|
||||
"mr-IN": Language.MR_IN,
|
||||
"ta-IN": Language.TA_IN,
|
||||
"te-IN": Language.TE_IN,
|
||||
"pa-IN": Language.PA_IN,
|
||||
"od-IN": Language.OR_IN,
|
||||
"en-US": Language.EN_US,
|
||||
"en-IN": Language.EN_IN,
|
||||
"as-IN": Language.AS_IN,
|
||||
}
|
||||
return mapping.get(language_code, Language.HI_IN)
|
||||
|
||||
async def start_metrics(self):
|
||||
"""Start TTFB and processing metrics collection."""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
18
uv.lock
generated
18
uv.lock
generated
@@ -4550,6 +4550,7 @@ runner = [
|
||||
{ name = "uvicorn" },
|
||||
]
|
||||
sarvam = [
|
||||
{ name = "sarvamai" },
|
||||
{ name = "websockets" },
|
||||
]
|
||||
sentry = [
|
||||
@@ -4704,6 +4705,7 @@ requires-dist = [
|
||||
{ name = "python-dotenv", marker = "extra == 'runner'", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "pyvips", extras = ["binary"], marker = "extra == 'moondream'", specifier = "~=3.0.0" },
|
||||
{ name = "resampy", specifier = "~=0.4.3" },
|
||||
{ name = "sarvamai", marker = "extra == 'sarvam'", specifier = "==0.1.21" },
|
||||
{ name = "sentry-sdk", marker = "extra == 'sentry'", specifier = ">=2.28.0,<3" },
|
||||
{ name = "simli-ai", marker = "extra == 'simli'", specifier = "~=0.1.10" },
|
||||
{ name = "soundfile", marker = "extra == 'soundfile'", specifier = "~=0.13.0" },
|
||||
@@ -6212,6 +6214,22 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/c3/c0be1135726618dc1e28d181b8c442403d8dbb9e273fd791de2d4384bcdd/safetensors-0.6.2-cp38-abi3-win_amd64.whl", hash = "sha256:c7b214870df923cbc1593c3faee16bec59ea462758699bd3fee399d00aac072c", size = 320192, upload-time = "2025-08-08T13:13:59.467Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sarvamai"
|
||||
version = "0.1.21"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-core" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "websockets" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e9/08/e5efcb30818ed220b818319255c22fd91e379489ebaa93efd6f444fb4987/sarvamai-0.1.21.tar.gz", hash = "sha256:865065635b2b99d40f5519308832954015627938e06a6333b5f62ae9c36278bb", size = 87386, upload-time = "2025-10-07T07:37:47.085Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2e/4e/b9933f72681b7aed91b86913337dd3981fad97027881fbc66c3c5eb03568/sarvamai-0.1.21-py3-none-any.whl", hash = "sha256:daa4e5d16635fe434f5f270cee416849249285369141d77132a17f0bf670f120", size = 175204, upload-time = "2025-10-07T07:37:46.024Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scipy"
|
||||
version = "1.15.3"
|
||||
|
||||
Reference in New Issue
Block a user