Compare commits

...

24 Commits

Author SHA1 Message Date
Aleix Conchillo Flaqué
8d6b8b035e Merge pull request #332 from pipecat-ai/aleix/allow-internal-http-sessions
services: allow internal http sessions if none is given
2024-07-31 15:51:52 -07:00
Aleix Conchillo Flaqué
0a15874c12 services: allow internal http sessions if none is given 2024-07-30 17:44:18 -07:00
Aleix Conchillo Flaqué
d60e99a043 examples(06a-image-sync): make sure frames go downstream 2024-07-30 11:41:58 -07:00
Aleix Conchillo Flaqué
77723b34c7 EndFrame tries to end gracefully CancelFrame cancels tasks 2024-07-30 11:41:19 -07:00
Aleix Conchillo Flaqué
c466d34a06 Merge pull request #328 from pipecat-ai/aleix/rtvi-towards-custom-pipelines
processors(rtvi): refactor to allow future custom pipelines
2024-07-29 15:07:57 -07:00
Aleix Conchillo Flaqué
f816897833 Merge pull request #327 from pipecat-ai/aleix/bot-start-stop-speaking-frames
bot start stop speaking frames
2024-07-27 17:21:23 -07:00
Aleix Conchillo Flaqué
c1e8a5e522 processors(rtvi): refactor to allow future custom pipelines 2024-07-26 10:26:36 -07:00
Aleix Conchillo Flaqué
76aca32f2e transport(output): emit new bot start|stop speaking frames 2024-07-25 14:50:33 -07:00
Aleix Conchillo Flaqué
7e31b2a795 processors(user_idle): use user speaking instead of interruption frames 2024-07-25 14:47:56 -07:00
Aleix Conchillo Flaqué
028e38a86b Merge pull request #326 from pipecat-ai/aleix/rtvi-bot-ready-fixes
rtvi: send bot-ready when pipeline is ready and first participant joins
2024-07-25 11:39:14 -07:00
Aleix Conchillo Flaqué
8cf7649855 processors(rtvi): send bot-ready when pipeline AND first participant joins 2024-07-25 11:25:51 -07:00
Aleix Conchillo Flaqué
64f5119b08 transports(base): allow registering event handlers without decorators 2024-07-25 11:24:24 -07:00
Aleix Conchillo Flaqué
4d606aefb3 update CHANGELOG 2024-07-25 09:57:01 -07:00
Ankur Duggal
4bafdaa04d Deepgram Adjustments (#313) 2024-07-25 09:51:51 -07:00
Aleix Conchillo Flaqué
5afe1abf82 Merge pull request #323 from pipecat-ai/aleix/base-input-handle-incoming-interruptions
transports(inputs): handle start/stop interruption frames
2024-07-24 15:16:18 -07:00
Aleix Conchillo Flaqué
f066d50b98 transports(inputs): handle start/stop interruption frames 2024-07-24 15:15:09 -07:00
Aleix Conchillo Flaqué
91103e21cc github(publish_test): download tags and depth to 100 2024-07-24 14:49:09 -07:00
Aleix Conchillo Flaqué
f44dabcd65 Merge pull request #322 from pipecat-ai/aleix/base-input-transport-system-frames-fix
transports(inputs): don't queue incoming system frames
2024-07-24 14:44:18 -07:00
Aleix Conchillo Flaqué
0fd2fca231 frames: StartFrame is now a control frame 2024-07-24 14:42:59 -07:00
Aleix Conchillo Flaqué
5bb64098e7 transports(inputs): don't queue incoming system frames 2024-07-24 14:35:00 -07:00
Aleix Conchillo Flaqué
3fc85e75e0 Merge pull request #320 from pipecat-ai/aleix/req-updates-072324
update project requirements and dependencies
2024-07-23 17:45:18 -07:00
Aleix Conchillo Flaqué
3f61ea16b7 update project requirements and dependencies 2024-07-23 17:35:47 -07:00
Aleix Conchillo Flaqué
4b393092b5 Merge pull request #319 from pipecat-ai/aleix/daily-completion-callbacks-0.0.39-fix
transports(daily): fix completion callbacks handling
2024-07-23 15:27:26 -07:00
Aleix Conchillo Flaqué
b583f5162b transports(daily): fix completion callbacks handling 2024-07-23 15:25:59 -07:00
25 changed files with 603 additions and 368 deletions

View File

@@ -9,6 +9,9 @@ jobs:
steps:
- name: Checkout repo
uses: actions/checkout@v4
with:
fetch-tags: true
fetch-depth: 100
- name: Set up Python
id: setup_python
uses: actions/setup-python@v4

View File

@@ -5,6 +5,43 @@ All notable changes to **pipecat** will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added
- Added new `BotStartedSpeakingFrame` and `BotStoppedSpeakingFrame` control
frames. These frames are pushed upstream and they should wrap
`BotSpeakingFrame`.
- Transports now allow you to register event handlers without decorators.
### Changed
- `BotSpeakingFrame` is now a control frame.
- `StartFrame` is now a control frame similar to `EndFrame`.
- `DeepgramTTSService` now is more customizable. You can adjust the encoding and
sample rate.
### Fixed
- RTVI's `bot-ready` message is now sent when the RTVI pipeline is ready and
a first participant joins.
- Fixed a `BaseInputTransport` issue that was causing incoming system frames to
be queued instead of being pushed immediately.
- Fixed a `BaseInputTransport` issue that was causing start/stop interruptions
incoming frames to not cancel tasks and be processed properly.
## [0.0.39] - 2024-07-23
### Fixed
- Fixed a regression introduced in 0.0.38 that would cause Daily transcription
to stop the Pipeline.
## [0.0.38] - 2024-07-23
### Added

View File

@@ -4,5 +4,5 @@ grpcio-tools~=1.62.2
pip-tools~=7.4.1
pyright~=1.1.367
pytest~=8.2.0
setuptools~=69.5.1
setuptools~=71.1.0
setuptools_scm~=8.1.0

View File

@@ -51,7 +51,7 @@ class ImageSyncAggregator(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if not isinstance(frame, SystemFrame):
if not isinstance(frame, SystemFrame) and direction == FrameDirection.DOWNSTREAM:
await self.push_frame(ImageRawFrame(image=self._speaking_image_bytes, size=(1024, 1024), format=self._speaking_image_format))
await self.push_frame(frame)
await self.push_frame(ImageRawFrame(image=self._waiting_image_bytes, size=(1024, 1024), format=self._waiting_image_format))

View File

@@ -8,7 +8,6 @@ aiofiles==24.1.0
# via deepgram-sdk
aiohttp==3.9.5
# via
# cartesia
# deepgram-sdk
# langchain
# langchain-community
@@ -36,17 +35,15 @@ attrs==23.2.0
# via
# aiohttp
# openpipe
av==12.2.0
av==12.3.0
# via faster-whisper
azure-cognitiveservices-speech==1.38.0
# via pipecat-ai (pyproject.toml)
blinker==1.8.2
# via flask
cachetools==5.3.3
cachetools==5.4.0
# via google-auth
cartesia==1.0.3
# via pipecat-ai (pyproject.toml)
certifi==2024.6.2
certifi==2024.7.4
# via
# httpcore
# httpx
@@ -80,13 +77,11 @@ einops==0.8.0
# via pipecat-ai (pyproject.toml)
email-validator==2.2.0
# via fastapi
exceptiongroup==1.2.1
# via
# anyio
# pytest
exceptiongroup==1.2.2
# via anyio
fal-client==0.4.1
# via pipecat-ai (pyproject.toml)
fastapi==0.111.0
fastapi==0.111.1
# via pipecat-ai (pyproject.toml)
fastapi-cli==0.0.4
# via fastapi
@@ -124,9 +119,9 @@ google-api-core[grpc]==2.19.1
# google-ai-generativelanguage
# google-api-python-client
# google-generativeai
google-api-python-client==2.135.0
google-api-python-client==2.137.0
# via google-generativeai
google-auth==2.31.0
google-auth==2.32.0
# via
# google-ai-generativelanguage
# google-api-core
@@ -135,7 +130,7 @@ google-auth==2.31.0
# google-generativeai
google-auth-httplib2==0.2.0
# via google-api-python-client
google-generativeai==0.7.1
google-generativeai==0.7.2
# via pipecat-ai (pyproject.toml)
googleapis-common-protos==1.63.2
# via
@@ -143,7 +138,7 @@ googleapis-common-protos==1.63.2
# grpcio-status
greenlet==3.0.3
# via sqlalchemy
grpcio==1.64.1
grpcio==1.65.1
# via
# google-api-core
# grpcio-status
@@ -165,7 +160,6 @@ httptools==0.6.1
httpx==0.27.0
# via
# anthropic
# cartesia
# deepgram-sdk
# fal-client
# fastapi
@@ -173,7 +167,7 @@ httpx==0.27.0
# openpipe
httpx-sse==0.4.0
# via fal-client
huggingface-hub==0.23.4
huggingface-hub==0.24.1
# via
# faster-whisper
# timm
@@ -188,8 +182,6 @@ idna==3.7
# httpx
# requests
# yarl
iniconfig==2.0.0
# via pytest
itsdangerous==2.2.0
# via flask
jinja2==3.1.4
@@ -203,23 +195,23 @@ jsonpatch==1.33
# via langchain-core
jsonpointer==3.0.0
# via jsonpatch
langchain==0.2.6
langchain==0.2.11
# via
# langchain-community
# pipecat-ai (pyproject.toml)
langchain-community==0.2.6
langchain-community==0.2.10
# via pipecat-ai (pyproject.toml)
langchain-core==0.2.10
langchain-core==0.2.23
# via
# langchain
# langchain-community
# langchain-openai
# langchain-text-splitters
langchain-openai==0.1.10
langchain-openai==0.1.17
# via pipecat-ai (pyproject.toml)
langchain-text-splitters==0.2.2
# via langchain
langsmith==0.1.83
langsmith==0.1.93
# via
# langchain
# langchain-community
@@ -296,31 +288,26 @@ nvidia-nvtx-cu12==12.1.105
# via torch
onnxruntime==1.18.1
# via faster-whisper
openai==1.27.0
openai==1.35.15
# via
# langchain-openai
# openpipe
# pipecat-ai (pyproject.toml)
openpipe==4.16.0
openpipe==4.18.0
# via pipecat-ai (pyproject.toml)
orjson==3.10.5
# via
# fastapi
# langsmith
orjson==3.10.6
# via langsmith
packaging==24.1
# via
# huggingface-hub
# langchain-core
# marshmallow
# onnxruntime
# pytest
# transformers
pillow==10.3.0
# via
# pipecat-ai (pyproject.toml)
# torchvision
pluggy==1.5.0
# via pytest
proto-plus==1.24.0
# via
# google-ai-generativelanguage
@@ -344,7 +331,7 @@ pyasn1-modules==0.4.0
# via google-auth
pyaudio==0.2.14
# via pipecat-ai (pyproject.toml)
pydantic==2.8.0
pydantic==2.8.2
# via
# anthropic
# fastapi
@@ -353,7 +340,7 @@ pydantic==2.8.0
# langchain-core
# langsmith
# openai
pydantic-core==2.20.0
pydantic-core==2.20.1
# via pydantic
pygments==2.18.0
# via rich
@@ -363,10 +350,6 @@ pyloudnorm==0.1.1
# via pipecat-ai (pyproject.toml)
pyparsing==3.1.2
# via httplib2
pytest==8.2.2
# via pytest-asyncio
pytest-asyncio==0.23.7
# via cartesia
python-dateutil==2.9.0.post0
# via openpipe
python-dotenv==1.0.1
@@ -391,7 +374,6 @@ regex==2024.5.15
# transformers
requests==2.32.3
# via
# cartesia
# google-api-core
# huggingface-hub
# langchain
@@ -428,11 +410,11 @@ sqlalchemy==2.0.31
# langchain-community
starlette==0.37.2
# via fastapi
sympy==1.12.1
sympy==1.13.1
# via
# onnxruntime
# torch
tenacity==8.4.2
tenacity==8.5.0
# via
# langchain
# langchain-community
@@ -446,8 +428,6 @@ tokenizers==0.19.1
# anthropic
# faster-whisper
# transformers
tomli==2.0.1
# via pytest
torch==2.3.1
# via
# pipecat-ai (pyproject.toml)
@@ -489,13 +469,11 @@ typing-extensions==4.12.2
# uvicorn
typing-inspect==0.9.0
# via dataclasses-json
ujson==5.10.0
# via fastapi
uritemplate==4.1.1
# via google-api-python-client
urllib3==2.2.2
# via requests
uvicorn[standard]==0.30.1
uvicorn[standard]==0.30.3
# via fastapi
uvloop==0.19.0
# via uvicorn
@@ -505,7 +483,6 @@ watchfiles==0.22.0
# via uvicorn
websockets==12.0
# via
# cartesia
# deepgram-sdk
# pipecat-ai (pyproject.toml)
# uvicorn

View File

@@ -8,7 +8,6 @@ aiofiles==24.1.0
# via deepgram-sdk
aiohttp==3.9.5
# via
# cartesia
# deepgram-sdk
# langchain
# langchain-community
@@ -36,17 +35,15 @@ attrs==23.2.0
# via
# aiohttp
# openpipe
av==12.2.0
av==12.3.0
# via faster-whisper
azure-cognitiveservices-speech==1.38.0
# via pipecat-ai (pyproject.toml)
blinker==1.8.2
# via flask
cachetools==5.3.3
cachetools==5.4.0
# via google-auth
cartesia==1.0.3
# via pipecat-ai (pyproject.toml)
certifi==2024.6.2
certifi==2024.7.4
# via
# httpcore
# httpx
@@ -80,13 +77,11 @@ einops==0.8.0
# via pipecat-ai (pyproject.toml)
email-validator==2.2.0
# via fastapi
exceptiongroup==1.2.1
# via
# anyio
# pytest
exceptiongroup==1.2.2
# via anyio
fal-client==0.4.1
# via pipecat-ai (pyproject.toml)
fastapi==0.111.0
fastapi==0.111.1
# via pipecat-ai (pyproject.toml)
fastapi-cli==0.0.4
# via fastapi
@@ -123,9 +118,9 @@ google-api-core[grpc]==2.19.1
# google-ai-generativelanguage
# google-api-python-client
# google-generativeai
google-api-python-client==2.135.0
google-api-python-client==2.137.0
# via google-generativeai
google-auth==2.31.0
google-auth==2.32.0
# via
# google-ai-generativelanguage
# google-api-core
@@ -134,13 +129,13 @@ google-auth==2.31.0
# google-generativeai
google-auth-httplib2==0.2.0
# via google-api-python-client
google-generativeai==0.7.1
google-generativeai==0.7.2
# via pipecat-ai (pyproject.toml)
googleapis-common-protos==1.63.2
# via
# google-api-core
# grpcio-status
grpcio==1.64.1
grpcio==1.65.1
# via
# google-api-core
# grpcio-status
@@ -162,7 +157,6 @@ httptools==0.6.1
httpx==0.27.0
# via
# anthropic
# cartesia
# deepgram-sdk
# fal-client
# fastapi
@@ -170,7 +164,7 @@ httpx==0.27.0
# openpipe
httpx-sse==0.4.0
# via fal-client
huggingface-hub==0.23.4
huggingface-hub==0.24.1
# via
# faster-whisper
# timm
@@ -185,8 +179,6 @@ idna==3.7
# httpx
# requests
# yarl
iniconfig==2.0.0
# via pytest
itsdangerous==2.2.0
# via flask
jinja2==3.1.4
@@ -200,23 +192,23 @@ jsonpatch==1.33
# via langchain-core
jsonpointer==3.0.0
# via jsonpatch
langchain==0.2.6
langchain==0.2.11
# via
# langchain-community
# pipecat-ai (pyproject.toml)
langchain-community==0.2.6
langchain-community==0.2.10
# via pipecat-ai (pyproject.toml)
langchain-core==0.2.10
langchain-core==0.2.23
# via
# langchain
# langchain-community
# langchain-openai
# langchain-text-splitters
langchain-openai==0.1.10
langchain-openai==0.1.17
# via pipecat-ai (pyproject.toml)
langchain-text-splitters==0.2.2
# via langchain
langsmith==0.1.83
langsmith==0.1.93
# via
# langchain
# langchain-community
@@ -262,31 +254,26 @@ numpy==1.26.4
# transformers
onnxruntime==1.18.1
# via faster-whisper
openai==1.27.0
openai==1.35.15
# via
# langchain-openai
# openpipe
# pipecat-ai (pyproject.toml)
openpipe==4.16.0
openpipe==4.18.0
# via pipecat-ai (pyproject.toml)
orjson==3.10.5
# via
# fastapi
# langsmith
orjson==3.10.6
# via langsmith
packaging==24.1
# via
# huggingface-hub
# langchain-core
# marshmallow
# onnxruntime
# pytest
# transformers
pillow==10.3.0
# via
# pipecat-ai (pyproject.toml)
# torchvision
pluggy==1.5.0
# via pytest
proto-plus==1.24.0
# via
# google-ai-generativelanguage
@@ -310,7 +297,7 @@ pyasn1-modules==0.4.0
# via google-auth
pyaudio==0.2.14
# via pipecat-ai (pyproject.toml)
pydantic==2.8.0
pydantic==2.8.2
# via
# anthropic
# fastapi
@@ -319,7 +306,7 @@ pydantic==2.8.0
# langchain-core
# langsmith
# openai
pydantic-core==2.20.0
pydantic-core==2.20.1
# via pydantic
pygments==2.18.0
# via rich
@@ -329,10 +316,6 @@ pyloudnorm==0.1.1
# via pipecat-ai (pyproject.toml)
pyparsing==3.1.2
# via httplib2
pytest==8.2.2
# via pytest-asyncio
pytest-asyncio==0.23.7
# via cartesia
python-dateutil==2.9.0.post0
# via openpipe
python-dotenv==1.0.1
@@ -357,7 +340,6 @@ regex==2024.5.15
# transformers
requests==2.32.3
# via
# cartesia
# google-api-core
# huggingface-hub
# langchain
@@ -394,11 +376,11 @@ sqlalchemy==2.0.31
# langchain-community
starlette==0.37.2
# via fastapi
sympy==1.12.1
sympy==1.13.1
# via
# onnxruntime
# torch
tenacity==8.4.2
tenacity==8.5.0
# via
# langchain
# langchain-community
@@ -412,8 +394,6 @@ tokenizers==0.19.1
# anthropic
# faster-whisper
# transformers
tomli==2.0.1
# via pytest
torch==2.3.1
# via
# pipecat-ai (pyproject.toml)
@@ -453,13 +433,11 @@ typing-extensions==4.12.2
# uvicorn
typing-inspect==0.9.0
# via dataclasses-json
ujson==5.10.0
# via fastapi
uritemplate==4.1.1
# via google-api-python-client
urllib3==2.2.2
# via requests
uvicorn[standard]==0.30.1
uvicorn[standard]==0.30.3
# via fastapi
uvloop==0.19.0
# via uvicorn
@@ -469,7 +447,6 @@ watchfiles==0.22.0
# via uvicorn
websockets==12.0
# via
# cartesia
# deepgram-sdk
# pipecat-ai (pyproject.toml)
# uvicorn

View File

@@ -43,12 +43,12 @@ examples = [ "python-dotenv~=1.0.0", "flask~=3.0.3", "flask_cors~=4.0.1" ]
fal = [ "fal-client~=0.4.1" ]
gladia = [ "websockets~=12.0" ]
google = [ "google-generativeai~=0.7.1" ]
fireworks = [ "openai~=1.27.0" ]
langchain = [ "langchain~=0.2.6", "langchain-community~=0.2.6", "langchain-openai~=0.1.10" ]
fireworks = [ "openai~=1.35.0" ]
langchain = [ "langchain~=0.2.10", "langchain-community~=0.2.9", "langchain-openai~=0.1.17" ]
local = [ "pyaudio~=0.2.0" ]
moondream = [ "einops~=0.8.0", "timm~=0.9.16", "transformers~=4.40.2" ]
openai = [ "openai~=1.27.0" ]
openpipe = [ "openpipe~=4.16.0" ]
openai = [ "openai~=1.35.0" ]
openpipe = [ "openpipe~=4.18.0" ]
playht = [ "pyht~=0.0.28" ]
silero = [ "torch~=2.3.1", "torchaudio~=2.3.1" ]
websocket = [ "websockets~=12.0", "fastapi~=0.111.0" ]

View File

@@ -212,14 +212,6 @@ class SystemFrame(Frame):
pass
@dataclass
class StartFrame(SystemFrame):
"""This is the first frame that should be pushed down a pipeline."""
allow_interruptions: bool = False
enable_metrics: bool = False
report_only_initial_ttfb: bool = False
@dataclass
class CancelFrame(SystemFrame):
"""Indicates that a pipeline needs to stop right away."""
@@ -278,17 +270,6 @@ class BotInterruptionFrame(SystemFrame):
pass
@dataclass
class BotSpeakingFrame(SystemFrame):
"""Emitted by transport outputs while the bot is still speaking. This can be
used, for example, to detect when a user is idle. That is, while the bot is
speaking we don't want to trigger any user idle timeout since the user might
be listening.
"""
pass
@dataclass
class MetricsFrame(SystemFrame):
"""Emitted by processor that can compute metrics like latencies.
@@ -306,6 +287,14 @@ class ControlFrame(Frame):
pass
@dataclass
class StartFrame(ControlFrame):
"""This is the first frame that should be pushed down a pipeline."""
allow_interruptions: bool = False
enable_metrics: bool = False
report_only_initial_ttfb: bool = False
@dataclass
class EndFrame(ControlFrame):
"""Indicates that a pipeline has ended and frame processors and pipelines
@@ -348,6 +337,33 @@ class UserStoppedSpeakingFrame(ControlFrame):
pass
@dataclass
class BotStartedSpeakingFrame(ControlFrame):
"""Emitted upstream by transport outputs to indicate the bot started speaking.
"""
pass
@dataclass
class BotStoppedSpeakingFrame(ControlFrame):
"""Emitted upstream by transport outputs to indicate the bot stopped speaking.
"""
pass
@dataclass
class BotSpeakingFrame(ControlFrame):
"""Emitted upstream by transport outputs while the bot is still
speaking. This can be used, for example, to detect when a user is idle. That
is, while the bot is speaking we don't want to trigger any user idle timeout
since the user might be listening.
"""
pass
@dataclass
class TTSStartedFrame(ControlFrame):
"""Used to indicate the beginning of a TTS response. Following

View File

@@ -7,11 +7,13 @@
import asyncio
import dataclasses
from typing import List, Literal, Optional, Type
from pydantic import BaseModel, ValidationError
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Type
from pydantic import PrivateAttr, BaseModel, ValidationError
from pipecat.frames.frames import (
BotInterruptionFrame,
CancelFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
@@ -33,62 +35,76 @@ from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.aggregators.llm_response import (
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.ai_services import AIService
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.openai import OpenAILLMService, OpenAILLMContext
from pipecat.transports.base_transport import BaseTransport
DEFAULT_MESSAGES = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
}
]
DEFAULT_MODEL = "llama3-70b-8192"
DEFAULT_VOICE = "79a125e8-cd45-4c13-8a67-188112f4dd22"
from loguru import logger
class RTVILLMConfig(BaseModel):
model: Optional[str] = None
messages: Optional[List[dict]] = None
class RTVIServiceOption(BaseModel):
name: str
handler: Optional[Callable[['RTVIProcessor',
'RTVIServiceOptionConfig'],
Awaitable[None]]] = None
class RTVITTSConfig(BaseModel):
voice: Optional[str] = None
class RTVIService(BaseModel):
name: str
cls: Type[FrameProcessor]
options: List[RTVIServiceOption]
_options_dict: Dict[str, RTVIServiceOption] = PrivateAttr(default={})
def model_post_init(self, __context: Any) -> None:
self._options_dict = {}
for option in self.options:
self._options_dict[option.name] = option
return super().model_post_init(__context)
#
# Client -> Pipecat messages.
#
class RTVIServiceOptionConfig(BaseModel):
name: str
value: Any
class RTVIServiceConfig(BaseModel):
service: str
options: List[RTVIServiceOptionConfig]
class RTVIConfig(BaseModel):
llm: Optional[RTVILLMConfig] = None
tts: Optional[RTVITTSConfig] = None
config: List[RTVIServiceConfig]
_config_dict: Dict[str, RTVIServiceConfig] = PrivateAttr(default={})
def model_post_init(self, __context: Any) -> None:
self._config_dict = {}
for c in self.config:
self._config_dict[c.service] = c
return super().model_post_init(__context)
class RTVISetup(BaseModel):
config: Optional[RTVIConfig] = None
class RTVILLMMessageData(BaseModel):
class RTVILLMContextData(BaseModel):
messages: List[dict]
class RTVITTSMessageData(BaseModel):
class RTVITTSSpeakData(BaseModel):
text: str
interrupt: Optional[bool] = False
class RTVIMessageData(BaseModel):
setup: Optional[RTVISetup] = None
config: Optional[RTVIConfig] = None
llm: Optional[RTVILLMMessageData] = None
tts: Optional[RTVITTSMessageData] = None
class RTVIMessage(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: str
id: str
data: Optional[RTVIMessageData] = None
data: Optional[Dict[str, Any]] = None
#
# Pipecat -> Client responses and messages.
#
class RTVIResponseData(BaseModel):
@@ -97,7 +113,7 @@ class RTVIResponseData(BaseModel):
class RTVIResponse(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["response"] = "response"
id: str
data: RTVIResponseData
@@ -108,7 +124,7 @@ class RTVIErrorData(BaseModel):
class RTVIError(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["error"] = "error"
data: RTVIErrorData
@@ -118,7 +134,7 @@ class RTVILLMContextMessageData(BaseModel):
class RTVILLMContextMessage(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["llm-context"] = "llm-context"
data: RTVILLMContextMessageData
@@ -128,13 +144,13 @@ class RTVITTSTextMessageData(BaseModel):
class RTVITTSTextMessage(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["tts-text"] = "tts-text"
data: RTVITTSTextMessageData
class RTVIBotReady(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["bot-ready"] = "bot-ready"
@@ -146,23 +162,23 @@ class RTVITranscriptionMessageData(BaseModel):
class RTVITranscriptionMessage(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["user-transcription"] = "user-transcription"
data: RTVITranscriptionMessageData
class RTVIUserStartedSpeakingMessage(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["user-started-speaking"] = "user-started-speaking"
class RTVIUserStoppedSpeakingMessage(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["user-stopped-speaking"] = "user-stopped-speaking"
class RTVIJSONCompletion(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["json-completion"] = "json-completion"
data: str
@@ -265,59 +281,128 @@ class RTVITTSTextProcessor(FrameProcessor):
await self.push_frame(TransportMessageFrame(message=message.model_dump(exclude_none=True)))
async def handle_llm_model_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
frame = LLMModelUpdateFrame(option.value)
await rtvi.push_frame(frame)
async def handle_llm_messages_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
frame = LLMMessagesUpdateFrame(option.value)
await rtvi.push_frame(frame)
async def handle_tts_voice_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
frame = TTSVoiceUpdateFrame(option.value)
await rtvi.push_frame(frame)
DEFAULT_LLM_SERVICE = RTVIService(
name="llm",
cls=OpenAILLMService,
options=[
RTVIServiceOption(name="model", handler=handle_llm_model_update),
RTVIServiceOption(name="messages", handler=handle_llm_messages_update)
])
DEFAULT_TTS_SERVICE = RTVIService(
name="tts",
cls=CartesiaTTSService,
options=[
RTVIServiceOption(name="voice_id", handler=handle_tts_voice_update),
])
class RTVIProcessor(FrameProcessor):
def __init__(
self,
*,
transport: BaseTransport,
setup: RTVISetup | None = None,
llm_api_key: str = "",
llm_base_url: str = "https://api.groq.com/openai/v1",
tts_api_key: str = "",
llm_cls: Type[AIService] = OpenAILLMService,
tts_cls: Type[AIService] = CartesiaTTSService):
def __init__(self, *, transport: BaseTransport):
super().__init__()
self._transport = transport
self._setup = setup
self._llm_api_key = llm_api_key
self._llm_base_url = llm_base_url
self._tts_api_key = tts_api_key
self._llm_cls = llm_cls
self._tts_cls = tts_cls
self._config: RTVIConfig | None = None
self._ctor_args: Dict[str, Any] = {}
self._start_frame: Frame | None = None
self._llm: FrameProcessor | None = None
self._tts: FrameProcessor | None = None
self._pipeline: FrameProcessor | None = None
self._first_participant_joined: bool = False
# Register transport event so we can send a `bot-ready` event (and maybe
# others) when the participant joins.
transport.add_event_handler(
"on_first_participant_joined",
self._on_first_participant_joined)
# Register default services.
self._registered_services: Dict[str, RTVIService] = {}
self.register_service(DEFAULT_LLM_SERVICE)
self.register_service(DEFAULT_TTS_SERVICE)
self._frame_handler_task = self.get_event_loop().create_task(self._frame_handler())
self._frame_queue = asyncio.Queue()
def register_service(self, service: RTVIService):
self._registered_services[service.name] = service
def setup_on_start(self, config: RTVIConfig | None, ctor_args: Dict[str, Any]):
self._config = config
self._ctor_args = ctor_args
async def update_config(self, config: RTVIConfig):
if self._pipeline:
await self._handle_config_update(config)
self._config = config
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, SystemFrame):
# Specific system frames
if isinstance(frame, CancelFrame):
await self._cancel(frame)
await self.push_frame(frame, direction)
# All other system frames
elif isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
# Control frames
elif isinstance(frame, StartFrame):
await self._start(frame)
await self._internal_push_frame(frame, direction)
elif isinstance(frame, EndFrame):
# Push EndFrame before stop(), because stop() waits on the task to
# finish and the task finishes when EndFrame is processed.
await self._internal_push_frame(frame, direction)
await self._stop(frame)
# Other frames
else:
await self._frame_queue.put((frame, direction))
if isinstance(frame, StartFrame):
self._start_frame = frame
try:
await self._handle_setup(self._setup)
except Exception as e:
await self._send_error(f"unable to setup RTVI: {e}")
await self._internal_push_frame(frame, direction)
async def cleanup(self):
if self._pipeline:
await self._pipeline.cleanup()
async def _start(self, frame: StartFrame):
try:
await self._handle_pipeline_setup(frame, self._config)
except Exception as e:
await self._send_error(f"unable to setup RTVI pipeline: {e}")
async def _stop(self, frame: EndFrame):
await self._frame_handler_task
async def _cancel(self, frame: CancelFrame):
self._frame_handler_task.cancel()
await self._frame_handler_task
async def _internal_push_frame(
self,
frame: Frame | None,
direction: FrameDirection | None = FrameDirection.DOWNSTREAM):
await self._frame_queue.put((frame, direction))
async def _frame_handler(self):
while True:
running = True
while running:
try:
(frame, direction) = await self._frame_queue.get()
await self._handle_frame(frame, direction)
self._frame_queue.task_done()
running = not isinstance(frame, EndFrame)
except asyncio.CancelledError:
break
@@ -372,113 +457,102 @@ class RTVIProcessor(FrameProcessor):
try:
message = RTVIMessage.model_validate(frame.message)
except ValidationError as e:
await self._send_error(f"invalid message: {e}")
await self._send_error(f"Invalid incoming message: {e}")
logger.warning(f"Invalid incoming message: {e}")
return
try:
success = True
error = None
match message.type:
case "setup":
setup = None
if message.data:
setup = message.data.setup
await self._handle_setup(message.id, setup)
case "config-update":
await self._handle_config_update(message.data.config)
await self._handle_config_update(RTVIConfig.model_validate(message.data))
case "llm-get-context":
await self._handle_llm_get_context()
case "llm-append-context":
await self._handle_llm_append_context(message.data.llm)
await self._handle_llm_append_context(RTVILLMContextData.model_validate(message.data))
case "llm-update-context":
await self._handle_llm_update_context(message.data.llm)
await self._handle_llm_update_context(RTVILLMContextData.model_validate(message.data))
case "tts-speak":
await self._handle_tts_speak(message.data.tts)
await self._handle_tts_speak(RTVITTSSpeakData.model_validate(message.data))
case "tts-interrupt":
await self._handle_tts_interrupt()
case _:
success = False
error = f"unsupported type {message.type}"
error = f"Unsupported type {message.type}"
await self._send_response(message.id, success, error)
except ValidationError as e:
await self._send_response(message.id, False, f"invalid message: {e}")
await self._send_response(message.id, False, f"Invalid incoming message: {e}")
logger.warning(f"Invalid incoming message: {e}")
except Exception as e:
await self._send_response(message.id, False, f"{e}")
await self._send_response(message.id, False, f"Exception processing message: {e}")
logger.warning(f"Exception processing message: {e}")
async def _handle_setup(self, setup: RTVISetup | None):
model = DEFAULT_MODEL
if setup and setup.config and setup.config.llm and setup.config.llm.model:
model = setup.config.llm.model
async def _handle_pipeline_setup(self, start_frame: StartFrame, config: RTVIConfig | None):
# TODO(aleix): We shouldn't need to save this in `self._tma_in`.
self._tma_in = LLMUserResponseAggregator()
tma_out = LLMAssistantResponseAggregator()
messages = DEFAULT_MESSAGES
if setup and setup.config and setup.config.llm and setup.config.llm.messages:
messages = setup.config.llm.messages
llm_cls = self._registered_services["llm"].cls
llm_args = self._ctor_args["llm"]
llm = llm_cls(**llm_args)
voice = DEFAULT_VOICE
if setup and setup.config and setup.config.tts and setup.config.tts.voice:
voice = setup.config.tts.voice
self._tma_in = LLMUserResponseAggregator(messages)
self._tma_out = LLMAssistantResponseAggregator(messages)
self._llm = self._llm_cls(
name="LLM",
base_url=self._llm_base_url,
api_key=self._llm_api_key,
model=model)
self._tts = self._tts_cls(name="TTS", api_key=self._tts_api_key, voice_id=voice)
tts_cls = self._registered_services["tts"].cls
tts_args = self._ctor_args["tts"]
tts = tts_cls(**tts_args)
# TODO-CB: Eventually we'll need to switch the context aggregators to use the
# OpenAI context frames instead of message frames
context = OpenAILLMContext(messages=messages)
self._fc = FunctionCaller(context)
context = OpenAILLMContext()
fc = FunctionCaller(context)
self._tts_text = RTVITTSTextProcessor()
tts_text = RTVITTSTextProcessor()
pipeline = Pipeline([
self._tma_in,
self._llm,
self._fc,
self._tts,
self._tts_text,
self._tma_out,
llm,
fc,
tts,
tts_text,
tma_out,
self._transport.output(),
])
self._pipeline = pipeline
parent = self.get_parent()
if parent and self._start_frame:
if parent:
parent.link(pipeline)
# We need to initialize the new pipeline with the same settings
# as the initial one.
start_frame = dataclasses.replace(self._start_frame)
start_frame = dataclasses.replace(start_frame)
await self.push_frame(start_frame)
# Configure the pipeline
if config:
await self._handle_config_update(config)
# Send new initial metrics with the new processors
processors = parent.processors_with_metrics()
processors.extend(self._pipeline.processors_with_metrics())
processors.extend(pipeline.processors_with_metrics())
ttfb = [{"processor": p.name, "value": 0.0} for p in processors]
processing = [{"processor": p.name, "value": 0.0} for p in processors]
await self.push_frame(MetricsFrame(ttfb=ttfb, processing=processing))
message = RTVIBotReady()
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
self._pipeline = pipeline
async def _handle_config_update(self, config: RTVIConfig):
# Change voice before LLM updates, so we can hear the new vocie.
if config.tts and config.tts.voice:
frame = TTSVoiceUpdateFrame(config.tts.voice)
await self.push_frame(frame)
if config.llm and config.llm.model:
frame = LLMModelUpdateFrame(config.llm.model)
await self.push_frame(frame)
if config.llm and config.llm.messages:
frame = LLMMessagesUpdateFrame(config.llm.messages)
await self.push_frame(frame)
await self._maybe_send_bot_ready()
async def _handle_config_service(self, config: RTVIServiceConfig):
service = self._registered_services[config.service]
for option in config.options:
handler = service._options_dict[option.name].handler
if handler:
await handler(self, option)
async def _handle_config_update(self, data: RTVIConfig):
for config in data.config:
await self._handle_config_service(config)
async def _handle_llm_get_context(self):
data = RTVILLMContextMessageData(messages=self._tma_in.messages)
@@ -486,17 +560,17 @@ class RTVIProcessor(FrameProcessor):
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
async def _handle_llm_append_context(self, data: RTVILLMMessageData):
async def _handle_llm_append_context(self, data: RTVILLMContextData):
if data and data.messages:
frame = LLMMessagesAppendFrame(data.messages)
await self.push_frame(frame)
async def _handle_llm_update_context(self, data: RTVILLMMessageData):
async def _handle_llm_update_context(self, data: RTVILLMContextData):
if data and data.messages:
frame = LLMMessagesUpdateFrame(data.messages)
await self.push_frame(frame)
async def _handle_tts_speak(self, data: RTVITTSMessageData):
async def _handle_tts_speak(self, data: RTVITTSSpeakData):
if data and data.text:
if data.interrupt:
await self._handle_tts_interrupt()
@@ -506,6 +580,16 @@ class RTVIProcessor(FrameProcessor):
async def _handle_tts_interrupt(self):
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
async def _on_first_participant_joined(self, transport, participant):
self._first_participant_joined = True
await self._maybe_send_bot_ready()
async def _maybe_send_bot_ready(self):
if self._pipeline and self._first_participant_joined:
message = RTVIBotReady()
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
async def _send_error(self, error: str):
message = RTVIError(data=RTVIErrorData(message=error))
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
@@ -523,7 +607,7 @@ class RTVIProcessor(FrameProcessor):
self._pipeline = pipeline
parent = self.get_parent()
if parent and self._start_frame:
if parent:
parent.link(pipeline)
message = RTVIResponse(id=id, data=RTVIResponseData(success=success, error=error))

View File

@@ -8,7 +8,12 @@ import asyncio
from typing import Awaitable, Callable
from pipecat.frames.frames import BotSpeakingFrame, Frame, StartInterruptionFrame, StopInterruptionFrame, SystemFrame
from pipecat.frames.frames import (
BotSpeakingFrame,
Frame,
SystemFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame)
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
from pipecat.processors.frame_processor import FrameDirection
@@ -47,10 +52,10 @@ class UserIdleProcessor(AsyncFrameProcessor):
await self.queue_frame(frame, direction)
# We shouldn't call the idle callback if the user or the bot are speaking.
if isinstance(frame, StartInterruptionFrame):
if isinstance(frame, UserStartedSpeakingFrame):
self._interrupted = True
self._idle_event.set()
elif isinstance(frame, StopInterruptionFrame):
elif isinstance(frame, UserStoppedSpeakingFrame):
self._interrupted = False
self._idle_event.set()
elif isinstance(frame, BotSpeakingFrame):

View File

@@ -283,14 +283,17 @@ class STTService(AIService):
await self.stop_processing_metrics()
(self._content, self._wave) = self._new_wave()
async def stop(self, frame: EndFrame):
self._wave.close()
async def cancel(self, frame: CancelFrame):
self._wave.close()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Processes a frame of audio data, either buffering or transcribing it."""
await super().process_frame(frame, direction)
if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame):
self._wave.close()
await self.push_frame(frame, direction)
elif isinstance(frame, AudioRawFrame):
if isinstance(frame, AudioRawFrame):
# In this service we accumulate audio internally and at the end we
# push a TextFrame. We don't really want to push audio frames down.
await self._append_audio(frame)

View File

@@ -147,13 +147,16 @@ class AzureSTTService(AsyncAIService):
await self._push_queue.put((frame, direction))
async def start(self, frame: StartFrame):
await super().start(frame)
self._speech_recognizer.start_continuous_recognition_async()
async def stop(self, frame: EndFrame):
await super().stop(frame)
self._speech_recognizer.stop_continuous_recognition_async()
self._audio_stream.close()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
self._speech_recognizer.stop_continuous_recognition_async()
self._audio_stream.close()
@@ -168,12 +171,12 @@ class AzureImageGenServiceREST(ImageGenService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
image_size: str,
api_key: str,
endpoint: str,
model: str,
api_version="2023-06-01-preview",
aiohttp_session: aiohttp.ClientSession | None = None,
):
super().__init__()
@@ -181,8 +184,14 @@ class AzureImageGenServiceREST(ImageGenService):
self._azure_endpoint = endpoint
self._api_version = api_version
self._model = model
self._aiohttp_session = aiohttp_session
self._image_size = image_size
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None
async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}"

View File

@@ -14,6 +14,7 @@ from typing import AsyncGenerator
from pipecat.processors.frame_processor import FrameDirection
from pipecat.frames.frames import (
CancelFrame,
Frame,
AudioRawFrame,
StartInterruptionFrame,
@@ -98,6 +99,10 @@ class CartesiaTTSService(TTSService):
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()
async def _connect(self):
try:
self._websocket = await websockets.connect(
@@ -111,6 +116,8 @@ class CartesiaTTSService(TTSService):
async def _disconnect(self):
try:
await self.stop_all_metrics()
if self._context_appending_task:
self._context_appending_task.cancel()
await self._context_appending_task
@@ -120,13 +127,12 @@ class CartesiaTTSService(TTSService):
await self._receive_task
self._receive_task = None
if self._websocket:
ws = self._websocket
await self._websocket.close()
self._websocket = None
await ws.close()
self._context_id = None
self._context_id_start_timestamp = None
self._timestamped_words_buffer = []
await self.stop_all_metrics()
except Exception as e:
logger.exception(f"{self} error closing websocket: {e}")
@@ -142,13 +148,13 @@ class CartesiaTTSService(TTSService):
try:
async for message in self._websocket:
msg = json.loads(message)
# logger.debug(f"Received message: {msg['type']} {msg['context_id']}")
if not msg or msg["context_id"] != self._context_id:
continue
if msg["type"] == "done":
await self.stop_ttfb_metrics()
# unset _context_id but not the _context_id_start_timestamp because we are likely still
# playing out audio and need the timestamp to set send context frames
# Unset _context_id but not the _context_id_start_timestamp
# because we are likely still playing out audio and need the
# timestamp to set send context frames.
self._context_id = None
self._timestamped_words_buffer.append(("LLMFullResponseEndFrame", 0))
elif msg["type"] == "timestamps":
@@ -166,6 +172,8 @@ class CartesiaTTSService(TTSService):
num_channels=1
)
await self.push_frame(frame)
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"{self} exception: {e}")
@@ -176,15 +184,17 @@ class CartesiaTTSService(TTSService):
if not self._context_id_start_timestamp:
continue
elapsed_seconds = time.time() - self._context_id_start_timestamp
# pop all words from self._timestamped_words_buffer that are older than the
# elapsed time and print a message about them to the console
# Pop all words from self._timestamped_words_buffer that are
# older than the elapsed time and print a message about them to
# the console.
while self._timestamped_words_buffer and self._timestamped_words_buffer[0][1] <= elapsed_seconds:
word, timestamp = self._timestamped_words_buffer.pop(0)
if word == "LLMFullResponseEndFrame" and timestamp == 0:
await self.push_frame(LLMFullResponseEndFrame())
continue
# print(f"Word '{word}' with timestamp {timestamp:.2f}s has been spoken.")
await self.push_frame(TextFrame(word))
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"{self} exception: {e}")
@@ -212,7 +222,6 @@ class CartesiaTTSService(TTSService):
"language": self._language,
"add_timestamps": True,
}
# logger.debug(f"SENDING MESSAGE {json.dumps(msg)}")
try:
await self._websocket.send(json.dumps(msg))
except Exception as e:

View File

@@ -45,21 +45,31 @@ class DeepgramTTSService(TTSService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
api_key: str,
voice: str = "aura-helios-en",
base_url: str = "https://api.deepgram.com/v1/speak",
sample_rate: int = 16000,
encoding: str = "linear16",
aiohttp_session: aiohttp.ClientSession | None = None,
**kwargs):
super().__init__(**kwargs)
self._voice = voice
self._api_key = api_key
self._aiohttp_session = aiohttp_session
self._base_url = base_url
self._sample_rate = sample_rate
self._encoding = encoding
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None
def can_generate_metrics(self) -> bool:
return True
async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()
async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice = voice
@@ -68,7 +78,7 @@ class DeepgramTTSService(TTSService):
logger.debug(f"Generating TTS: [{text}]")
base_url = self._base_url
request_url = f"{base_url}?model={self._voice}&encoding=linear16&container=none&sample_rate=16000"
request_url = f"{base_url}?model={self._voice}&encoding={self._encoding}&container=none&sample_rate={self._sample_rate}"
headers = {"authorization": f"token {self._api_key}"}
body = {"text": text}
@@ -91,7 +101,7 @@ class DeepgramTTSService(TTSService):
async for data in r.content:
await self.stop_ttfb_metrics()
frame = AudioRawFrame(audio=data, sample_rate=16000, num_channels=1)
frame = AudioRawFrame(audio=data, sample_rate=self._sample_rate, num_channels=1)
yield frame
except Exception as e:
logger.exception(f"{self} exception: {e}")
@@ -132,15 +142,18 @@ class DeepgramSTTService(AsyncAIService):
await self.queue_frame(frame, direction)
async def start(self, frame: StartFrame):
await super().start(frame)
if await self._connection.start(self._live_options):
logger.debug(f"{self}: Connected to Deepgram")
else:
logger.error(f"{self}: Unable to connect to Deepgram")
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._connection.finish()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._connection.finish()
async def _on_message(self, *args, **kwargs):

View File

@@ -19,21 +19,27 @@ class ElevenLabsTTSService(TTSService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
api_key: str,
voice_id: str,
model: str = "eleven_turbo_v2",
aiohttp_session: aiohttp.ClientSession | None = None,
**kwargs):
super().__init__(**kwargs)
self._api_key = api_key
self._voice_id = voice_id
self._aiohttp_session = aiohttp_session
self._model = model
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None
def can_generate_metrics(self) -> bool:
return True
async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()
async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice_id = voice

View File

@@ -39,18 +39,24 @@ class FalImageGenService(ImageGenService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
params: InputParams,
model: str = "fal-ai/fast-sdxl",
key: str | None = None,
aiohttp_session: aiohttp.ClientSession | None = None,
):
super().__init__()
self._model = model
self._params = params
self._aiohttp_session = aiohttp_session
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None
if key:
os.environ["FAL_KEY"] = key
async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating image from prompt: {prompt}")

View File

@@ -68,14 +68,17 @@ class GladiaSTTService(AsyncAIService):
await self.queue_frame(frame, direction)
async def start(self, frame: StartFrame):
await super().start(frame)
self._websocket = await websockets.connect(self._url)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
await self._setup_gladia()
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._websocket.close()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._websocket.close()
async def _setup_gladia(self):

View File

@@ -253,16 +253,22 @@ class OpenAIImageGenService(ImageGenService):
def __init__(
self,
*,
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
aiohttp_session: aiohttp.ClientSession,
api_key: str,
model: str = "dall-e-3",
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
aiohttp_session: aiohttp.ClientSession | None = None,
):
super().__init__()
self._model = model
self._image_size = image_size
self._client = AsyncOpenAI(api_key=api_key)
self._aiohttp_session = aiohttp_session
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None
async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating image from prompt: {prompt}")

View File

@@ -38,22 +38,28 @@ class XTTSService(TTSService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
voice_id: str,
language: str,
base_url: str,
aiohttp_session: aiohttp.ClientSession | None = None,
**kwargs):
super().__init__(**kwargs)
self._voice_id = voice_id
self._language = language
self._base_url = base_url
self._aiohttp_session = aiohttp_session
self._studio_speakers = requests.get(self._base_url + "/studio_speakers").json()
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None
def can_generate_metrics(self) -> bool:
return True
async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()
async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice_id = voice

View File

@@ -18,6 +18,7 @@ from pipecat.frames.frames import (
Frame,
StartInterruptionFrame,
StopInterruptionFrame,
SystemFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame)
from pipecat.transports.base_transport import TransportParams
@@ -45,12 +46,26 @@ class BaseInputTransport(FrameProcessor):
self._audio_in_queue = asyncio.Queue()
self._audio_task = self.get_event_loop().create_task(self._audio_task_handler())
async def stop(self):
# Wait for the task to finish.
async def stop(self, frame: EndFrame):
# Cancel and wait for the audio input task to finish.
if self._params.audio_in_enabled or self._params.vad_enabled:
self._audio_task.cancel()
await self._audio_task
# Wait for the push frame task to finish. It will finish when the
# EndFrame is actually processed.
await self._push_frame_task
async def cancel(self, frame: CancelFrame):
# Cancel all the tasks and wait for them to finish.
if self._params.audio_in_enabled or self._params.vad_enabled:
self._audio_task.cancel()
await self._audio_task
self._push_frame_task.cancel()
await self._push_frame_task
def vad_analyzer(self) -> VADAnalyzer | None:
return self._params.vad_analyzer
@@ -62,25 +77,32 @@ class BaseInputTransport(FrameProcessor):
# Frame processor
#
async def cleanup(self):
self._push_frame_task.cancel()
await self._push_frame_task
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# Specific system frames
if isinstance(frame, CancelFrame):
await self.stop()
# We don't queue a CancelFrame since we want to stop ASAP.
await self.cancel(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, BotInterruptionFrame):
await self._handle_interruptions(frame, False)
elif isinstance(frame, StartInterruptionFrame):
await self._start_interruption()
elif isinstance(frame, StopInterruptionFrame):
await self._stop_interruption()
# All other system frames
elif isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
# Control frames
elif isinstance(frame, StartFrame):
await self.start(frame)
await self._internal_push_frame(frame, direction)
elif isinstance(frame, EndFrame):
# Push EndFrame before stop(), because stop() waits on the task to
# finish and the task finishes when EndFrame is processed.
await self._internal_push_frame(frame, direction)
await self.stop()
elif isinstance(frame, BotInterruptionFrame):
await self._handle_interruptions(frame, False)
await self.stop(frame)
# Other frames
else:
await self._internal_push_frame(frame, direction)
@@ -100,10 +122,12 @@ class BaseInputTransport(FrameProcessor):
await self._push_queue.put((frame, direction))
async def _push_frame_task_handler(self):
while True:
running = True
while running:
try:
(frame, direction) = await self._push_queue.get()
await self.push_frame(frame, direction)
running = not isinstance(frame, EndFrame)
self._push_queue.task_done()
except asyncio.CancelledError:
break
@@ -113,6 +137,9 @@ class BaseInputTransport(FrameProcessor):
#
async def _start_interruption(self):
if not self.interruptions_allowed:
return
# Cancel the task. This will stop pushing frames downstream.
self._push_frame_task.cancel()
await self._push_frame_task
@@ -124,6 +151,9 @@ class BaseInputTransport(FrameProcessor):
self._create_push_task()
async def _stop_interruption(self):
if not self.interruptions_allowed:
return
await self.push_frame(StopInterruptionFrame())
async def _handle_interruptions(self, frame: Frame, push_frame: bool):

View File

@@ -15,6 +15,8 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.frames.frames import (
AudioRawFrame,
BotSpeakingFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
MetricsFrame,
SpriteFrame,
@@ -25,6 +27,8 @@ from pipecat.frames.frames import (
StartInterruptionFrame,
StopInterruptionFrame,
SystemFrame,
TTSStartedFrame,
TTSStoppedFrame,
TransportMessageFrame)
from pipecat.transports.base_transport import TransportParams
@@ -60,18 +64,34 @@ class BaseOutputTransport(FrameProcessor):
self._create_push_task()
async def start(self, frame: StartFrame):
# Create media threads queues.
# Create camera output queue and task if needed.
if self._params.camera_out_enabled:
self._camera_out_queue = asyncio.Queue()
self._camera_out_task = self.get_event_loop().create_task(self._camera_out_task_handler())
async def stop(self):
# Wait on the threads to finish.
async def stop(self, frame: EndFrame):
# Cancel and wait for the camera output task to finish.
if self._params.camera_out_enabled:
self._camera_out_task.cancel()
await self._camera_out_task
self._stopped_event.set()
# Wait for the push frame and sink tasks to finish. They will finish when
# the EndFrame is actually processed.
await self._push_frame_task
await self._sink_task
async def cancel(self, frame: CancelFrame):
# Cancel all the tasks and wait for them to finish.
if self._params.camera_out_enabled:
self._camera_out_task.cancel()
await self._camera_out_task
self._push_frame_task.cancel()
await self._push_frame_task
self._sink_task.cancel()
await self._sink_task
async def send_message(self, frame: TransportMessageFrame):
pass
@@ -89,48 +109,38 @@ class BaseOutputTransport(FrameProcessor):
# Frame processor
#
async def cleanup(self):
if self._sink_task:
self._sink_task.cancel()
await self._sink_task
self._push_frame_task.cancel()
await self._push_frame_task
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
#
# Out-of-band frames like (CancelFrame or StartInterruptionFrame) are
# pushed immediately. Other frames require order so they are put in the
# sink queue.
# System frames (like StartInterruptionFrame) are pushed
# immediately. Other frames require order so they are put in the sink
# queue.
#
if isinstance(frame, StartFrame):
await self.start(frame)
await self.push_frame(frame, direction)
# EndFrame is managed in the sink queue handler.
elif isinstance(frame, CancelFrame):
await self.stop()
if isinstance(frame, CancelFrame):
await self.push_frame(frame, direction)
await self.cancel(frame)
elif isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame):
await self.push_frame(frame, direction)
await self._handle_interruptions(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, MetricsFrame):
await self.send_metrics(frame)
await self.push_frame(frame, direction)
await self.send_metrics(frame)
elif isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
# Control frames.
elif isinstance(frame, StartFrame):
await self._sink_queue.put(frame)
await self.start(frame)
elif isinstance(frame, EndFrame):
await self._sink_queue.put(frame)
await self.stop(frame)
# Other frames.
elif isinstance(frame, AudioRawFrame):
await self._handle_audio(frame)
else:
await self._sink_queue.put(frame)
# If we are finishing, wait here until we have stopped, otherwise we might
# close things too early upstream. We need this event because we don't
# know when the internal threads will finish.
if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame):
await self._stopped_event.wait()
async def _handle_interruptions(self, frame: Frame):
if not self.interruptions_allowed:
return
@@ -160,7 +170,9 @@ class BaseOutputTransport(FrameProcessor):
async def _sink_task_handler(self):
# Audio accumlation buffer
buffer = bytearray()
while True:
running = True
while running:
try:
frame = await self._sink_queue.get()
if isinstance(frame, AudioRawFrame) and self._params.audio_out_enabled:
@@ -172,11 +184,16 @@ class BaseOutputTransport(FrameProcessor):
await self._set_camera_images(frame.images)
elif isinstance(frame, TransportMessageFrame):
await self.send_message(frame)
elif isinstance(frame, TTSStartedFrame):
await self._internal_push_frame(BotStartedSpeakingFrame(), FrameDirection.UPSTREAM)
await self._internal_push_frame(frame)
elif isinstance(frame, TTSStoppedFrame):
await self._internal_push_frame(BotStoppedSpeakingFrame(), FrameDirection.UPSTREAM)
await self._internal_push_frame(frame)
else:
await self._internal_push_frame(frame)
if isinstance(frame, EndFrame):
await self.stop()
running = not isinstance(frame, EndFrame)
self._sink_queue.task_done()
except asyncio.CancelledError:
@@ -200,10 +217,12 @@ class BaseOutputTransport(FrameProcessor):
await self._push_queue.put((frame, direction))
async def _push_frame_task_handler(self):
while True:
running = True
while running:
try:
(frame, direction) = await self._push_queue.get()
await self.push_frame(frame, direction)
running = not isinstance(frame, EndFrame)
self._push_queue.task_done()
except asyncio.CancelledError:
break

View File

@@ -60,20 +60,20 @@ class BaseTransport(ABC):
def event_handler(self, event_name: str):
def decorator(handler):
self._add_event_handler(event_name, handler)
self.add_event_handler(event_name, handler)
return handler
return decorator
def add_event_handler(self, event_name: str, handler):
if event_name not in self._event_handlers:
raise Exception(f"Event handler {event_name} not registered")
self._event_handlers[event_name].append(handler)
def _register_event_handler(self, event_name: str):
if event_name in self._event_handlers:
raise Exception(f"Event handler {event_name} already registered")
self._event_handlers[event_name] = []
def _add_event_handler(self, event_name: str, handler):
if event_name not in self._event_handlers:
raise Exception(f"Event handler {event_name} not registered")
self._event_handlers[event_name].append(handler)
async def _call_event_handler(self, event_name: str, *args, **kwargs):
try:
for handler in self._event_handlers[event_name]:

View File

@@ -12,7 +12,7 @@ import wave
from typing import Awaitable, Callable
from pydantic.main import BaseModel
from pipecat.frames.frames import AudioRawFrame, StartFrame
from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, StartFrame
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.serializers.base_serializer import FrameSerializer
from pipecat.transports.base_input import BaseInputTransport
@@ -57,14 +57,19 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
self._callbacks = callbacks
async def start(self, frame: StartFrame):
await self._callbacks.on_client_connected(self._websocket)
await super().start(frame)
await self._callbacks.on_client_connected(self._websocket)
self._receive_task = self.get_event_loop().create_task(self._receive_messages())
async def stop(self):
async def stop(self, frame: EndFrame):
await super().stop(frame)
if self._websocket.client_state != WebSocketState.DISCONNECTED:
await self._websocket.close()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
if self._websocket.client_state != WebSocketState.DISCONNECTED:
await self._websocket.close()
await super().stop()
async def _receive_messages(self):
async for message in self._websocket.iter_text():

View File

@@ -11,7 +11,7 @@ import wave
from typing import Awaitable, Callable
from pydantic.main import BaseModel
from pipecat.frames.frames import AudioRawFrame, StartFrame
from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, StartFrame
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.serializers.base_serializer import FrameSerializer
from pipecat.serializers.protobuf import ProtobufFrameSerializer
@@ -64,10 +64,15 @@ class WebsocketServerInputTransport(BaseInputTransport):
self._server_task = self.get_event_loop().create_task(self._server_task_handler())
await super().start(frame)
async def stop(self):
async def stop(self, frame: EndFrame):
await super().stop(frame)
self._stop_server_event.set()
await self._server_task
await super().stop()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
self._server_task.cancel()
await self._server_task
async def _server_task_handler(self):
logger.info(f"Starting websocket server on {self._host}:{self._port}")

View File

@@ -23,6 +23,8 @@ from pydantic.main import BaseModel
from pipecat.frames.frames import (
AudioRawFrame,
CancelFrame,
EndFrame,
Frame,
ImageRawFrame,
InterimTranscriptionFrame,
@@ -125,8 +127,15 @@ class DailyCallbacks(BaseModel):
def completion_callback(future):
def _callback(*args):
if not future.cancelled():
future.get_loop().call_soon_threadsafe(future.set_result, *args)
def set_result(future, *args):
try:
if len(args) > 1:
future.set_result(args)
else:
future.set_result(*args)
except asyncio.InvalidStateError:
pass
future.get_loop().call_soon_threadsafe(set_result, future, *args)
return _callback
@@ -282,11 +291,12 @@ class DailyTransportClient(EventHandler):
await self._callbacks.on_error(error_msg)
async def _start_transcription(self):
future = self._loop.create_future()
logger.info(f"Enabling transcription with settings {self._params.transcription_settings}")
future = self._loop.create_future()
self._client.start_transcription(
settings=self._params.transcription_settings.model_dump(exclude_none=True),
completion=lambda error: future.set_result(error)
completion=completion_callback(future)
)
error = await future
if error:
@@ -295,14 +305,10 @@ class DailyTransportClient(EventHandler):
async def _join(self):
future = self._loop.create_future()
def handle_join_response(data, error):
if not future.cancelled():
future.get_loop().call_soon_threadsafe(future.set_result, (data, error))
self._client.join(
self._room_url,
self._token,
completion=handle_join_response,
completion=completion_callback(future),
client_settings={
"inputs": {
"camera": {
@@ -370,20 +376,14 @@ class DailyTransportClient(EventHandler):
async def _stop_transcription(self):
future = self._loop.create_future()
self._client.stop_transcription(completion=lambda error: future.set_result(error))
self._client.stop_transcription(completion=completion_callback(future))
error = await future
if error:
logger.error(f"Unable to stop transcription: {error}")
async def _leave(self):
future = self._loop.create_future()
def handle_leave_response(error):
if not future.cancelled():
future.get_loop().call_soon_threadsafe(future.set_result, error)
self._client.leave(completion=handle_leave_response)
self._client.leave(completion=completion_callback(future))
return await asyncio.wait_for(future, timeout=10)
async def cleanup(self):
@@ -547,9 +547,19 @@ class DailyInputTransport(BaseInputTransport):
if self._params.audio_in_enabled or self._params.vad_enabled:
self._audio_in_task = self.get_event_loop().create_task(self._audio_in_task_handler())
async def stop(self):
async def stop(self, frame: EndFrame):
# Parent stop.
await super().stop()
await super().stop(frame)
# Leave the room.
await self._client.leave()
# Stop audio thread.
if self._params.audio_in_enabled or self._params.vad_enabled:
self._audio_in_task.cancel()
await self._audio_in_task
async def cancel(self, frame: CancelFrame):
# Parent stop.
await super().cancel(frame)
# Leave the room.
await self._client.leave()
# Stop audio thread.
@@ -664,9 +674,15 @@ class DailyOutputTransport(BaseOutputTransport):
# Join the room.
await self._client.join()
async def stop(self):
async def stop(self, frame: EndFrame):
# Parent stop.
await super().stop()
await super().stop(frame)
# Leave the room.
await self._client.leave()
async def cancel(self, frame: CancelFrame):
# Parent stop.
await super().cancel(frame)
# Leave the room.
await self._client.leave()