Compare commits
1 Commits
hush/reset
...
mb/nim-llm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
32446c40f2 |
@@ -16,13 +16,19 @@ from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
<<<<<<< Updated upstream
|
||||
AudioRawFrame,
|
||||
=======
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
>>>>>>> Stashed changes
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartFrame,
|
||||
TextFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
@@ -45,8 +51,12 @@ from pipecat.transcriptions.language import Language
|
||||
try:
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.generativeai as gai
|
||||
<<<<<<< Updated upstream
|
||||
from google.cloud import texttospeech_v1
|
||||
from google.generativeai.types import GenerationConfig
|
||||
=======
|
||||
from google.cloud import speech, texttospeech_v1
|
||||
>>>>>>> Stashed changes
|
||||
from google.oauth2 import service_account
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -827,3 +837,107 @@ class GoogleTTSService(TTSService):
|
||||
yield ErrorFrame(error=error_message)
|
||||
finally:
|
||||
yield TTSStoppedFrame()
|
||||
<<<<<<< Updated upstream
|
||||
=======
|
||||
|
||||
|
||||
class GoogleSTTService(STTService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
credentials_path: str,
|
||||
language: Language = Language.EN,
|
||||
sample_rate: int = 16000,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._credentials_path = credentials_path
|
||||
self._language = language
|
||||
self._sample_rate = sample_rate
|
||||
self._client = None
|
||||
self._streaming_config = None
|
||||
self._requests_queue = asyncio.Queue()
|
||||
self._responses = None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
credentials = service_account.Credentials.from_service_account_file(self._credentials_path)
|
||||
self._client = speech.SpeechClient(credentials=credentials)
|
||||
|
||||
config = speech.RecognitionConfig(
|
||||
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=self._sample_rate,
|
||||
language_code=self._language.value,
|
||||
enable_automatic_punctuation=True,
|
||||
)
|
||||
self._streaming_config = speech.StreamingRecognitionConfig(
|
||||
config=config, interim_results=True
|
||||
)
|
||||
|
||||
# Start the recognition stream
|
||||
self._responses = self._client.streaming_recognize(
|
||||
self._streaming_config, self._request_generator()
|
||||
)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._requests_queue.put(None) # Signal to stop the request generator
|
||||
self._client = None
|
||||
self._streaming_config = None
|
||||
self._responses = None
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self.stop(EndFrame())
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
self._language = language
|
||||
# Recreate the streaming config with the new language
|
||||
if self._client:
|
||||
config = speech.RecognitionConfig(
|
||||
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=self._sample_rate,
|
||||
language_code=self._language.value,
|
||||
enable_automatic_punctuation=True,
|
||||
)
|
||||
self._streaming_config = speech.StreamingRecognitionConfig(
|
||||
config=config, interim_results=True
|
||||
)
|
||||
# Restart the recognition stream
|
||||
await self._requests_queue.put(None) # Signal to stop the current request generator
|
||||
self._responses = self._client.streaming_recognize(
|
||||
self._streaming_config, self._request_generator()
|
||||
)
|
||||
|
||||
async def _request_generator(self):
|
||||
while True:
|
||||
request = await self._requests_queue.get()
|
||||
if request is None:
|
||||
break
|
||||
yield request
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
if not self._client or not self._streaming_config or not self._responses:
|
||||
raise RuntimeError("GoogleSTTService not started")
|
||||
|
||||
# Queue the audio content
|
||||
await self._requests_queue.put(speech.StreamingRecognizeRequest(audio_content=audio))
|
||||
|
||||
# Process the responses
|
||||
for response in self._responses:
|
||||
for result in response.results:
|
||||
if result.alternatives:
|
||||
transcript = result.alternatives[0].transcript
|
||||
if result.is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(transcript, "", time_now_iso8601(), self._language)
|
||||
)
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript, "", time_now_iso8601(), self._language
|
||||
)
|
||||
)
|
||||
|
||||
yield None
|
||||
>>>>>>> Stashed changes
|
||||
|
||||
105
src/pipecat/services/nim.py
Normal file
105
src/pipecat/services/nim.py
Normal file
@@ -0,0 +1,105 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
|
||||
|
||||
class NimLLMService(OpenAILLMService):
|
||||
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
|
||||
|
||||
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
|
||||
compatibility with the OpenAI-style interface. It specifically handles the difference
|
||||
in token usage reporting between NIM (incremental) and OpenAI (final summary).
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for accessing NVIDIA's NIM API
|
||||
base_url (str, optional): The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1"
|
||||
model (str, optional): The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct"
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService
|
||||
|
||||
Example:
|
||||
```python
|
||||
service = NimLLMService(
|
||||
api_key="your-api-key",
|
||||
model="nvidia/llama-3.1-nemotron-70b-instruct"
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://integrate.api.nvidia.com/v1",
|
||||
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
# Counters for accumulating token usage metrics
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = False
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
"""Process a context through the LLM and accumulate token usage metrics.
|
||||
|
||||
This method overrides the parent class implementation to handle NVIDIA's
|
||||
incremental token reporting style, accumulating the counts and reporting
|
||||
them once at the end of processing.
|
||||
|
||||
Args:
|
||||
context (OpenAILLMContext): The context to process, containing messages
|
||||
and other information needed for the LLM interaction.
|
||||
"""
|
||||
# Reset all counters and flags at the start of processing
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = True
|
||||
|
||||
try:
|
||||
await super()._process_context(context)
|
||||
finally:
|
||||
self._is_processing = False
|
||||
# Report final accumulated token usage at the end of processing
|
||||
if self._prompt_tokens > 0 or self._completion_tokens > 0:
|
||||
self._total_tokens = self._prompt_tokens + self._completion_tokens
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
)
|
||||
await super().start_llm_usage_metrics(tokens)
|
||||
|
||||
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
|
||||
"""Accumulate token usage metrics during processing.
|
||||
|
||||
This method intercepts the incremental token updates from NVIDIA's API
|
||||
and accumulates them instead of passing each update to the metrics system.
|
||||
The final accumulated totals are reported at the end of processing.
|
||||
|
||||
Args:
|
||||
tokens (LLMTokenUsage): The token usage metrics for the current chunk
|
||||
of processing, containing prompt_tokens and completion_tokens counts.
|
||||
"""
|
||||
# Only accumulate metrics during active processing
|
||||
if not self._is_processing:
|
||||
return
|
||||
|
||||
# Record prompt tokens the first time we see them
|
||||
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
|
||||
self._prompt_tokens = tokens.prompt_tokens
|
||||
self._has_reported_prompt_tokens = True
|
||||
|
||||
# Update completion tokens count if it has increased
|
||||
if tokens.completion_tokens > self._completion_tokens:
|
||||
self._completion_tokens = tokens.completion_tokens
|
||||
Reference in New Issue
Block a user