Compare commits
11 Commits
khk/load-j
...
v0.0.47
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a46eaa838b | ||
|
|
7c432499db | ||
|
|
8d75fcc9f0 | ||
|
|
61d73f81ae | ||
|
|
951255def9 | ||
|
|
e556f34094 | ||
|
|
ccc3691620 | ||
|
|
5321affda7 | ||
|
|
e5ad8dc67b | ||
|
|
46927805bc | ||
|
|
b999b76f70 |
13
CHANGELOG.md
13
CHANGELOG.md
@@ -5,7 +5,7 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
## [0.0.47] - 2024-10-22
|
||||
|
||||
### Added
|
||||
|
||||
@@ -15,8 +15,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Added a foundational example for Gladia transcription:
|
||||
`13c-gladia-transcription.py`
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated `GladiaSTTService` to use the V2 API.
|
||||
|
||||
- Changed `DailyTransport` transcription model to `nova-2-general`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue that would cause an import error when importing
|
||||
`SileroVADAnalyzer` from the old package `pipecat.vad.silero`.
|
||||
|
||||
- Fixed `enable_usage_metrics` to control LLM/TTS usage metrics separately
|
||||
from `enable_metrics`.
|
||||
|
||||
@@ -32,6 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed `DeepgramSTTService` model to `nova-2-general`.
|
||||
|
||||
- Moved `SileroVAD` audio processor to `processors.audio.vad`.
|
||||
|
||||
- Module `utils.audio` is now `audio.utils`. A new `resample_audio` function has
|
||||
|
||||
@@ -1,12 +1,43 @@
|
||||
# Simple Chatbot
|
||||
# Chatbot with canonical-metrics
|
||||
|
||||
<img src="image.png" width="420px">
|
||||
This project implements a chatbot using a pipeline architecture that integrates audio processing, transcription, and a language model for conversational interactions. The chatbot operates within a daily communication environment, utilizing various services for text-to-speech and language model responses.
|
||||
|
||||
This app connects you to a chatbot powered by GPT-4, complete with animations generated by Stable Video Diffusion.
|
||||
## Features
|
||||
|
||||
See a video of it in action: https://x.com/kwindla/status/1778628911817183509
|
||||
- **Audio Input and Output**: Captures microphone input and plays back audio responses.
|
||||
- **Voice Activity Detection**: Utilizes Silero VAD to manage audio input intelligently.
|
||||
- **Text-to-Speech**: Integrates ElevenLabs TTS service to convert text responses into audio.
|
||||
- **Language Model Interaction**: Uses OpenAI's GPT-4 model to generate responses based on user input.
|
||||
- **Transcription Services**: Captures and transcribes participant speech for analytics.
|
||||
- **Metrics Collection**: Sends audio data for analysis via Canonical Metrics Service.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.7+
|
||||
- `aiohttp`
|
||||
- `loguru`
|
||||
- `python-dotenv`
|
||||
- Additional libraries from the `pipecat` package.
|
||||
|
||||
## Setup
|
||||
|
||||
1. Clone the repository.
|
||||
2. Install the required packages.
|
||||
3. Set up environment variables for API keys:
|
||||
- `OPENAI_API_KEY`
|
||||
- `ELEVENLABS_API_KEY`
|
||||
- `CANONICAL_API_KEY`
|
||||
- `CANONICAL_API_URL`
|
||||
4. Run the script.
|
||||
|
||||
## Usage
|
||||
|
||||
The chatbot introduces itself and engages in conversations, providing brief and creative responses. Designed for flexibility, it can support multiple languages with appropriate configuration.
|
||||
|
||||
## Events
|
||||
|
||||
- Participants joining or leaving the call are handled dynamically, adjusting the chatbot's behavior accordingly.
|
||||
|
||||
And a quick video walkthrough of the code: https://www.loom.com/share/13df1967161f4d24ade054e7f8753416
|
||||
|
||||
ℹ️ The first time, things might take extra time to get started since VAD (Voice Activity Detection) model needs to be downloaded.
|
||||
|
||||
|
||||
@@ -5,12 +5,16 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.frames.frames import EndFrame, LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -20,12 +24,6 @@ from pipecat.services.gladia import GladiaSTTService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
@@ -90,6 +88,11 @@ async def main():
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
# Register an event handler to exit the application when the user leaves.
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
@@ -21,9 +21,9 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"aiohttp~=3.10.3",
|
||||
"loguru~=0.7.2",
|
||||
"Markdown~=3.7",
|
||||
"numpy~=1.26.4",
|
||||
"loguru~=0.7.2",
|
||||
"Pillow~=10.4.0",
|
||||
"protobuf~=4.25.4",
|
||||
"pydantic~=2.8.2",
|
||||
|
||||
@@ -8,6 +8,7 @@ import base64
|
||||
import json
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
@@ -23,7 +24,6 @@ from pipecat.services.ai_services import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
# See .env.example for Gladia configuration needed
|
||||
try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
@@ -38,15 +38,16 @@ class GladiaSTTService(STTService):
|
||||
class InputParams(BaseModel):
|
||||
sample_rate: Optional[int] = 16000
|
||||
language: Optional[Language] = Language.EN
|
||||
transcription_hint: Optional[str] = None
|
||||
endpointing: Optional[int] = 200
|
||||
prosody: Optional[bool] = None
|
||||
endpointing: Optional[float] = 0.2
|
||||
maximum_duration_without_endpointing: Optional[int] = 10
|
||||
audio_enhancer: Optional[bool] = None
|
||||
words_accurate_timestamps: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
url: str = "wss://api.gladia.io/audio/text/audio-transcription",
|
||||
url: str = "https://api.gladia.io/v2/live",
|
||||
confidence: float = 0.5,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
@@ -56,101 +57,82 @@ class GladiaSTTService(STTService):
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self._settings = {
|
||||
"encoding": "wav/pcm",
|
||||
"bit_depth": 16,
|
||||
"sample_rate": params.sample_rate,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN,
|
||||
"transcription_hint": params.transcription_hint,
|
||||
"channels": 1,
|
||||
"language_config": {
|
||||
"languages": [self.language_to_service_language(params.language)]
|
||||
if params.language
|
||||
else [],
|
||||
"code_switching": False,
|
||||
},
|
||||
"endpointing": params.endpointing,
|
||||
"prosody": params.prosody,
|
||||
"maximum_duration_without_endpointing": params.maximum_duration_without_endpointing,
|
||||
"pre_processing": {
|
||||
"audio_enhancer": params.audio_enhancer,
|
||||
},
|
||||
"realtime_processing": {
|
||||
"words_accurate_timestamps": params.words_accurate_timestamps,
|
||||
},
|
||||
}
|
||||
self._confidence = confidence
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.BG:
|
||||
return "bulgarian"
|
||||
case Language.CA:
|
||||
return "catalan"
|
||||
case Language.ZH:
|
||||
return "chinese"
|
||||
case Language.CS:
|
||||
return "czech"
|
||||
case Language.DA:
|
||||
return "danish"
|
||||
case Language.NL:
|
||||
return "dutch"
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_AU
|
||||
| Language.EN_GB
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "english"
|
||||
case Language.ET:
|
||||
return "estonian"
|
||||
case Language.FI:
|
||||
return "finnish"
|
||||
case Language.FR | Language.FR_CA:
|
||||
return "french"
|
||||
case Language.DE | Language.DE_CH:
|
||||
return "german"
|
||||
case Language.EL:
|
||||
return "greek"
|
||||
case Language.HI:
|
||||
return "hindi"
|
||||
case Language.HU:
|
||||
return "hungarian"
|
||||
case Language.ID:
|
||||
return "indonesian"
|
||||
case Language.IT:
|
||||
return "italian"
|
||||
case Language.JA:
|
||||
return "japanese"
|
||||
case Language.KO:
|
||||
return "korean"
|
||||
case Language.LV:
|
||||
return "latvian"
|
||||
case Language.LT:
|
||||
return "lithuanian"
|
||||
case Language.MS:
|
||||
return "malay"
|
||||
case Language.NO:
|
||||
return "norwegian"
|
||||
case Language.PL:
|
||||
return "polish"
|
||||
case Language.PT | Language.PT_BR:
|
||||
return "portuguese"
|
||||
case Language.RO:
|
||||
return "romanian"
|
||||
case Language.RU:
|
||||
return "russian"
|
||||
case Language.SK:
|
||||
return "slovak"
|
||||
case Language.ES:
|
||||
return "spanish"
|
||||
case Language.SV:
|
||||
return "slovenian"
|
||||
case Language.TH:
|
||||
return "thai"
|
||||
case Language.TR:
|
||||
return "turkish"
|
||||
case Language.UK:
|
||||
return "ukrainian"
|
||||
case Language.VI:
|
||||
return "vietnamese"
|
||||
return None
|
||||
language_map = {
|
||||
Language.BG: "bg",
|
||||
Language.CA: "ca",
|
||||
Language.ZH: "zh",
|
||||
Language.CS: "cs",
|
||||
Language.DA: "da",
|
||||
Language.NL: "nl",
|
||||
Language.EN: "en",
|
||||
Language.EN_US: "en",
|
||||
Language.EN_AU: "en",
|
||||
Language.EN_GB: "en",
|
||||
Language.EN_NZ: "en",
|
||||
Language.EN_IN: "en",
|
||||
Language.ET: "et",
|
||||
Language.FI: "fi",
|
||||
Language.FR: "fr",
|
||||
Language.FR_CA: "fr",
|
||||
Language.DE: "de",
|
||||
Language.DE_CH: "de",
|
||||
Language.EL: "el",
|
||||
Language.HI: "hi",
|
||||
Language.HU: "hu",
|
||||
Language.ID: "id",
|
||||
Language.IT: "it",
|
||||
Language.JA: "ja",
|
||||
Language.KO: "ko",
|
||||
Language.LV: "lv",
|
||||
Language.LT: "lt",
|
||||
Language.MS: "ms",
|
||||
Language.NO: "no",
|
||||
Language.PL: "pl",
|
||||
Language.PT: "pt",
|
||||
Language.PT_BR: "pt",
|
||||
Language.RO: "ro",
|
||||
Language.RU: "ru",
|
||||
Language.SK: "sk",
|
||||
Language.ES: "es",
|
||||
Language.SV: "sv",
|
||||
Language.TH: "th",
|
||||
Language.TR: "tr",
|
||||
Language.UK: "uk",
|
||||
Language.VI: "vi",
|
||||
}
|
||||
return language_map.get(language)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._websocket = await websockets.connect(self._url)
|
||||
response = await self._setup_gladia()
|
||||
self._websocket = await websockets.connect(response["url"])
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
await self._setup_gladia()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._send_stop_recording()
|
||||
await self._websocket.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
@@ -164,39 +146,37 @@ class GladiaSTTService(STTService):
|
||||
yield None
|
||||
|
||||
async def _setup_gladia(self):
|
||||
configuration = {
|
||||
"x_gladia_key": self._api_key,
|
||||
"encoding": "WAV/PCM",
|
||||
"model_type": "fast",
|
||||
"language_behaviour": "manual",
|
||||
"sample_rate": self._settings["sample_rate"],
|
||||
"language": self._settings["language"],
|
||||
"transcription_hint": self._settings["transcription_hint"],
|
||||
"endpointing": self._settings["endpointing"],
|
||||
"prosody": self._settings["prosody"],
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(configuration))
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self._url,
|
||||
headers={"X-Gladia-Key": self._api_key, "Content-Type": "application/json"},
|
||||
json=self._settings,
|
||||
) as response:
|
||||
if response.ok:
|
||||
return await response.json()
|
||||
else:
|
||||
logger.error(
|
||||
f"Gladia error: {response.status}: {response.text or response.reason}"
|
||||
)
|
||||
raise Exception(f"Failed to initialize Gladia session: {response.status}")
|
||||
|
||||
async def _send_audio(self, audio: bytes):
|
||||
message = {"frames": base64.b64encode(audio).decode("utf-8")}
|
||||
data = base64.b64encode(audio).decode("utf-8")
|
||||
message = {"type": "audio_chunk", "data": {"chunk": data}}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
|
||||
async def _send_stop_recording(self):
|
||||
await self._websocket.send(json.dumps({"type": "stop_recording"}))
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
async for message in self._websocket:
|
||||
utterance = json.loads(message)
|
||||
if not utterance:
|
||||
continue
|
||||
|
||||
if "error" in utterance:
|
||||
message = utterance["message"]
|
||||
logger.error(f"Gladia error: {message}")
|
||||
elif "confidence" in utterance:
|
||||
type = utterance["type"]
|
||||
confidence = utterance["confidence"]
|
||||
transcript = utterance["transcription"]
|
||||
content = json.loads(message)
|
||||
if content["type"] == "transcript":
|
||||
utterance = content["data"]["utterance"]
|
||||
confidence = utterance.get("confidence", 0)
|
||||
transcript = utterance["text"]
|
||||
if confidence >= self._confidence:
|
||||
if type == "final":
|
||||
if content["data"]["is_final"]:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(transcript, "", time_now_iso8601())
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Mapping, Optional
|
||||
@@ -20,7 +21,7 @@ from daily import (
|
||||
VirtualSpeakerDevice,
|
||||
)
|
||||
from loguru import logger
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADParams
|
||||
from pipecat.frames.frames import (
|
||||
@@ -93,8 +94,8 @@ class DailyDialinSettings(BaseModel):
|
||||
|
||||
class DailyTranscriptionSettings(BaseModel):
|
||||
language: str = "en"
|
||||
tier: str = "nova"
|
||||
model: str = "2-conversationalai"
|
||||
tier: Optional[str] = None
|
||||
model: str = "nova-2-general"
|
||||
profanity_filter: bool = True
|
||||
redact: bool = False
|
||||
endpointing: bool = True
|
||||
@@ -102,6 +103,16 @@ class DailyTranscriptionSettings(BaseModel):
|
||||
includeRawResponse: bool = True
|
||||
extra: Mapping[str, Any] = {"interim_results": True}
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_deprecated_fields(cls, values):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
if "tier" in values:
|
||||
warnings.warn(
|
||||
"Field 'tier' is deprecated, use 'model' instead.", DeprecationWarning
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
class DailyParams(TransportParams):
|
||||
api_url: str = "https://api.daily.co/v1"
|
||||
|
||||
@@ -4,8 +4,13 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from loguru import logger
|
||||
import warnings
|
||||
|
||||
logger.warning("DEPRECATED: Package `pipecat.vad` is deprecated, use `pipecat.audio.vad` instead.")
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Package `pipecat.vad` is deprecated, use `pipecat.audio.vad` instead", DeprecationWarning
|
||||
)
|
||||
|
||||
from ..audio.vad.silero import SileroVAD, SileroVADAnalyzer
|
||||
from ..audio.vad.silero import SileroVADAnalyzer
|
||||
from ..processors.audio.vad.silero import SileroVAD
|
||||
|
||||
@@ -4,8 +4,12 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from loguru import logger
|
||||
import warnings
|
||||
|
||||
logger.warning("DEPRECATED: Package `pipecat.vad` is deprecated, use `pipecat.audio.vad` instead.")
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Package `pipecat.vad` is deprecated, use `pipecat.audio.vad` instead", DeprecationWarning
|
||||
)
|
||||
|
||||
from ..audio.vad.vad_analyzer import VADAnalyzer, VADParams, VADState
|
||||
|
||||
1
src/pipecat/workflow/.gitignore
vendored
1
src/pipecat/workflow/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
*.json
|
||||
@@ -1 +0,0 @@
|
||||
python -m pipecat.workflow.workflow_test to run
|
||||
@@ -1,18 +0,0 @@
|
||||
from ..services.cartesia import CartesiaTTSService
|
||||
from ..services.openai import OpenAILLMService
|
||||
from ..services.deepgram import DeepgramSTTService
|
||||
from ..transports.services.daily import DailyTransport
|
||||
from ..processors.frame_processor import FrameProcessor
|
||||
|
||||
# Map workflow types to their corresponding Python classes
|
||||
WORKFLOW_MAPPING = {
|
||||
"inputs/audio_input": DailyTransport,
|
||||
"processors/speech_to_text": DeepgramSTTService,
|
||||
"processors/llm": OpenAILLMService,
|
||||
"processors/text_to_speech": CartesiaTTSService,
|
||||
"outputs/audio_output": DailyTransport,
|
||||
}
|
||||
|
||||
|
||||
def get_processor_class(node_type: str) -> type[FrameProcessor]:
|
||||
return WORKFLOW_MAPPING.get(node_type, FrameProcessor)
|
||||
@@ -1,65 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from ..pipeline.pipeline import Pipeline
|
||||
from ..pipeline.runner import PipelineRunner
|
||||
from ..pipeline.task import PipelineTask, PipelineParams
|
||||
from .workflow_translator import translate_workflow
|
||||
from ..services.openai import OpenAIUserContextAggregator
|
||||
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def main():
|
||||
print("Starting workflow test")
|
||||
|
||||
# Update the path to the workflow.json file
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
workflow_path = os.path.join(script_dir, "workflow.json")
|
||||
print(f"Workflow path: {workflow_path}")
|
||||
|
||||
# Translate the workflow to a list of processors
|
||||
print("Translating workflow to processors")
|
||||
processors, daily_transport = translate_workflow(workflow_path)
|
||||
print(f"Processors created: {processors}")
|
||||
|
||||
# Create a pipeline from the processors
|
||||
print("Creating pipeline")
|
||||
pipeline = Pipeline(processors)
|
||||
print(f"Pipeline created: {pipeline}")
|
||||
|
||||
# Create a pipeline task
|
||||
print("Creating pipeline task")
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
|
||||
print(f"Pipeline task created: {task}")
|
||||
|
||||
# Create a pipeline runner
|
||||
print("Creating pipeline runner")
|
||||
runner = PipelineRunner()
|
||||
print(f"Pipeline runner created: {runner}")
|
||||
|
||||
user_context_aggregator = next(
|
||||
p for p in processors if isinstance(p, OpenAIUserContextAggregator)
|
||||
)
|
||||
|
||||
@daily_transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
await task.queue_frames([user_context_aggregator.get_context_frame()])
|
||||
|
||||
# Run the pipeline
|
||||
print("Running the pipeline")
|
||||
try:
|
||||
await runner.run(task)
|
||||
print("Pipeline execution completed successfully")
|
||||
except Exception as e:
|
||||
print(f"Error during pipeline execution: {e}")
|
||||
|
||||
print("Workflow test completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting main execution")
|
||||
asyncio.run(main())
|
||||
print("Main execution completed")
|
||||
@@ -1,140 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from .workflow_mapping import get_processor_class
|
||||
from ..processors.frame_processor import FrameProcessor
|
||||
from ..transports.services.daily import DailyParams
|
||||
from ..processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from ..audio.vad.silero import SileroVADAnalyzer
|
||||
from ..transports.base_transport import BaseTransport
|
||||
|
||||
|
||||
def load_workflow(file_path: str) -> Dict[str, Any]:
|
||||
print(f"Loading workflow from file: {file_path}")
|
||||
try:
|
||||
with open(file_path, "r") as f:
|
||||
workflow = json.load(f)
|
||||
print(f"Workflow loaded successfully: {workflow}")
|
||||
return workflow
|
||||
except Exception as e:
|
||||
print(f"Error loading workflow: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def create_processor(node: Dict[str, Any], next_node: Dict[str, Any] = None) -> FrameProcessor:
|
||||
print(f"Creating processor for node: {node['id']} of type: {node['type']}")
|
||||
processor_class = get_processor_class(node["type"])
|
||||
print(f"Processor class: {processor_class}")
|
||||
|
||||
# Extract relevant properties for initialization
|
||||
init_params = {}
|
||||
if node["type"] == "inputs/audio_input":
|
||||
init_params = {
|
||||
"room_url": os.getenv("DAILY_SAMPLE_ROOM_URL"),
|
||||
"token": "",
|
||||
"bot_name": "PipecatBot",
|
||||
"params": DailyParams(
|
||||
audio_out_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_audio_passthrough=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
elif node["type"] == "processors/speech_to_text":
|
||||
init_params = {
|
||||
"api_key": os.getenv("DEEPGRAM_API_KEY"),
|
||||
}
|
||||
elif node["type"] == "processors/text_to_speech":
|
||||
init_params = {
|
||||
"api_key": os.getenv("CARTESIA_API_KEY"),
|
||||
"voice_id": "79a125e8-cd45-4c13-8a67-188112f4dd22",
|
||||
}
|
||||
|
||||
print(f"Initialization parameters: {init_params}")
|
||||
processor = processor_class(**init_params)
|
||||
print(f"Processor created: {processor}")
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
def create_pipeline(workflow: Dict[str, Any]) -> Tuple[List[FrameProcessor], BaseTransport]:
|
||||
print("Creating pipeline from workflow")
|
||||
nodes = {node["id"]: node for node in workflow["nodes"]}
|
||||
links = workflow["links"]
|
||||
|
||||
print(f"Nodes: {nodes}")
|
||||
print(f"Links: {links}")
|
||||
|
||||
# Create a dictionary to store processors
|
||||
processors = {}
|
||||
daily_transport = None
|
||||
llm_service = None
|
||||
context_aggregator = None
|
||||
|
||||
# Create processors for each node
|
||||
for node_id, node in nodes.items():
|
||||
print(f"Creating processor for node: {node_id}")
|
||||
|
||||
if node["type"] == "inputs/audio_input":
|
||||
daily_transport = create_processor(node)
|
||||
processors[node_id] = {"processor": daily_transport, "type": node["type"]}
|
||||
elif node["type"] == "outputs/audio_output":
|
||||
if daily_transport is None:
|
||||
raise ValueError("Audio output transport node found before audio input node")
|
||||
processors[node_id] = {"processor": daily_transport, "type": node["type"]}
|
||||
elif node["type"] == "processors/llm":
|
||||
llm_service = create_processor(node)
|
||||
processors[node_id] = {"processor": llm_service, "type": node["type"]}
|
||||
context = OpenAILLMContext(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant. Your name is Housecat. You are participating in a voice conversation. Keep your answers brief. For punctuation use only period, comma, and question mark.",
|
||||
},
|
||||
{"role": "user", "content": "Introduce yourself."},
|
||||
]
|
||||
)
|
||||
context_aggregator = llm_service.create_context_aggregator(context)
|
||||
print(f"Context aggregator created: {context_aggregator}")
|
||||
else:
|
||||
processors[node_id] = {"processor": create_processor(node), "type": node["type"]}
|
||||
|
||||
# Create the pipeline based on the links
|
||||
pipeline = []
|
||||
for link in links:
|
||||
source_id, _, _, target_id, _, _ = link
|
||||
print(f"Processing link: {source_id} -> {target_id}")
|
||||
|
||||
if processors[source_id]["processor"] not in pipeline:
|
||||
print(f"Adding source processor: {source_id}, {processors[source_id]['processor']}")
|
||||
if processors[source_id]["type"] == "inputs/audio_input":
|
||||
pipeline.append(processors[source_id]["processor"].input())
|
||||
else:
|
||||
pipeline.append(processors[source_id]["processor"])
|
||||
|
||||
if processors[target_id]["processor"] not in pipeline and target_id in processors:
|
||||
print(f"Adding target processor: {target_id} {processors[target_id]['processor']}")
|
||||
if processors[target_id]["type"] == "outputs/audio_output":
|
||||
pipeline.append(processors[target_id]["processor"].output())
|
||||
elif processors[target_id]["type"] == "processors/llm":
|
||||
print("TRYING TO LINK AGGREGATOR")
|
||||
if context_aggregator:
|
||||
print("AGGREGATOR FOUND")
|
||||
pipeline.append(context_aggregator.user())
|
||||
pipeline.append(processors[target_id]["processor"])
|
||||
else:
|
||||
pipeline.append(processors[target_id]["processor"])
|
||||
|
||||
print(f"Pipeline created with {len(pipeline)} processors")
|
||||
print(f"Pipeline: {pipeline}")
|
||||
|
||||
return pipeline, daily_transport
|
||||
|
||||
|
||||
def translate_workflow(file_path: str) -> Tuple[List[FrameProcessor], BaseTransport]:
|
||||
print(f"Translating workflow from file: {file_path}")
|
||||
workflow = load_workflow(file_path)
|
||||
pipeline, transport = create_pipeline(workflow)
|
||||
print("Workflow translation completed")
|
||||
return pipeline, transport
|
||||
Reference in New Issue
Block a user