Compare commits

...

1 Commits

Author SHA1 Message Date
Mark Backman
32446c40f2 Add a NIM LLM service 2024-12-03 23:10:48 -05:00
2 changed files with 219 additions and 0 deletions

View File

@@ -16,13 +16,19 @@ from PIL import Image
from pydantic import BaseModel, Field
from pipecat.frames.frames import (
<<<<<<< Updated upstream
AudioRawFrame,
=======
CancelFrame,
EndFrame,
>>>>>>> Stashed changes
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMUpdateSettingsFrame,
StartFrame,
TextFrame,
TTSAudioRawFrame,
TTSStartedFrame,
@@ -45,8 +51,12 @@ from pipecat.transcriptions.language import Language
try:
import google.ai.generativelanguage as glm
import google.generativeai as gai
<<<<<<< Updated upstream
from google.cloud import texttospeech_v1
from google.generativeai.types import GenerationConfig
=======
from google.cloud import speech, texttospeech_v1
>>>>>>> Stashed changes
from google.oauth2 import service_account
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
@@ -827,3 +837,107 @@ class GoogleTTSService(TTSService):
yield ErrorFrame(error=error_message)
finally:
yield TTSStoppedFrame()
<<<<<<< Updated upstream
=======
class GoogleSTTService(STTService):
def __init__(
self,
*,
credentials_path: str,
language: Language = Language.EN,
sample_rate: int = 16000,
**kwargs,
):
super().__init__(**kwargs)
self._credentials_path = credentials_path
self._language = language
self._sample_rate = sample_rate
self._client = None
self._streaming_config = None
self._requests_queue = asyncio.Queue()
self._responses = None
async def start(self, frame: StartFrame):
await super().start(frame)
credentials = service_account.Credentials.from_service_account_file(self._credentials_path)
self._client = speech.SpeechClient(credentials=credentials)
config = speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
sample_rate_hertz=self._sample_rate,
language_code=self._language.value,
enable_automatic_punctuation=True,
)
self._streaming_config = speech.StreamingRecognitionConfig(
config=config, interim_results=True
)
# Start the recognition stream
self._responses = self._client.streaming_recognize(
self._streaming_config, self._request_generator()
)
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._requests_queue.put(None) # Signal to stop the request generator
self._client = None
self._streaming_config = None
self._responses = None
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self.stop(EndFrame())
async def set_language(self, language: Language):
self._language = language
# Recreate the streaming config with the new language
if self._client:
config = speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
sample_rate_hertz=self._sample_rate,
language_code=self._language.value,
enable_automatic_punctuation=True,
)
self._streaming_config = speech.StreamingRecognitionConfig(
config=config, interim_results=True
)
# Restart the recognition stream
await self._requests_queue.put(None) # Signal to stop the current request generator
self._responses = self._client.streaming_recognize(
self._streaming_config, self._request_generator()
)
async def _request_generator(self):
while True:
request = await self._requests_queue.get()
if request is None:
break
yield request
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
if not self._client or not self._streaming_config or not self._responses:
raise RuntimeError("GoogleSTTService not started")
# Queue the audio content
await self._requests_queue.put(speech.StreamingRecognizeRequest(audio_content=audio))
# Process the responses
for response in self._responses:
for result in response.results:
if result.alternatives:
transcript = result.alternatives[0].transcript
if result.is_final:
await self.push_frame(
TranscriptionFrame(transcript, "", time_now_iso8601(), self._language)
)
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript, "", time_now_iso8601(), self._language
)
)
yield None
>>>>>>> Stashed changes

105
src/pipecat/services/nim.py Normal file
View File

@@ -0,0 +1,105 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai import OpenAILLMService
class NimLLMService(OpenAILLMService):
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
compatibility with the OpenAI-style interface. It specifically handles the difference
in token usage reporting between NIM (incremental) and OpenAI (final summary).
Args:
api_key (str): The API key for accessing NVIDIA's NIM API
base_url (str, optional): The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1"
model (str, optional): The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct"
**kwargs: Additional keyword arguments passed to OpenAILLMService
Example:
```python
service = NimLLMService(
api_key="your-api-key",
model="nvidia/llama-3.1-nemotron-70b-instruct"
)
```
"""
def __init__(
self,
*,
api_key: str,
base_url: str = "https://integrate.api.nvidia.com/v1",
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
**kwargs,
):
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
# Counters for accumulating token usage metrics
self._prompt_tokens = 0
self._completion_tokens = 0
self._total_tokens = 0
self._has_reported_prompt_tokens = False
self._is_processing = False
async def _process_context(self, context: OpenAILLMContext):
"""Process a context through the LLM and accumulate token usage metrics.
This method overrides the parent class implementation to handle NVIDIA's
incremental token reporting style, accumulating the counts and reporting
them once at the end of processing.
Args:
context (OpenAILLMContext): The context to process, containing messages
and other information needed for the LLM interaction.
"""
# Reset all counters and flags at the start of processing
self._prompt_tokens = 0
self._completion_tokens = 0
self._total_tokens = 0
self._has_reported_prompt_tokens = False
self._is_processing = True
try:
await super()._process_context(context)
finally:
self._is_processing = False
# Report final accumulated token usage at the end of processing
if self._prompt_tokens > 0 or self._completion_tokens > 0:
self._total_tokens = self._prompt_tokens + self._completion_tokens
tokens = LLMTokenUsage(
prompt_tokens=self._prompt_tokens,
completion_tokens=self._completion_tokens,
total_tokens=self._total_tokens,
)
await super().start_llm_usage_metrics(tokens)
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
"""Accumulate token usage metrics during processing.
This method intercepts the incremental token updates from NVIDIA's API
and accumulates them instead of passing each update to the metrics system.
The final accumulated totals are reported at the end of processing.
Args:
tokens (LLMTokenUsage): The token usage metrics for the current chunk
of processing, containing prompt_tokens and completion_tokens counts.
"""
# Only accumulate metrics during active processing
if not self._is_processing:
return
# Record prompt tokens the first time we see them
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
self._prompt_tokens = tokens.prompt_tokens
self._has_reported_prompt_tokens = True
# Update completion tokens count if it has increased
if tokens.completion_tokens > self._completion_tokens:
self._completion_tokens = tokens.completion_tokens