Compare commits
24 Commits
v0.0.38
...
aleix/stop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d6b8b035e | ||
|
|
0a15874c12 | ||
|
|
d60e99a043 | ||
|
|
77723b34c7 | ||
|
|
c466d34a06 | ||
|
|
f816897833 | ||
|
|
c1e8a5e522 | ||
|
|
76aca32f2e | ||
|
|
7e31b2a795 | ||
|
|
028e38a86b | ||
|
|
8cf7649855 | ||
|
|
64f5119b08 | ||
|
|
4d606aefb3 | ||
|
|
4bafdaa04d | ||
|
|
5afe1abf82 | ||
|
|
f066d50b98 | ||
|
|
91103e21cc | ||
|
|
f44dabcd65 | ||
|
|
0fd2fca231 | ||
|
|
5bb64098e7 | ||
|
|
3fc85e75e0 | ||
|
|
3f61ea16b7 | ||
|
|
4b393092b5 | ||
|
|
b583f5162b |
3
.github/workflows/publish_test.yaml
vendored
3
.github/workflows/publish_test.yaml
vendored
@@ -9,6 +9,9 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-tags: true
|
||||
fetch-depth: 100
|
||||
- name: Set up Python
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v4
|
||||
|
||||
37
CHANGELOG.md
37
CHANGELOG.md
@@ -5,6 +5,43 @@ All notable changes to **pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Added new `BotStartedSpeakingFrame` and `BotStoppedSpeakingFrame` control
|
||||
frames. These frames are pushed upstream and they should wrap
|
||||
`BotSpeakingFrame`.
|
||||
|
||||
- Transports now allow you to register event handlers without decorators.
|
||||
|
||||
### Changed
|
||||
|
||||
- `BotSpeakingFrame` is now a control frame.
|
||||
|
||||
- `StartFrame` is now a control frame similar to `EndFrame`.
|
||||
|
||||
- `DeepgramTTSService` now is more customizable. You can adjust the encoding and
|
||||
sample rate.
|
||||
|
||||
### Fixed
|
||||
|
||||
- RTVI's `bot-ready` message is now sent when the RTVI pipeline is ready and
|
||||
a first participant joins.
|
||||
|
||||
- Fixed a `BaseInputTransport` issue that was causing incoming system frames to
|
||||
be queued instead of being pushed immediately.
|
||||
|
||||
- Fixed a `BaseInputTransport` issue that was causing start/stop interruptions
|
||||
incoming frames to not cancel tasks and be processed properly.
|
||||
|
||||
## [0.0.39] - 2024-07-23
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a regression introduced in 0.0.38 that would cause Daily transcription
|
||||
to stop the Pipeline.
|
||||
|
||||
## [0.0.38] - 2024-07-23
|
||||
|
||||
### Added
|
||||
|
||||
@@ -4,5 +4,5 @@ grpcio-tools~=1.62.2
|
||||
pip-tools~=7.4.1
|
||||
pyright~=1.1.367
|
||||
pytest~=8.2.0
|
||||
setuptools~=69.5.1
|
||||
setuptools~=71.1.0
|
||||
setuptools_scm~=8.1.0
|
||||
|
||||
@@ -51,7 +51,7 @@ class ImageSyncAggregator(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if not isinstance(frame, SystemFrame):
|
||||
if not isinstance(frame, SystemFrame) and direction == FrameDirection.DOWNSTREAM:
|
||||
await self.push_frame(ImageRawFrame(image=self._speaking_image_bytes, size=(1024, 1024), format=self._speaking_image_format))
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(ImageRawFrame(image=self._waiting_image_bytes, size=(1024, 1024), format=self._waiting_image_format))
|
||||
|
||||
@@ -8,7 +8,6 @@ aiofiles==24.1.0
|
||||
# via deepgram-sdk
|
||||
aiohttp==3.9.5
|
||||
# via
|
||||
# cartesia
|
||||
# deepgram-sdk
|
||||
# langchain
|
||||
# langchain-community
|
||||
@@ -36,17 +35,15 @@ attrs==23.2.0
|
||||
# via
|
||||
# aiohttp
|
||||
# openpipe
|
||||
av==12.2.0
|
||||
av==12.3.0
|
||||
# via faster-whisper
|
||||
azure-cognitiveservices-speech==1.38.0
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
blinker==1.8.2
|
||||
# via flask
|
||||
cachetools==5.3.3
|
||||
cachetools==5.4.0
|
||||
# via google-auth
|
||||
cartesia==1.0.3
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
certifi==2024.6.2
|
||||
certifi==2024.7.4
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
@@ -80,13 +77,11 @@ einops==0.8.0
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
email-validator==2.2.0
|
||||
# via fastapi
|
||||
exceptiongroup==1.2.1
|
||||
# via
|
||||
# anyio
|
||||
# pytest
|
||||
exceptiongroup==1.2.2
|
||||
# via anyio
|
||||
fal-client==0.4.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
fastapi==0.111.0
|
||||
fastapi==0.111.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
fastapi-cli==0.0.4
|
||||
# via fastapi
|
||||
@@ -124,9 +119,9 @@ google-api-core[grpc]==2.19.1
|
||||
# google-ai-generativelanguage
|
||||
# google-api-python-client
|
||||
# google-generativeai
|
||||
google-api-python-client==2.135.0
|
||||
google-api-python-client==2.137.0
|
||||
# via google-generativeai
|
||||
google-auth==2.31.0
|
||||
google-auth==2.32.0
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
# google-api-core
|
||||
@@ -135,7 +130,7 @@ google-auth==2.31.0
|
||||
# google-generativeai
|
||||
google-auth-httplib2==0.2.0
|
||||
# via google-api-python-client
|
||||
google-generativeai==0.7.1
|
||||
google-generativeai==0.7.2
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
googleapis-common-protos==1.63.2
|
||||
# via
|
||||
@@ -143,7 +138,7 @@ googleapis-common-protos==1.63.2
|
||||
# grpcio-status
|
||||
greenlet==3.0.3
|
||||
# via sqlalchemy
|
||||
grpcio==1.64.1
|
||||
grpcio==1.65.1
|
||||
# via
|
||||
# google-api-core
|
||||
# grpcio-status
|
||||
@@ -165,7 +160,6 @@ httptools==0.6.1
|
||||
httpx==0.27.0
|
||||
# via
|
||||
# anthropic
|
||||
# cartesia
|
||||
# deepgram-sdk
|
||||
# fal-client
|
||||
# fastapi
|
||||
@@ -173,7 +167,7 @@ httpx==0.27.0
|
||||
# openpipe
|
||||
httpx-sse==0.4.0
|
||||
# via fal-client
|
||||
huggingface-hub==0.23.4
|
||||
huggingface-hub==0.24.1
|
||||
# via
|
||||
# faster-whisper
|
||||
# timm
|
||||
@@ -188,8 +182,6 @@ idna==3.7
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jinja2==3.1.4
|
||||
@@ -203,23 +195,23 @@ jsonpatch==1.33
|
||||
# via langchain-core
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
langchain==0.2.6
|
||||
langchain==0.2.11
|
||||
# via
|
||||
# langchain-community
|
||||
# pipecat-ai (pyproject.toml)
|
||||
langchain-community==0.2.6
|
||||
langchain-community==0.2.10
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
langchain-core==0.2.10
|
||||
langchain-core==0.2.23
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langchain-openai
|
||||
# langchain-text-splitters
|
||||
langchain-openai==0.1.10
|
||||
langchain-openai==0.1.17
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
langchain-text-splitters==0.2.2
|
||||
# via langchain
|
||||
langsmith==0.1.83
|
||||
langsmith==0.1.93
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
@@ -296,31 +288,26 @@ nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
onnxruntime==1.18.1
|
||||
# via faster-whisper
|
||||
openai==1.27.0
|
||||
openai==1.35.15
|
||||
# via
|
||||
# langchain-openai
|
||||
# openpipe
|
||||
# pipecat-ai (pyproject.toml)
|
||||
openpipe==4.16.0
|
||||
openpipe==4.18.0
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
orjson==3.10.5
|
||||
# via
|
||||
# fastapi
|
||||
# langsmith
|
||||
orjson==3.10.6
|
||||
# via langsmith
|
||||
packaging==24.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# langchain-core
|
||||
# marshmallow
|
||||
# onnxruntime
|
||||
# pytest
|
||||
# transformers
|
||||
pillow==10.3.0
|
||||
# via
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# torchvision
|
||||
pluggy==1.5.0
|
||||
# via pytest
|
||||
proto-plus==1.24.0
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
@@ -344,7 +331,7 @@ pyasn1-modules==0.4.0
|
||||
# via google-auth
|
||||
pyaudio==0.2.14
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
pydantic==2.8.0
|
||||
pydantic==2.8.2
|
||||
# via
|
||||
# anthropic
|
||||
# fastapi
|
||||
@@ -353,7 +340,7 @@ pydantic==2.8.0
|
||||
# langchain-core
|
||||
# langsmith
|
||||
# openai
|
||||
pydantic-core==2.20.0
|
||||
pydantic-core==2.20.1
|
||||
# via pydantic
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
@@ -363,10 +350,6 @@ pyloudnorm==0.1.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
pyparsing==3.1.2
|
||||
# via httplib2
|
||||
pytest==8.2.2
|
||||
# via pytest-asyncio
|
||||
pytest-asyncio==0.23.7
|
||||
# via cartesia
|
||||
python-dateutil==2.9.0.post0
|
||||
# via openpipe
|
||||
python-dotenv==1.0.1
|
||||
@@ -391,7 +374,6 @@ regex==2024.5.15
|
||||
# transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# cartesia
|
||||
# google-api-core
|
||||
# huggingface-hub
|
||||
# langchain
|
||||
@@ -428,11 +410,11 @@ sqlalchemy==2.0.31
|
||||
# langchain-community
|
||||
starlette==0.37.2
|
||||
# via fastapi
|
||||
sympy==1.12.1
|
||||
sympy==1.13.1
|
||||
# via
|
||||
# onnxruntime
|
||||
# torch
|
||||
tenacity==8.4.2
|
||||
tenacity==8.5.0
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
@@ -446,8 +428,6 @@ tokenizers==0.19.1
|
||||
# anthropic
|
||||
# faster-whisper
|
||||
# transformers
|
||||
tomli==2.0.1
|
||||
# via pytest
|
||||
torch==2.3.1
|
||||
# via
|
||||
# pipecat-ai (pyproject.toml)
|
||||
@@ -489,13 +469,11 @@ typing-extensions==4.12.2
|
||||
# uvicorn
|
||||
typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
ujson==5.10.0
|
||||
# via fastapi
|
||||
uritemplate==4.1.1
|
||||
# via google-api-python-client
|
||||
urllib3==2.2.2
|
||||
# via requests
|
||||
uvicorn[standard]==0.30.1
|
||||
uvicorn[standard]==0.30.3
|
||||
# via fastapi
|
||||
uvloop==0.19.0
|
||||
# via uvicorn
|
||||
@@ -505,7 +483,6 @@ watchfiles==0.22.0
|
||||
# via uvicorn
|
||||
websockets==12.0
|
||||
# via
|
||||
# cartesia
|
||||
# deepgram-sdk
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# uvicorn
|
||||
|
||||
@@ -8,7 +8,6 @@ aiofiles==24.1.0
|
||||
# via deepgram-sdk
|
||||
aiohttp==3.9.5
|
||||
# via
|
||||
# cartesia
|
||||
# deepgram-sdk
|
||||
# langchain
|
||||
# langchain-community
|
||||
@@ -36,17 +35,15 @@ attrs==23.2.0
|
||||
# via
|
||||
# aiohttp
|
||||
# openpipe
|
||||
av==12.2.0
|
||||
av==12.3.0
|
||||
# via faster-whisper
|
||||
azure-cognitiveservices-speech==1.38.0
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
blinker==1.8.2
|
||||
# via flask
|
||||
cachetools==5.3.3
|
||||
cachetools==5.4.0
|
||||
# via google-auth
|
||||
cartesia==1.0.3
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
certifi==2024.6.2
|
||||
certifi==2024.7.4
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
@@ -80,13 +77,11 @@ einops==0.8.0
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
email-validator==2.2.0
|
||||
# via fastapi
|
||||
exceptiongroup==1.2.1
|
||||
# via
|
||||
# anyio
|
||||
# pytest
|
||||
exceptiongroup==1.2.2
|
||||
# via anyio
|
||||
fal-client==0.4.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
fastapi==0.111.0
|
||||
fastapi==0.111.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
fastapi-cli==0.0.4
|
||||
# via fastapi
|
||||
@@ -123,9 +118,9 @@ google-api-core[grpc]==2.19.1
|
||||
# google-ai-generativelanguage
|
||||
# google-api-python-client
|
||||
# google-generativeai
|
||||
google-api-python-client==2.135.0
|
||||
google-api-python-client==2.137.0
|
||||
# via google-generativeai
|
||||
google-auth==2.31.0
|
||||
google-auth==2.32.0
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
# google-api-core
|
||||
@@ -134,13 +129,13 @@ google-auth==2.31.0
|
||||
# google-generativeai
|
||||
google-auth-httplib2==0.2.0
|
||||
# via google-api-python-client
|
||||
google-generativeai==0.7.1
|
||||
google-generativeai==0.7.2
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
googleapis-common-protos==1.63.2
|
||||
# via
|
||||
# google-api-core
|
||||
# grpcio-status
|
||||
grpcio==1.64.1
|
||||
grpcio==1.65.1
|
||||
# via
|
||||
# google-api-core
|
||||
# grpcio-status
|
||||
@@ -162,7 +157,6 @@ httptools==0.6.1
|
||||
httpx==0.27.0
|
||||
# via
|
||||
# anthropic
|
||||
# cartesia
|
||||
# deepgram-sdk
|
||||
# fal-client
|
||||
# fastapi
|
||||
@@ -170,7 +164,7 @@ httpx==0.27.0
|
||||
# openpipe
|
||||
httpx-sse==0.4.0
|
||||
# via fal-client
|
||||
huggingface-hub==0.23.4
|
||||
huggingface-hub==0.24.1
|
||||
# via
|
||||
# faster-whisper
|
||||
# timm
|
||||
@@ -185,8 +179,6 @@ idna==3.7
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jinja2==3.1.4
|
||||
@@ -200,23 +192,23 @@ jsonpatch==1.33
|
||||
# via langchain-core
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
langchain==0.2.6
|
||||
langchain==0.2.11
|
||||
# via
|
||||
# langchain-community
|
||||
# pipecat-ai (pyproject.toml)
|
||||
langchain-community==0.2.6
|
||||
langchain-community==0.2.10
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
langchain-core==0.2.10
|
||||
langchain-core==0.2.23
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langchain-openai
|
||||
# langchain-text-splitters
|
||||
langchain-openai==0.1.10
|
||||
langchain-openai==0.1.17
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
langchain-text-splitters==0.2.2
|
||||
# via langchain
|
||||
langsmith==0.1.83
|
||||
langsmith==0.1.93
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
@@ -262,31 +254,26 @@ numpy==1.26.4
|
||||
# transformers
|
||||
onnxruntime==1.18.1
|
||||
# via faster-whisper
|
||||
openai==1.27.0
|
||||
openai==1.35.15
|
||||
# via
|
||||
# langchain-openai
|
||||
# openpipe
|
||||
# pipecat-ai (pyproject.toml)
|
||||
openpipe==4.16.0
|
||||
openpipe==4.18.0
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
orjson==3.10.5
|
||||
# via
|
||||
# fastapi
|
||||
# langsmith
|
||||
orjson==3.10.6
|
||||
# via langsmith
|
||||
packaging==24.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# langchain-core
|
||||
# marshmallow
|
||||
# onnxruntime
|
||||
# pytest
|
||||
# transformers
|
||||
pillow==10.3.0
|
||||
# via
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# torchvision
|
||||
pluggy==1.5.0
|
||||
# via pytest
|
||||
proto-plus==1.24.0
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
@@ -310,7 +297,7 @@ pyasn1-modules==0.4.0
|
||||
# via google-auth
|
||||
pyaudio==0.2.14
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
pydantic==2.8.0
|
||||
pydantic==2.8.2
|
||||
# via
|
||||
# anthropic
|
||||
# fastapi
|
||||
@@ -319,7 +306,7 @@ pydantic==2.8.0
|
||||
# langchain-core
|
||||
# langsmith
|
||||
# openai
|
||||
pydantic-core==2.20.0
|
||||
pydantic-core==2.20.1
|
||||
# via pydantic
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
@@ -329,10 +316,6 @@ pyloudnorm==0.1.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
pyparsing==3.1.2
|
||||
# via httplib2
|
||||
pytest==8.2.2
|
||||
# via pytest-asyncio
|
||||
pytest-asyncio==0.23.7
|
||||
# via cartesia
|
||||
python-dateutil==2.9.0.post0
|
||||
# via openpipe
|
||||
python-dotenv==1.0.1
|
||||
@@ -357,7 +340,6 @@ regex==2024.5.15
|
||||
# transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# cartesia
|
||||
# google-api-core
|
||||
# huggingface-hub
|
||||
# langchain
|
||||
@@ -394,11 +376,11 @@ sqlalchemy==2.0.31
|
||||
# langchain-community
|
||||
starlette==0.37.2
|
||||
# via fastapi
|
||||
sympy==1.12.1
|
||||
sympy==1.13.1
|
||||
# via
|
||||
# onnxruntime
|
||||
# torch
|
||||
tenacity==8.4.2
|
||||
tenacity==8.5.0
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
@@ -412,8 +394,6 @@ tokenizers==0.19.1
|
||||
# anthropic
|
||||
# faster-whisper
|
||||
# transformers
|
||||
tomli==2.0.1
|
||||
# via pytest
|
||||
torch==2.3.1
|
||||
# via
|
||||
# pipecat-ai (pyproject.toml)
|
||||
@@ -453,13 +433,11 @@ typing-extensions==4.12.2
|
||||
# uvicorn
|
||||
typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
ujson==5.10.0
|
||||
# via fastapi
|
||||
uritemplate==4.1.1
|
||||
# via google-api-python-client
|
||||
urllib3==2.2.2
|
||||
# via requests
|
||||
uvicorn[standard]==0.30.1
|
||||
uvicorn[standard]==0.30.3
|
||||
# via fastapi
|
||||
uvloop==0.19.0
|
||||
# via uvicorn
|
||||
@@ -469,7 +447,6 @@ watchfiles==0.22.0
|
||||
# via uvicorn
|
||||
websockets==12.0
|
||||
# via
|
||||
# cartesia
|
||||
# deepgram-sdk
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# uvicorn
|
||||
|
||||
@@ -43,12 +43,12 @@ examples = [ "python-dotenv~=1.0.0", "flask~=3.0.3", "flask_cors~=4.0.1" ]
|
||||
fal = [ "fal-client~=0.4.1" ]
|
||||
gladia = [ "websockets~=12.0" ]
|
||||
google = [ "google-generativeai~=0.7.1" ]
|
||||
fireworks = [ "openai~=1.27.0" ]
|
||||
langchain = [ "langchain~=0.2.6", "langchain-community~=0.2.6", "langchain-openai~=0.1.10" ]
|
||||
fireworks = [ "openai~=1.35.0" ]
|
||||
langchain = [ "langchain~=0.2.10", "langchain-community~=0.2.9", "langchain-openai~=0.1.17" ]
|
||||
local = [ "pyaudio~=0.2.0" ]
|
||||
moondream = [ "einops~=0.8.0", "timm~=0.9.16", "transformers~=4.40.2" ]
|
||||
openai = [ "openai~=1.27.0" ]
|
||||
openpipe = [ "openpipe~=4.16.0" ]
|
||||
openai = [ "openai~=1.35.0" ]
|
||||
openpipe = [ "openpipe~=4.18.0" ]
|
||||
playht = [ "pyht~=0.0.28" ]
|
||||
silero = [ "torch~=2.3.1", "torchaudio~=2.3.1" ]
|
||||
websocket = [ "websockets~=12.0", "fastapi~=0.111.0" ]
|
||||
|
||||
@@ -212,14 +212,6 @@ class SystemFrame(Frame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StartFrame(SystemFrame):
|
||||
"""This is the first frame that should be pushed down a pipeline."""
|
||||
allow_interruptions: bool = False
|
||||
enable_metrics: bool = False
|
||||
report_only_initial_ttfb: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CancelFrame(SystemFrame):
|
||||
"""Indicates that a pipeline needs to stop right away."""
|
||||
@@ -278,17 +270,6 @@ class BotInterruptionFrame(SystemFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotSpeakingFrame(SystemFrame):
|
||||
"""Emitted by transport outputs while the bot is still speaking. This can be
|
||||
used, for example, to detect when a user is idle. That is, while the bot is
|
||||
speaking we don't want to trigger any user idle timeout since the user might
|
||||
be listening.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricsFrame(SystemFrame):
|
||||
"""Emitted by processor that can compute metrics like latencies.
|
||||
@@ -306,6 +287,14 @@ class ControlFrame(Frame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StartFrame(ControlFrame):
|
||||
"""This is the first frame that should be pushed down a pipeline."""
|
||||
allow_interruptions: bool = False
|
||||
enable_metrics: bool = False
|
||||
report_only_initial_ttfb: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndFrame(ControlFrame):
|
||||
"""Indicates that a pipeline has ended and frame processors and pipelines
|
||||
@@ -348,6 +337,33 @@ class UserStoppedSpeakingFrame(ControlFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotStartedSpeakingFrame(ControlFrame):
|
||||
"""Emitted upstream by transport outputs to indicate the bot started speaking.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotStoppedSpeakingFrame(ControlFrame):
|
||||
"""Emitted upstream by transport outputs to indicate the bot stopped speaking.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotSpeakingFrame(ControlFrame):
|
||||
"""Emitted upstream by transport outputs while the bot is still
|
||||
speaking. This can be used, for example, to detect when a user is idle. That
|
||||
is, while the bot is speaking we don't want to trigger any user idle timeout
|
||||
since the user might be listening.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSStartedFrame(ControlFrame):
|
||||
"""Used to indicate the beginning of a TTS response. Following
|
||||
|
||||
@@ -7,11 +7,13 @@
|
||||
import asyncio
|
||||
import dataclasses
|
||||
|
||||
from typing import List, Literal, Optional, Type
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Type
|
||||
from pydantic import PrivateAttr, BaseModel, ValidationError
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -33,62 +35,76 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.ai_services import AIService
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.openai import OpenAILLMService, OpenAILLMContext
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
|
||||
DEFAULT_MESSAGES = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
}
|
||||
]
|
||||
|
||||
DEFAULT_MODEL = "llama3-70b-8192"
|
||||
|
||||
DEFAULT_VOICE = "79a125e8-cd45-4c13-8a67-188112f4dd22"
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class RTVILLMConfig(BaseModel):
|
||||
model: Optional[str] = None
|
||||
messages: Optional[List[dict]] = None
|
||||
class RTVIServiceOption(BaseModel):
|
||||
name: str
|
||||
handler: Optional[Callable[['RTVIProcessor',
|
||||
'RTVIServiceOptionConfig'],
|
||||
Awaitable[None]]] = None
|
||||
|
||||
|
||||
class RTVITTSConfig(BaseModel):
|
||||
voice: Optional[str] = None
|
||||
class RTVIService(BaseModel):
|
||||
name: str
|
||||
cls: Type[FrameProcessor]
|
||||
options: List[RTVIServiceOption]
|
||||
_options_dict: Dict[str, RTVIServiceOption] = PrivateAttr(default={})
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._options_dict = {}
|
||||
for option in self.options:
|
||||
self._options_dict[option.name] = option
|
||||
return super().model_post_init(__context)
|
||||
|
||||
#
|
||||
# Client -> Pipecat messages.
|
||||
#
|
||||
|
||||
|
||||
class RTVIServiceOptionConfig(BaseModel):
|
||||
name: str
|
||||
value: Any
|
||||
|
||||
|
||||
class RTVIServiceConfig(BaseModel):
|
||||
service: str
|
||||
options: List[RTVIServiceOptionConfig]
|
||||
|
||||
|
||||
class RTVIConfig(BaseModel):
|
||||
llm: Optional[RTVILLMConfig] = None
|
||||
tts: Optional[RTVITTSConfig] = None
|
||||
config: List[RTVIServiceConfig]
|
||||
_config_dict: Dict[str, RTVIServiceConfig] = PrivateAttr(default={})
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._config_dict = {}
|
||||
for c in self.config:
|
||||
self._config_dict[c.service] = c
|
||||
return super().model_post_init(__context)
|
||||
|
||||
|
||||
class RTVISetup(BaseModel):
|
||||
config: Optional[RTVIConfig] = None
|
||||
|
||||
|
||||
class RTVILLMMessageData(BaseModel):
|
||||
class RTVILLMContextData(BaseModel):
|
||||
messages: List[dict]
|
||||
|
||||
|
||||
class RTVITTSMessageData(BaseModel):
|
||||
class RTVITTSSpeakData(BaseModel):
|
||||
text: str
|
||||
interrupt: Optional[bool] = False
|
||||
|
||||
|
||||
class RTVIMessageData(BaseModel):
|
||||
setup: Optional[RTVISetup] = None
|
||||
config: Optional[RTVIConfig] = None
|
||||
llm: Optional[RTVILLMMessageData] = None
|
||||
tts: Optional[RTVITTSMessageData] = None
|
||||
|
||||
|
||||
class RTVIMessage(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: str
|
||||
id: str
|
||||
data: Optional[RTVIMessageData] = None
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
#
|
||||
# Pipecat -> Client responses and messages.
|
||||
#
|
||||
|
||||
|
||||
class RTVIResponseData(BaseModel):
|
||||
@@ -97,7 +113,7 @@ class RTVIResponseData(BaseModel):
|
||||
|
||||
|
||||
class RTVIResponse(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["response"] = "response"
|
||||
id: str
|
||||
data: RTVIResponseData
|
||||
@@ -108,7 +124,7 @@ class RTVIErrorData(BaseModel):
|
||||
|
||||
|
||||
class RTVIError(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["error"] = "error"
|
||||
data: RTVIErrorData
|
||||
|
||||
@@ -118,7 +134,7 @@ class RTVILLMContextMessageData(BaseModel):
|
||||
|
||||
|
||||
class RTVILLMContextMessage(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["llm-context"] = "llm-context"
|
||||
data: RTVILLMContextMessageData
|
||||
|
||||
@@ -128,13 +144,13 @@ class RTVITTSTextMessageData(BaseModel):
|
||||
|
||||
|
||||
class RTVITTSTextMessage(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["tts-text"] = "tts-text"
|
||||
data: RTVITTSTextMessageData
|
||||
|
||||
|
||||
class RTVIBotReady(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-ready"] = "bot-ready"
|
||||
|
||||
|
||||
@@ -146,23 +162,23 @@ class RTVITranscriptionMessageData(BaseModel):
|
||||
|
||||
|
||||
class RTVITranscriptionMessage(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-transcription"] = "user-transcription"
|
||||
data: RTVITranscriptionMessageData
|
||||
|
||||
|
||||
class RTVIUserStartedSpeakingMessage(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-started-speaking"] = "user-started-speaking"
|
||||
|
||||
|
||||
class RTVIUserStoppedSpeakingMessage(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-stopped-speaking"] = "user-stopped-speaking"
|
||||
|
||||
|
||||
class RTVIJSONCompletion(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["json-completion"] = "json-completion"
|
||||
data: str
|
||||
|
||||
@@ -265,59 +281,128 @@ class RTVITTSTextProcessor(FrameProcessor):
|
||||
await self.push_frame(TransportMessageFrame(message=message.model_dump(exclude_none=True)))
|
||||
|
||||
|
||||
async def handle_llm_model_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
|
||||
frame = LLMModelUpdateFrame(option.value)
|
||||
await rtvi.push_frame(frame)
|
||||
|
||||
|
||||
async def handle_llm_messages_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
|
||||
frame = LLMMessagesUpdateFrame(option.value)
|
||||
await rtvi.push_frame(frame)
|
||||
|
||||
|
||||
async def handle_tts_voice_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
|
||||
frame = TTSVoiceUpdateFrame(option.value)
|
||||
await rtvi.push_frame(frame)
|
||||
|
||||
DEFAULT_LLM_SERVICE = RTVIService(
|
||||
name="llm",
|
||||
cls=OpenAILLMService,
|
||||
options=[
|
||||
RTVIServiceOption(name="model", handler=handle_llm_model_update),
|
||||
RTVIServiceOption(name="messages", handler=handle_llm_messages_update)
|
||||
])
|
||||
|
||||
DEFAULT_TTS_SERVICE = RTVIService(
|
||||
name="tts",
|
||||
cls=CartesiaTTSService,
|
||||
options=[
|
||||
RTVIServiceOption(name="voice_id", handler=handle_tts_voice_update),
|
||||
])
|
||||
|
||||
|
||||
class RTVIProcessor(FrameProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
transport: BaseTransport,
|
||||
setup: RTVISetup | None = None,
|
||||
llm_api_key: str = "",
|
||||
llm_base_url: str = "https://api.groq.com/openai/v1",
|
||||
tts_api_key: str = "",
|
||||
llm_cls: Type[AIService] = OpenAILLMService,
|
||||
tts_cls: Type[AIService] = CartesiaTTSService):
|
||||
def __init__(self, *, transport: BaseTransport):
|
||||
super().__init__()
|
||||
self._transport = transport
|
||||
self._setup = setup
|
||||
self._llm_api_key = llm_api_key
|
||||
self._llm_base_url = llm_base_url
|
||||
self._tts_api_key = tts_api_key
|
||||
self._llm_cls = llm_cls
|
||||
self._tts_cls = tts_cls
|
||||
self._config: RTVIConfig | None = None
|
||||
self._ctor_args: Dict[str, Any] = {}
|
||||
|
||||
self._start_frame: Frame | None = None
|
||||
self._llm: FrameProcessor | None = None
|
||||
self._tts: FrameProcessor | None = None
|
||||
self._pipeline: FrameProcessor | None = None
|
||||
self._first_participant_joined: bool = False
|
||||
|
||||
# Register transport event so we can send a `bot-ready` event (and maybe
|
||||
# others) when the participant joins.
|
||||
transport.add_event_handler(
|
||||
"on_first_participant_joined",
|
||||
self._on_first_participant_joined)
|
||||
|
||||
# Register default services.
|
||||
self._registered_services: Dict[str, RTVIService] = {}
|
||||
self.register_service(DEFAULT_LLM_SERVICE)
|
||||
self.register_service(DEFAULT_TTS_SERVICE)
|
||||
|
||||
self._frame_handler_task = self.get_event_loop().create_task(self._frame_handler())
|
||||
self._frame_queue = asyncio.Queue()
|
||||
|
||||
def register_service(self, service: RTVIService):
|
||||
self._registered_services[service.name] = service
|
||||
|
||||
def setup_on_start(self, config: RTVIConfig | None, ctor_args: Dict[str, Any]):
|
||||
self._config = config
|
||||
self._ctor_args = ctor_args
|
||||
|
||||
async def update_config(self, config: RTVIConfig):
|
||||
if self._pipeline:
|
||||
await self._handle_config_update(config)
|
||||
self._config = config
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, SystemFrame):
|
||||
# Specific system frames
|
||||
if isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
# Control frames
|
||||
elif isinstance(frame, StartFrame):
|
||||
await self._start(frame)
|
||||
await self._internal_push_frame(frame, direction)
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Push EndFrame before stop(), because stop() waits on the task to
|
||||
# finish and the task finishes when EndFrame is processed.
|
||||
await self._internal_push_frame(frame, direction)
|
||||
await self._stop(frame)
|
||||
# Other frames
|
||||
else:
|
||||
await self._frame_queue.put((frame, direction))
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
self._start_frame = frame
|
||||
try:
|
||||
await self._handle_setup(self._setup)
|
||||
except Exception as e:
|
||||
await self._send_error(f"unable to setup RTVI: {e}")
|
||||
await self._internal_push_frame(frame, direction)
|
||||
|
||||
async def cleanup(self):
|
||||
if self._pipeline:
|
||||
await self._pipeline.cleanup()
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
try:
|
||||
await self._handle_pipeline_setup(frame, self._config)
|
||||
except Exception as e:
|
||||
await self._send_error(f"unable to setup RTVI pipeline: {e}")
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
await self._frame_handler_task
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
self._frame_handler_task.cancel()
|
||||
await self._frame_handler_task
|
||||
|
||||
async def _internal_push_frame(
|
||||
self,
|
||||
frame: Frame | None,
|
||||
direction: FrameDirection | None = FrameDirection.DOWNSTREAM):
|
||||
await self._frame_queue.put((frame, direction))
|
||||
|
||||
async def _frame_handler(self):
|
||||
while True:
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._frame_queue.get()
|
||||
await self._handle_frame(frame, direction)
|
||||
self._frame_queue.task_done()
|
||||
running = not isinstance(frame, EndFrame)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
@@ -372,113 +457,102 @@ class RTVIProcessor(FrameProcessor):
|
||||
try:
|
||||
message = RTVIMessage.model_validate(frame.message)
|
||||
except ValidationError as e:
|
||||
await self._send_error(f"invalid message: {e}")
|
||||
await self._send_error(f"Invalid incoming message: {e}")
|
||||
logger.warning(f"Invalid incoming message: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
success = True
|
||||
error = None
|
||||
match message.type:
|
||||
case "setup":
|
||||
setup = None
|
||||
if message.data:
|
||||
setup = message.data.setup
|
||||
await self._handle_setup(message.id, setup)
|
||||
case "config-update":
|
||||
await self._handle_config_update(message.data.config)
|
||||
await self._handle_config_update(RTVIConfig.model_validate(message.data))
|
||||
case "llm-get-context":
|
||||
await self._handle_llm_get_context()
|
||||
case "llm-append-context":
|
||||
await self._handle_llm_append_context(message.data.llm)
|
||||
await self._handle_llm_append_context(RTVILLMContextData.model_validate(message.data))
|
||||
case "llm-update-context":
|
||||
await self._handle_llm_update_context(message.data.llm)
|
||||
await self._handle_llm_update_context(RTVILLMContextData.model_validate(message.data))
|
||||
case "tts-speak":
|
||||
await self._handle_tts_speak(message.data.tts)
|
||||
await self._handle_tts_speak(RTVITTSSpeakData.model_validate(message.data))
|
||||
case "tts-interrupt":
|
||||
await self._handle_tts_interrupt()
|
||||
case _:
|
||||
success = False
|
||||
error = f"unsupported type {message.type}"
|
||||
error = f"Unsupported type {message.type}"
|
||||
|
||||
await self._send_response(message.id, success, error)
|
||||
except ValidationError as e:
|
||||
await self._send_response(message.id, False, f"invalid message: {e}")
|
||||
await self._send_response(message.id, False, f"Invalid incoming message: {e}")
|
||||
logger.warning(f"Invalid incoming message: {e}")
|
||||
except Exception as e:
|
||||
await self._send_response(message.id, False, f"{e}")
|
||||
await self._send_response(message.id, False, f"Exception processing message: {e}")
|
||||
logger.warning(f"Exception processing message: {e}")
|
||||
|
||||
async def _handle_setup(self, setup: RTVISetup | None):
|
||||
model = DEFAULT_MODEL
|
||||
if setup and setup.config and setup.config.llm and setup.config.llm.model:
|
||||
model = setup.config.llm.model
|
||||
async def _handle_pipeline_setup(self, start_frame: StartFrame, config: RTVIConfig | None):
|
||||
# TODO(aleix): We shouldn't need to save this in `self._tma_in`.
|
||||
self._tma_in = LLMUserResponseAggregator()
|
||||
tma_out = LLMAssistantResponseAggregator()
|
||||
|
||||
messages = DEFAULT_MESSAGES
|
||||
if setup and setup.config and setup.config.llm and setup.config.llm.messages:
|
||||
messages = setup.config.llm.messages
|
||||
llm_cls = self._registered_services["llm"].cls
|
||||
llm_args = self._ctor_args["llm"]
|
||||
llm = llm_cls(**llm_args)
|
||||
|
||||
voice = DEFAULT_VOICE
|
||||
if setup and setup.config and setup.config.tts and setup.config.tts.voice:
|
||||
voice = setup.config.tts.voice
|
||||
|
||||
self._tma_in = LLMUserResponseAggregator(messages)
|
||||
self._tma_out = LLMAssistantResponseAggregator(messages)
|
||||
|
||||
self._llm = self._llm_cls(
|
||||
name="LLM",
|
||||
base_url=self._llm_base_url,
|
||||
api_key=self._llm_api_key,
|
||||
model=model)
|
||||
|
||||
self._tts = self._tts_cls(name="TTS", api_key=self._tts_api_key, voice_id=voice)
|
||||
tts_cls = self._registered_services["tts"].cls
|
||||
tts_args = self._ctor_args["tts"]
|
||||
tts = tts_cls(**tts_args)
|
||||
|
||||
# TODO-CB: Eventually we'll need to switch the context aggregators to use the
|
||||
# OpenAI context frames instead of message frames
|
||||
context = OpenAILLMContext(messages=messages)
|
||||
self._fc = FunctionCaller(context)
|
||||
context = OpenAILLMContext()
|
||||
fc = FunctionCaller(context)
|
||||
|
||||
self._tts_text = RTVITTSTextProcessor()
|
||||
tts_text = RTVITTSTextProcessor()
|
||||
|
||||
pipeline = Pipeline([
|
||||
self._tma_in,
|
||||
self._llm,
|
||||
self._fc,
|
||||
self._tts,
|
||||
self._tts_text,
|
||||
self._tma_out,
|
||||
llm,
|
||||
fc,
|
||||
tts,
|
||||
tts_text,
|
||||
tma_out,
|
||||
self._transport.output(),
|
||||
])
|
||||
self._pipeline = pipeline
|
||||
|
||||
parent = self.get_parent()
|
||||
if parent and self._start_frame:
|
||||
if parent:
|
||||
parent.link(pipeline)
|
||||
|
||||
# We need to initialize the new pipeline with the same settings
|
||||
# as the initial one.
|
||||
start_frame = dataclasses.replace(self._start_frame)
|
||||
start_frame = dataclasses.replace(start_frame)
|
||||
await self.push_frame(start_frame)
|
||||
|
||||
# Configure the pipeline
|
||||
if config:
|
||||
await self._handle_config_update(config)
|
||||
|
||||
# Send new initial metrics with the new processors
|
||||
processors = parent.processors_with_metrics()
|
||||
processors.extend(self._pipeline.processors_with_metrics())
|
||||
processors.extend(pipeline.processors_with_metrics())
|
||||
ttfb = [{"processor": p.name, "value": 0.0} for p in processors]
|
||||
processing = [{"processor": p.name, "value": 0.0} for p in processors]
|
||||
await self.push_frame(MetricsFrame(ttfb=ttfb, processing=processing))
|
||||
|
||||
message = RTVIBotReady()
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
self._pipeline = pipeline
|
||||
|
||||
async def _handle_config_update(self, config: RTVIConfig):
|
||||
# Change voice before LLM updates, so we can hear the new vocie.
|
||||
if config.tts and config.tts.voice:
|
||||
frame = TTSVoiceUpdateFrame(config.tts.voice)
|
||||
await self.push_frame(frame)
|
||||
if config.llm and config.llm.model:
|
||||
frame = LLMModelUpdateFrame(config.llm.model)
|
||||
await self.push_frame(frame)
|
||||
if config.llm and config.llm.messages:
|
||||
frame = LLMMessagesUpdateFrame(config.llm.messages)
|
||||
await self.push_frame(frame)
|
||||
await self._maybe_send_bot_ready()
|
||||
|
||||
async def _handle_config_service(self, config: RTVIServiceConfig):
|
||||
service = self._registered_services[config.service]
|
||||
for option in config.options:
|
||||
handler = service._options_dict[option.name].handler
|
||||
if handler:
|
||||
await handler(self, option)
|
||||
|
||||
async def _handle_config_update(self, data: RTVIConfig):
|
||||
for config in data.config:
|
||||
await self._handle_config_service(config)
|
||||
|
||||
async def _handle_llm_get_context(self):
|
||||
data = RTVILLMContextMessageData(messages=self._tma_in.messages)
|
||||
@@ -486,17 +560,17 @@ class RTVIProcessor(FrameProcessor):
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_llm_append_context(self, data: RTVILLMMessageData):
|
||||
async def _handle_llm_append_context(self, data: RTVILLMContextData):
|
||||
if data and data.messages:
|
||||
frame = LLMMessagesAppendFrame(data.messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_llm_update_context(self, data: RTVILLMMessageData):
|
||||
async def _handle_llm_update_context(self, data: RTVILLMContextData):
|
||||
if data and data.messages:
|
||||
frame = LLMMessagesUpdateFrame(data.messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_tts_speak(self, data: RTVITTSMessageData):
|
||||
async def _handle_tts_speak(self, data: RTVITTSSpeakData):
|
||||
if data and data.text:
|
||||
if data.interrupt:
|
||||
await self._handle_tts_interrupt()
|
||||
@@ -506,6 +580,16 @@ class RTVIProcessor(FrameProcessor):
|
||||
async def _handle_tts_interrupt(self):
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
async def _on_first_participant_joined(self, transport, participant):
|
||||
self._first_participant_joined = True
|
||||
await self._maybe_send_bot_ready()
|
||||
|
||||
async def _maybe_send_bot_ready(self):
|
||||
if self._pipeline and self._first_participant_joined:
|
||||
message = RTVIBotReady()
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _send_error(self, error: str):
|
||||
message = RTVIError(data=RTVIErrorData(message=error))
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
@@ -523,7 +607,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
self._pipeline = pipeline
|
||||
|
||||
parent = self.get_parent()
|
||||
if parent and self._start_frame:
|
||||
if parent:
|
||||
parent.link(pipeline)
|
||||
|
||||
message = RTVIResponse(id=id, data=RTVIResponseData(success=success, error=error))
|
||||
|
||||
@@ -8,7 +8,12 @@ import asyncio
|
||||
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from pipecat.frames.frames import BotSpeakingFrame, Frame, StartInterruptionFrame, StopInterruptionFrame, SystemFrame
|
||||
from pipecat.frames.frames import (
|
||||
BotSpeakingFrame,
|
||||
Frame,
|
||||
SystemFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame)
|
||||
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
@@ -47,10 +52,10 @@ class UserIdleProcessor(AsyncFrameProcessor):
|
||||
await self.queue_frame(frame, direction)
|
||||
|
||||
# We shouldn't call the idle callback if the user or the bot are speaking.
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
self._interrupted = True
|
||||
self._idle_event.set()
|
||||
elif isinstance(frame, StopInterruptionFrame):
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
self._interrupted = False
|
||||
self._idle_event.set()
|
||||
elif isinstance(frame, BotSpeakingFrame):
|
||||
|
||||
@@ -283,14 +283,17 @@ class STTService(AIService):
|
||||
await self.stop_processing_metrics()
|
||||
(self._content, self._wave) = self._new_wave()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
self._wave.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
self._wave.close()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Processes a frame of audio data, either buffering or transcribing it."""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame):
|
||||
self._wave.close()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
# In this service we accumulate audio internally and at the end we
|
||||
# push a TextFrame. We don't really want to push audio frames down.
|
||||
await self._append_audio(frame)
|
||||
|
||||
@@ -147,13 +147,16 @@ class AzureSTTService(AsyncAIService):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
self._speech_recognizer.stop_continuous_recognition_async()
|
||||
self._audio_stream.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
self._speech_recognizer.stop_continuous_recognition_async()
|
||||
self._audio_stream.close()
|
||||
|
||||
@@ -168,12 +171,12 @@ class AzureImageGenServiceREST(ImageGenService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
image_size: str,
|
||||
api_key: str,
|
||||
endpoint: str,
|
||||
model: str,
|
||||
api_version="2023-06-01-preview",
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -181,8 +184,14 @@ class AzureImageGenServiceREST(ImageGenService):
|
||||
self._azure_endpoint = endpoint
|
||||
self._api_version = api_version
|
||||
self._model = model
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._image_size = image_size
|
||||
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
self._close_aiohttp_session = aiohttp_session is None
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._close_aiohttp_session:
|
||||
await self._aiohttp_session.close()
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}"
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import AsyncGenerator
|
||||
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
Frame,
|
||||
AudioRawFrame,
|
||||
StartInterruptionFrame,
|
||||
@@ -98,6 +99,10 @@ class CartesiaTTSService(TTSService):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
try:
|
||||
self._websocket = await websockets.connect(
|
||||
@@ -111,6 +116,8 @@ class CartesiaTTSService(TTSService):
|
||||
|
||||
async def _disconnect(self):
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._context_appending_task:
|
||||
self._context_appending_task.cancel()
|
||||
await self._context_appending_task
|
||||
@@ -120,13 +127,12 @@ class CartesiaTTSService(TTSService):
|
||||
await self._receive_task
|
||||
self._receive_task = None
|
||||
if self._websocket:
|
||||
ws = self._websocket
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
await ws.close()
|
||||
|
||||
self._context_id = None
|
||||
self._context_id_start_timestamp = None
|
||||
self._timestamped_words_buffer = []
|
||||
await self.stop_all_metrics()
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error closing websocket: {e}")
|
||||
|
||||
@@ -142,13 +148,13 @@ class CartesiaTTSService(TTSService):
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
msg = json.loads(message)
|
||||
# logger.debug(f"Received message: {msg['type']} {msg['context_id']}")
|
||||
if not msg or msg["context_id"] != self._context_id:
|
||||
continue
|
||||
if msg["type"] == "done":
|
||||
await self.stop_ttfb_metrics()
|
||||
# unset _context_id but not the _context_id_start_timestamp because we are likely still
|
||||
# playing out audio and need the timestamp to set send context frames
|
||||
# Unset _context_id but not the _context_id_start_timestamp
|
||||
# because we are likely still playing out audio and need the
|
||||
# timestamp to set send context frames.
|
||||
self._context_id = None
|
||||
self._timestamped_words_buffer.append(("LLMFullResponseEndFrame", 0))
|
||||
elif msg["type"] == "timestamps":
|
||||
@@ -166,6 +172,8 @@ class CartesiaTTSService(TTSService):
|
||||
num_channels=1
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
|
||||
@@ -176,15 +184,17 @@ class CartesiaTTSService(TTSService):
|
||||
if not self._context_id_start_timestamp:
|
||||
continue
|
||||
elapsed_seconds = time.time() - self._context_id_start_timestamp
|
||||
# pop all words from self._timestamped_words_buffer that are older than the
|
||||
# elapsed time and print a message about them to the console
|
||||
# Pop all words from self._timestamped_words_buffer that are
|
||||
# older than the elapsed time and print a message about them to
|
||||
# the console.
|
||||
while self._timestamped_words_buffer and self._timestamped_words_buffer[0][1] <= elapsed_seconds:
|
||||
word, timestamp = self._timestamped_words_buffer.pop(0)
|
||||
if word == "LLMFullResponseEndFrame" and timestamp == 0:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
continue
|
||||
# print(f"Word '{word}' with timestamp {timestamp:.2f}s has been spoken.")
|
||||
await self.push_frame(TextFrame(word))
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
|
||||
@@ -212,7 +222,6 @@ class CartesiaTTSService(TTSService):
|
||||
"language": self._language,
|
||||
"add_timestamps": True,
|
||||
}
|
||||
# logger.debug(f"SENDING MESSAGE {json.dumps(msg)}")
|
||||
try:
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
except Exception as e:
|
||||
|
||||
@@ -45,21 +45,31 @@ class DeepgramTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
api_key: str,
|
||||
voice: str = "aura-helios-en",
|
||||
base_url: str = "https://api.deepgram.com/v1/speak",
|
||||
sample_rate: int = 16000,
|
||||
encoding: str = "linear16",
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._voice = voice
|
||||
self._api_key = api_key
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._base_url = base_url
|
||||
self._sample_rate = sample_rate
|
||||
self._encoding = encoding
|
||||
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
self._close_aiohttp_session = aiohttp_session is None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._close_aiohttp_session:
|
||||
await self._aiohttp_session.close()
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = voice
|
||||
@@ -68,7 +78,7 @@ class DeepgramTTSService(TTSService):
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
base_url = self._base_url
|
||||
request_url = f"{base_url}?model={self._voice}&encoding=linear16&container=none&sample_rate=16000"
|
||||
request_url = f"{base_url}?model={self._voice}&encoding={self._encoding}&container=none&sample_rate={self._sample_rate}"
|
||||
headers = {"authorization": f"token {self._api_key}"}
|
||||
body = {"text": text}
|
||||
|
||||
@@ -91,7 +101,7 @@ class DeepgramTTSService(TTSService):
|
||||
|
||||
async for data in r.content:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = AudioRawFrame(audio=data, sample_rate=16000, num_channels=1)
|
||||
frame = AudioRawFrame(audio=data, sample_rate=self._sample_rate, num_channels=1)
|
||||
yield frame
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
@@ -132,15 +142,18 @@ class DeepgramSTTService(AsyncAIService):
|
||||
await self.queue_frame(frame, direction)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
if await self._connection.start(self._live_options):
|
||||
logger.debug(f"{self}: Connected to Deepgram")
|
||||
else:
|
||||
logger.error(f"{self}: Unable to connect to Deepgram")
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._connection.finish()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._connection.finish()
|
||||
|
||||
async def _on_message(self, *args, **kwargs):
|
||||
|
||||
@@ -19,21 +19,27 @@ class ElevenLabsTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
model: str = "eleven_turbo_v2",
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._model = model
|
||||
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
self._close_aiohttp_session = aiohttp_session is None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._close_aiohttp_session:
|
||||
await self._aiohttp_session.close()
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
@@ -39,18 +39,24 @@ class FalImageGenService(ImageGenService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
params: InputParams,
|
||||
model: str = "fal-ai/fast-sdxl",
|
||||
key: str | None = None,
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
self._params = params
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
self._close_aiohttp_session = aiohttp_session is None
|
||||
if key:
|
||||
os.environ["FAL_KEY"] = key
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._close_aiohttp_session:
|
||||
await self._aiohttp_session.close()
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating image from prompt: {prompt}")
|
||||
|
||||
|
||||
@@ -68,14 +68,17 @@ class GladiaSTTService(AsyncAIService):
|
||||
await self.queue_frame(frame, direction)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._websocket = await websockets.connect(self._url)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
await self._setup_gladia()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._websocket.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._websocket.close()
|
||||
|
||||
async def _setup_gladia(self):
|
||||
|
||||
@@ -253,16 +253,22 @@ class OpenAIImageGenService(ImageGenService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
api_key: str,
|
||||
model: str = "dall-e-3",
|
||||
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
self._image_size = image_size
|
||||
self._client = AsyncOpenAI(api_key=api_key)
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
self._close_aiohttp_session = aiohttp_session is None
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._close_aiohttp_session:
|
||||
await self._aiohttp_session.close()
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating image from prompt: {prompt}")
|
||||
|
||||
@@ -38,22 +38,28 @@ class XTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
voice_id: str,
|
||||
language: str,
|
||||
base_url: str,
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._voice_id = voice_id
|
||||
self._language = language
|
||||
self._base_url = base_url
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._studio_speakers = requests.get(self._base_url + "/studio_speakers").json()
|
||||
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
self._close_aiohttp_session = aiohttp_session is None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._close_aiohttp_session:
|
||||
await self._aiohttp_session.close()
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
@@ -18,6 +18,7 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
SystemFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame)
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
@@ -45,12 +46,26 @@ class BaseInputTransport(FrameProcessor):
|
||||
self._audio_in_queue = asyncio.Queue()
|
||||
self._audio_task = self.get_event_loop().create_task(self._audio_task_handler())
|
||||
|
||||
async def stop(self):
|
||||
# Wait for the task to finish.
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Cancel and wait for the audio input task to finish.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_task.cancel()
|
||||
await self._audio_task
|
||||
|
||||
# Wait for the push frame task to finish. It will finish when the
|
||||
# EndFrame is actually processed.
|
||||
await self._push_frame_task
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Cancel all the tasks and wait for them to finish.
|
||||
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_task.cancel()
|
||||
await self._audio_task
|
||||
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
def vad_analyzer(self) -> VADAnalyzer | None:
|
||||
return self._params.vad_analyzer
|
||||
|
||||
@@ -62,25 +77,32 @@ class BaseInputTransport(FrameProcessor):
|
||||
# Frame processor
|
||||
#
|
||||
|
||||
async def cleanup(self):
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Specific system frames
|
||||
if isinstance(frame, CancelFrame):
|
||||
await self.stop()
|
||||
# We don't queue a CancelFrame since we want to stop ASAP.
|
||||
await self.cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotInterruptionFrame):
|
||||
await self._handle_interruptions(frame, False)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
await self._start_interruption()
|
||||
elif isinstance(frame, StopInterruptionFrame):
|
||||
await self._stop_interruption()
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
# Control frames
|
||||
elif isinstance(frame, StartFrame):
|
||||
await self.start(frame)
|
||||
await self._internal_push_frame(frame, direction)
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Push EndFrame before stop(), because stop() waits on the task to
|
||||
# finish and the task finishes when EndFrame is processed.
|
||||
await self._internal_push_frame(frame, direction)
|
||||
await self.stop()
|
||||
elif isinstance(frame, BotInterruptionFrame):
|
||||
await self._handle_interruptions(frame, False)
|
||||
await self.stop(frame)
|
||||
# Other frames
|
||||
else:
|
||||
await self._internal_push_frame(frame, direction)
|
||||
|
||||
@@ -100,10 +122,12 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
while True:
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await self.push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
@@ -113,6 +137,9 @@ class BaseInputTransport(FrameProcessor):
|
||||
#
|
||||
|
||||
async def _start_interruption(self):
|
||||
if not self.interruptions_allowed:
|
||||
return
|
||||
|
||||
# Cancel the task. This will stop pushing frames downstream.
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
@@ -124,6 +151,9 @@ class BaseInputTransport(FrameProcessor):
|
||||
self._create_push_task()
|
||||
|
||||
async def _stop_interruption(self):
|
||||
if not self.interruptions_allowed:
|
||||
return
|
||||
|
||||
await self.push_frame(StopInterruptionFrame())
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame, push_frame: bool):
|
||||
|
||||
@@ -15,6 +15,8 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
BotSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
MetricsFrame,
|
||||
SpriteFrame,
|
||||
@@ -25,6 +27,8 @@ from pipecat.frames.frames import (
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
SystemFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TransportMessageFrame)
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
@@ -60,18 +64,34 @@ class BaseOutputTransport(FrameProcessor):
|
||||
self._create_push_task()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
# Create media threads queues.
|
||||
# Create camera output queue and task if needed.
|
||||
if self._params.camera_out_enabled:
|
||||
self._camera_out_queue = asyncio.Queue()
|
||||
self._camera_out_task = self.get_event_loop().create_task(self._camera_out_task_handler())
|
||||
|
||||
async def stop(self):
|
||||
# Wait on the threads to finish.
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Cancel and wait for the camera output task to finish.
|
||||
if self._params.camera_out_enabled:
|
||||
self._camera_out_task.cancel()
|
||||
await self._camera_out_task
|
||||
|
||||
self._stopped_event.set()
|
||||
# Wait for the push frame and sink tasks to finish. They will finish when
|
||||
# the EndFrame is actually processed.
|
||||
await self._push_frame_task
|
||||
await self._sink_task
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Cancel all the tasks and wait for them to finish.
|
||||
|
||||
if self._params.camera_out_enabled:
|
||||
self._camera_out_task.cancel()
|
||||
await self._camera_out_task
|
||||
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
self._sink_task.cancel()
|
||||
await self._sink_task
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
pass
|
||||
@@ -89,48 +109,38 @@ class BaseOutputTransport(FrameProcessor):
|
||||
# Frame processor
|
||||
#
|
||||
|
||||
async def cleanup(self):
|
||||
if self._sink_task:
|
||||
self._sink_task.cancel()
|
||||
await self._sink_task
|
||||
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
#
|
||||
# Out-of-band frames like (CancelFrame or StartInterruptionFrame) are
|
||||
# pushed immediately. Other frames require order so they are put in the
|
||||
# sink queue.
|
||||
# System frames (like StartInterruptionFrame) are pushed
|
||||
# immediately. Other frames require order so they are put in the sink
|
||||
# queue.
|
||||
#
|
||||
if isinstance(frame, StartFrame):
|
||||
await self.start(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
# EndFrame is managed in the sink queue handler.
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.stop()
|
||||
if isinstance(frame, CancelFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self.cancel(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self._handle_interruptions(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
await self.send_metrics(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
await self.send_metrics(frame)
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
# Control frames.
|
||||
elif isinstance(frame, StartFrame):
|
||||
await self._sink_queue.put(frame)
|
||||
await self.start(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._sink_queue.put(frame)
|
||||
await self.stop(frame)
|
||||
# Other frames.
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
await self._handle_audio(frame)
|
||||
else:
|
||||
await self._sink_queue.put(frame)
|
||||
|
||||
# If we are finishing, wait here until we have stopped, otherwise we might
|
||||
# close things too early upstream. We need this event because we don't
|
||||
# know when the internal threads will finish.
|
||||
if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame):
|
||||
await self._stopped_event.wait()
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
if not self.interruptions_allowed:
|
||||
return
|
||||
@@ -160,7 +170,9 @@ class BaseOutputTransport(FrameProcessor):
|
||||
async def _sink_task_handler(self):
|
||||
# Audio accumlation buffer
|
||||
buffer = bytearray()
|
||||
while True:
|
||||
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
frame = await self._sink_queue.get()
|
||||
if isinstance(frame, AudioRawFrame) and self._params.audio_out_enabled:
|
||||
@@ -172,11 +184,16 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await self._set_camera_images(frame.images)
|
||||
elif isinstance(frame, TransportMessageFrame):
|
||||
await self.send_message(frame)
|
||||
elif isinstance(frame, TTSStartedFrame):
|
||||
await self._internal_push_frame(BotStartedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
await self._internal_push_frame(frame)
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
await self._internal_push_frame(BotStoppedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
await self._internal_push_frame(frame)
|
||||
else:
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
if isinstance(frame, EndFrame):
|
||||
await self.stop()
|
||||
running = not isinstance(frame, EndFrame)
|
||||
|
||||
self._sink_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
@@ -200,10 +217,12 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
while True:
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await self.push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
@@ -60,20 +60,20 @@ class BaseTransport(ABC):
|
||||
|
||||
def event_handler(self, event_name: str):
|
||||
def decorator(handler):
|
||||
self._add_event_handler(event_name, handler)
|
||||
self.add_event_handler(event_name, handler)
|
||||
return handler
|
||||
return decorator
|
||||
|
||||
def add_event_handler(self, event_name: str, handler):
|
||||
if event_name not in self._event_handlers:
|
||||
raise Exception(f"Event handler {event_name} not registered")
|
||||
self._event_handlers[event_name].append(handler)
|
||||
|
||||
def _register_event_handler(self, event_name: str):
|
||||
if event_name in self._event_handlers:
|
||||
raise Exception(f"Event handler {event_name} already registered")
|
||||
self._event_handlers[event_name] = []
|
||||
|
||||
def _add_event_handler(self, event_name: str, handler):
|
||||
if event_name not in self._event_handlers:
|
||||
raise Exception(f"Event handler {event_name} not registered")
|
||||
self._event_handlers[event_name].append(handler)
|
||||
|
||||
async def _call_event_handler(self, event_name: str, *args, **kwargs):
|
||||
try:
|
||||
for handler in self._event_handlers[event_name]:
|
||||
|
||||
@@ -12,7 +12,7 @@ import wave
|
||||
from typing import Awaitable, Callable
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, StartFrame
|
||||
from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, StartFrame
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
@@ -57,14 +57,19 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
self._callbacks = callbacks
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await self._callbacks.on_client_connected(self._websocket)
|
||||
await super().start(frame)
|
||||
await self._callbacks.on_client_connected(self._websocket)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_messages())
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
if self._websocket.client_state != WebSocketState.DISCONNECTED:
|
||||
await self._websocket.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
if self._websocket.client_state != WebSocketState.DISCONNECTED:
|
||||
await self._websocket.close()
|
||||
await super().stop()
|
||||
|
||||
async def _receive_messages(self):
|
||||
async for message in self._websocket.iter_text():
|
||||
|
||||
@@ -11,7 +11,7 @@ import wave
|
||||
from typing import Awaitable, Callable
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, StartFrame
|
||||
from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, StartFrame
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
from pipecat.serializers.protobuf import ProtobufFrameSerializer
|
||||
@@ -64,10 +64,15 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
self._server_task = self.get_event_loop().create_task(self._server_task_handler())
|
||||
await super().start(frame)
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
self._stop_server_event.set()
|
||||
await self._server_task
|
||||
await super().stop()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
self._server_task.cancel()
|
||||
await self._server_task
|
||||
|
||||
async def _server_task_handler(self):
|
||||
logger.info(f"Starting websocket server on {self._host}:{self._port}")
|
||||
|
||||
@@ -23,6 +23,8 @@ from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
ImageRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
@@ -125,8 +127,15 @@ class DailyCallbacks(BaseModel):
|
||||
|
||||
def completion_callback(future):
|
||||
def _callback(*args):
|
||||
if not future.cancelled():
|
||||
future.get_loop().call_soon_threadsafe(future.set_result, *args)
|
||||
def set_result(future, *args):
|
||||
try:
|
||||
if len(args) > 1:
|
||||
future.set_result(args)
|
||||
else:
|
||||
future.set_result(*args)
|
||||
except asyncio.InvalidStateError:
|
||||
pass
|
||||
future.get_loop().call_soon_threadsafe(set_result, future, *args)
|
||||
return _callback
|
||||
|
||||
|
||||
@@ -282,11 +291,12 @@ class DailyTransportClient(EventHandler):
|
||||
await self._callbacks.on_error(error_msg)
|
||||
|
||||
async def _start_transcription(self):
|
||||
future = self._loop.create_future()
|
||||
logger.info(f"Enabling transcription with settings {self._params.transcription_settings}")
|
||||
|
||||
future = self._loop.create_future()
|
||||
self._client.start_transcription(
|
||||
settings=self._params.transcription_settings.model_dump(exclude_none=True),
|
||||
completion=lambda error: future.set_result(error)
|
||||
completion=completion_callback(future)
|
||||
)
|
||||
error = await future
|
||||
if error:
|
||||
@@ -295,14 +305,10 @@ class DailyTransportClient(EventHandler):
|
||||
async def _join(self):
|
||||
future = self._loop.create_future()
|
||||
|
||||
def handle_join_response(data, error):
|
||||
if not future.cancelled():
|
||||
future.get_loop().call_soon_threadsafe(future.set_result, (data, error))
|
||||
|
||||
self._client.join(
|
||||
self._room_url,
|
||||
self._token,
|
||||
completion=handle_join_response,
|
||||
completion=completion_callback(future),
|
||||
client_settings={
|
||||
"inputs": {
|
||||
"camera": {
|
||||
@@ -370,20 +376,14 @@ class DailyTransportClient(EventHandler):
|
||||
|
||||
async def _stop_transcription(self):
|
||||
future = self._loop.create_future()
|
||||
self._client.stop_transcription(completion=lambda error: future.set_result(error))
|
||||
self._client.stop_transcription(completion=completion_callback(future))
|
||||
error = await future
|
||||
if error:
|
||||
logger.error(f"Unable to stop transcription: {error}")
|
||||
|
||||
async def _leave(self):
|
||||
future = self._loop.create_future()
|
||||
|
||||
def handle_leave_response(error):
|
||||
if not future.cancelled():
|
||||
future.get_loop().call_soon_threadsafe(future.set_result, error)
|
||||
|
||||
self._client.leave(completion=handle_leave_response)
|
||||
|
||||
self._client.leave(completion=completion_callback(future))
|
||||
return await asyncio.wait_for(future, timeout=10)
|
||||
|
||||
async def cleanup(self):
|
||||
@@ -547,9 +547,19 @@ class DailyInputTransport(BaseInputTransport):
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_in_task = self.get_event_loop().create_task(self._audio_in_task_handler())
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Parent stop.
|
||||
await super().stop()
|
||||
await super().stop(frame)
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
# Stop audio thread.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_in_task.cancel()
|
||||
await self._audio_in_task
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Parent stop.
|
||||
await super().cancel(frame)
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
# Stop audio thread.
|
||||
@@ -664,9 +674,15 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
# Join the room.
|
||||
await self._client.join()
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Parent stop.
|
||||
await super().stop()
|
||||
await super().stop(frame)
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Parent stop.
|
||||
await super().cancel(frame)
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user