From 32446c40f2ae9cbc6bbd64f666b7f8171c8bc265 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Tue, 3 Dec 2024 23:10:48 -0500 Subject: [PATCH] Add a NIM LLM service --- src/pipecat/services/google.py | 114 +++++++++++++++++++++++++++++++++ src/pipecat/services/nim.py | 105 ++++++++++++++++++++++++++++++ 2 files changed, 219 insertions(+) create mode 100644 src/pipecat/services/nim.py diff --git a/src/pipecat/services/google.py b/src/pipecat/services/google.py index be8d80d1f..8f96d016b 100644 --- a/src/pipecat/services/google.py +++ b/src/pipecat/services/google.py @@ -16,13 +16,19 @@ from PIL import Image from pydantic import BaseModel, Field from pipecat.frames.frames import ( +<<<<<<< Updated upstream AudioRawFrame, +======= + CancelFrame, + EndFrame, +>>>>>>> Stashed changes ErrorFrame, Frame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesFrame, LLMUpdateSettingsFrame, + StartFrame, TextFrame, TTSAudioRawFrame, TTSStartedFrame, @@ -45,8 +51,12 @@ from pipecat.transcriptions.language import Language try: import google.ai.generativelanguage as glm import google.generativeai as gai +<<<<<<< Updated upstream from google.cloud import texttospeech_v1 from google.generativeai.types import GenerationConfig +======= + from google.cloud import speech, texttospeech_v1 +>>>>>>> Stashed changes from google.oauth2 import service_account except ModuleNotFoundError as e: logger.error(f"Exception: {e}") @@ -827,3 +837,107 @@ class GoogleTTSService(TTSService): yield ErrorFrame(error=error_message) finally: yield TTSStoppedFrame() +<<<<<<< Updated upstream +======= + + +class GoogleSTTService(STTService): + def __init__( + self, + *, + credentials_path: str, + language: Language = Language.EN, + sample_rate: int = 16000, + **kwargs, + ): + super().__init__(**kwargs) + self._credentials_path = credentials_path + self._language = language + self._sample_rate = sample_rate + self._client = None + self._streaming_config = None + self._requests_queue = asyncio.Queue() + self._responses = None + + async def start(self, frame: StartFrame): + await super().start(frame) + credentials = service_account.Credentials.from_service_account_file(self._credentials_path) + self._client = speech.SpeechClient(credentials=credentials) + + config = speech.RecognitionConfig( + encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16, + sample_rate_hertz=self._sample_rate, + language_code=self._language.value, + enable_automatic_punctuation=True, + ) + self._streaming_config = speech.StreamingRecognitionConfig( + config=config, interim_results=True + ) + + # Start the recognition stream + self._responses = self._client.streaming_recognize( + self._streaming_config, self._request_generator() + ) + + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._requests_queue.put(None) # Signal to stop the request generator + self._client = None + self._streaming_config = None + self._responses = None + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self.stop(EndFrame()) + + async def set_language(self, language: Language): + self._language = language + # Recreate the streaming config with the new language + if self._client: + config = speech.RecognitionConfig( + encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16, + sample_rate_hertz=self._sample_rate, + language_code=self._language.value, + enable_automatic_punctuation=True, + ) + self._streaming_config = speech.StreamingRecognitionConfig( + config=config, interim_results=True + ) + # Restart the recognition stream + await self._requests_queue.put(None) # Signal to stop the current request generator + self._responses = self._client.streaming_recognize( + self._streaming_config, self._request_generator() + ) + + async def _request_generator(self): + while True: + request = await self._requests_queue.get() + if request is None: + break + yield request + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + if not self._client or not self._streaming_config or not self._responses: + raise RuntimeError("GoogleSTTService not started") + + # Queue the audio content + await self._requests_queue.put(speech.StreamingRecognizeRequest(audio_content=audio)) + + # Process the responses + for response in self._responses: + for result in response.results: + if result.alternatives: + transcript = result.alternatives[0].transcript + if result.is_final: + await self.push_frame( + TranscriptionFrame(transcript, "", time_now_iso8601(), self._language) + ) + else: + await self.push_frame( + InterimTranscriptionFrame( + transcript, "", time_now_iso8601(), self._language + ) + ) + + yield None +>>>>>>> Stashed changes diff --git a/src/pipecat/services/nim.py b/src/pipecat/services/nim.py new file mode 100644 index 000000000..0ce0171c9 --- /dev/null +++ b/src/pipecat/services/nim.py @@ -0,0 +1,105 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + + +from pipecat.metrics.metrics import LLMTokenUsage +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.openai import OpenAILLMService + + +class NimLLMService(OpenAILLMService): + """A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API. + + This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining + compatibility with the OpenAI-style interface. It specifically handles the difference + in token usage reporting between NIM (incremental) and OpenAI (final summary). + + Args: + api_key (str): The API key for accessing NVIDIA's NIM API + base_url (str, optional): The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1" + model (str, optional): The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct" + **kwargs: Additional keyword arguments passed to OpenAILLMService + + Example: + ```python + service = NimLLMService( + api_key="your-api-key", + model="nvidia/llama-3.1-nemotron-70b-instruct" + ) + ``` + """ + + def __init__( + self, + *, + api_key: str, + base_url: str = "https://integrate.api.nvidia.com/v1", + model: str = "nvidia/llama-3.1-nemotron-70b-instruct", + **kwargs, + ): + super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs) + # Counters for accumulating token usage metrics + self._prompt_tokens = 0 + self._completion_tokens = 0 + self._total_tokens = 0 + self._has_reported_prompt_tokens = False + self._is_processing = False + + async def _process_context(self, context: OpenAILLMContext): + """Process a context through the LLM and accumulate token usage metrics. + + This method overrides the parent class implementation to handle NVIDIA's + incremental token reporting style, accumulating the counts and reporting + them once at the end of processing. + + Args: + context (OpenAILLMContext): The context to process, containing messages + and other information needed for the LLM interaction. + """ + # Reset all counters and flags at the start of processing + self._prompt_tokens = 0 + self._completion_tokens = 0 + self._total_tokens = 0 + self._has_reported_prompt_tokens = False + self._is_processing = True + + try: + await super()._process_context(context) + finally: + self._is_processing = False + # Report final accumulated token usage at the end of processing + if self._prompt_tokens > 0 or self._completion_tokens > 0: + self._total_tokens = self._prompt_tokens + self._completion_tokens + tokens = LLMTokenUsage( + prompt_tokens=self._prompt_tokens, + completion_tokens=self._completion_tokens, + total_tokens=self._total_tokens, + ) + await super().start_llm_usage_metrics(tokens) + + async def start_llm_usage_metrics(self, tokens: LLMTokenUsage): + """Accumulate token usage metrics during processing. + + This method intercepts the incremental token updates from NVIDIA's API + and accumulates them instead of passing each update to the metrics system. + The final accumulated totals are reported at the end of processing. + + Args: + tokens (LLMTokenUsage): The token usage metrics for the current chunk + of processing, containing prompt_tokens and completion_tokens counts. + """ + # Only accumulate metrics during active processing + if not self._is_processing: + return + + # Record prompt tokens the first time we see them + if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0: + self._prompt_tokens = tokens.prompt_tokens + self._has_reported_prompt_tokens = True + + # Update completion tokens count if it has increased + if tokens.completion_tokens > self._completion_tokens: + self._completion_tokens = tokens.completion_tokens