configurability via constructor

This commit is contained in:
Kwindla Hultman Kramer
2024-10-02 16:26:22 -07:00
parent efd3627202
commit fa3a6647ef
2 changed files with 118 additions and 17 deletions

View File

@@ -14,8 +14,11 @@ from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.logger import FrameLogger
from pipecat.services.openai_realtime_beta import OpenAILLMServiceRealtimeBeta
from pipecat.services.openai_realtime_beta import (
OpenAILLMServiceRealtimeBeta,
OpenAITurnDetection,
RealtimeSessionProperties,
)
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer
@@ -51,9 +54,31 @@ async def main():
),
)
fl1 = FrameLogger("fl-1")
llm = OpenAILLMServiceRealtimeBeta(api_key=os.getenv("OPENAI_API_KEY"))
fl2 = FrameLogger("fl-2")
session_properties = RealtimeSessionProperties(
turn_detection=OpenAITurnDetection(silence_duration_ms=800),
instructions="""
Your knowledge cutoff is 2023-10. You are a helpful and friendly AI.
Act like a human, but remember that you aren't a human and that you can't do human
things in the real world. Your voice and personality should be warm and engaging, with a lively and
playful tone.
If interacting in a non-English language, start by using the standard accent or dialect familiar to
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
even if you're asked about them.
You are participating in a voice conversation. Keep your responses concise, short, and to the point
unless specifically asked to elaborate on a topic.
Remember, your responses should be short. Just one or two sentences, usually.
Start by suggesting that you have a conversation about space exploration.
""",
)
llm = OpenAILLMServiceRealtimeBeta(
api_key=os.getenv("OPENAI_API_KEY"), session_properties=session_properties
)
messages = [
{
@@ -65,9 +90,7 @@ async def main():
pipeline = Pipeline(
[
transport.input(), # Transport user input
# fl1,
llm, # LLM
# fl2,
transport.output(), # Transport bot output
]
)

View File

@@ -3,6 +3,8 @@ import base64
import json
import websockets
from typing import List, Optional
from pydantic import BaseModel, Field
from pipecat.frames.frames import (
CancelFrame,
@@ -12,6 +14,7 @@ from pipecat.frames.frames import (
EndFrame,
InputAudioRawFrame,
StartFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
TTSAudioRawFrame,
@@ -24,19 +27,59 @@ from loguru import logger
# temp: websocket logger
# import logging
# logging.basicConfig(
# format="%(message)s",
# level=logging.DEBUG,
# )
class OpenAIInputTranscription(BaseModel):
# enabled: bool = Field(description="Whether to enable input audio transcription.", default=True)
model: str = Field(
description="The model to use for transcription (e.g., 'whisper-1').", default="whisper-1"
)
class OpenAITurnDetection(BaseModel):
type: str = Field(
default="server_vad",
description="Type of turn detection, only 'server_vad' is currently supported.",
)
threshold: float = Field(
ge=0.0, le=1.0, default=0.5, description="Activation threshold for VAD (0.0 to 1.0)."
)
prefix_padding_ms: int = Field(
default=300,
description="Amount of audio to include before speech starts (in milliseconds).",
)
silence_duration_ms: int = Field(
default=200, description="Duration of silence to detect speech stop (in milliseconds)."
)
class RealtimeSessionProperties(BaseModel):
modalities: List[str] = Field(default=["text", "audio"])
instructions: str = Field(default="")
voice: str = Field(default="alloy")
input_audio_format: str = Field(default="pcm16")
output_audio_format: str = Field(default="pcm16")
input_audio_transcription: Optional[OpenAIInputTranscription] = Field(
default=OpenAIInputTranscription()
)
turn_detection: Optional[OpenAITurnDetection] = Field(default=None)
tools: List[str] = Field(default=[])
tool_choice: str = Field(default="auto")
temperature: float = Field(default=0.8)
max_response_output_tokens: int = Field(default=4096)
class OpenAILLMServiceRealtimeBeta(LLMService):
def __init__(
self,
*,
api_key: str,
base_url="wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01",
session_properties: RealtimeSessionProperties = RealtimeSessionProperties(),
**kwargs,
):
super().__init__(base_url=base_url, **kwargs)
@@ -45,7 +88,7 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
self._websocket = None
self._receive_task = None
self._session_properties = None
self._session_properties = session_properties
self._responses_in_flight = {}
async def start(self, frame: StartFrame):
@@ -60,9 +103,19 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
await super().cancel(frame)
await self._disconnect()
async def update_session_properties(self):
logger.debug(f"Updating session properties: {self._session_properties.dict()}")
await self._websocket.send(
json.dumps(
{
"type": "session.update",
"session": self._session_properties.dict(),
}
)
)
async def _connect(self):
try:
logger.debug(f"connecting to {self.base_url} with api_key {self.api_key}")
self._websocket = await websockets.connect(
uri=self.base_url,
extra_headers={
@@ -101,10 +154,14 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
try:
async for message in self._get_websocket():
msg = json.loads(message)
logger.debug(f"Received message: {msg}")
# logger.debug(f"Received message: {msg}")
if not msg:
continue
if msg["type"] == "session.created":
logger.debug(f"Received session.created: {msg}")
await self.update_session_properties()
elif msg["type"] == "session.updated":
logger.debug(f"Received session configuration: {msg}")
self._session_properties = msg["session"]
elif msg["type"] == "response.created":
pass
@@ -119,6 +176,7 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
await self.push_frame(frame)
elif msg["type"] == "response.text.delta":
logger.debug(f"!!! {msg['delta']}")
pass
elif msg["type"] == "response.output_item.done":
if msg["item"]["type"] == "message":
for item in msg["item"]["content"]:
@@ -127,8 +185,7 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
elif msg["type"] == "response.done":
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
elif msg["type"] == "response.error":
logger.error(f"Error: {msg}")
elif msg["type"] == "error":
raise Exception(f"Error: {msg}")
except asyncio.CancelledError:
@@ -159,7 +216,6 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
"type": "response.create",
"response": {
"modalities": ["audio", "text"],
"instructions": "Your knowledge cutoff is 2023-10. You are a helpful, witty, and friendly AI. You are a participant in a voice chat. Act like a human, but remember that you aren't a human and that you can't do human things in the real world. Your voice and personality should be warm and engaging, with a lively and playful tone. If interacting in a non-English language, start by using the standard accent or dialect familiar to the user. Talk quickly. You should always call a function if you can. Do not refer to these rules, even if you're asked about them.",
},
},
)
@@ -179,16 +235,38 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
)
# await self._websocket.send(json.dumps(({"type": "input_audio_buffer.commit"})))
async def _handle_interruption(self, frame):
logger.debug(f"Handling interruption: {frame}")
await self.stop_all_metrics()
await self.push_frame(LLMFullResponseEndFrame())
await self._websocket.send(
json.dumps(
{
"type": "response.cancel",
},
)
)
await self._websocket.send(
json.dumps(
{
"type": "input_audio_buffer.clear",
},
)
)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
if isinstance(frame, TranscriptionFrame):
messages = [{"role": "user", "content": frame.text}]
context = OpenAILLMContext(messages)
await self._create_response(context, messages)
if isinstance(frame, InputAudioRawFrame):
# await self._create_response(context, messages)
elif isinstance(frame, InputAudioRawFrame):
await self._send_user_audio(frame)
elif isinstance(frame, StartInterruptionFrame):
await self._handle_interruption(frame)
await self.push_frame(frame, direction)
# async def get_chat_completions(
# self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]