Compare commits

...

1 Commits

Author SHA1 Message Date
Moishe Lettvin
b168c53e44 Adding some more doscstrings, cleanup 2024-03-08 09:39:51 -05:00
2 changed files with 33 additions and 9 deletions

View File

@@ -16,13 +16,20 @@ from dailyai.pipeline.frames import (
Frame,
TextFrame,
TranscriptionQueueFrame,
UserStoppedSpeakingFrame
)
from abc import abstractmethod
from typing import AsyncGenerator, AsyncIterable, BinaryIO, Iterable, List
class AIService(FrameProcessor):
""" This is the base class for various AI services (LLM, TTS and Image)
This class adds some convenienence functions to run, effectively, a one-stage
pipeline where the incoming frames can come from an iterable or queue
and the processed frames go to a queue. Child classes extend those convenience
functions, eg. TTS's `say` method runs the TTS and emits the AudioFrames to a
queue.
"""
def __init__(self):
self.logger = logging.getLogger("dailyai")
@@ -30,12 +37,17 @@ class AIService(FrameProcessor):
def stop(self):
pass
async def run_to_queue(self, queue: asyncio.Queue, frames, add_end_of_stream=False) -> None:
async def run_to_queue(
self,
queue: asyncio.Queue,
frames: Iterable[Frame] | AsyncIterable[Frame] | asyncio.Queue[Frame]
) -> None:
""" Process the given frames (from an iterable or queue) and send them to
the given queue.
"""
async for frame in self.run(frames):
await queue.put(frame)
if add_end_of_stream:
await queue.put(EndFrame())
async def run(
self,
@@ -43,6 +55,16 @@ class AIService(FrameProcessor):
| AsyncIterable[Frame]
| asyncio.Queue[Frame],
) -> AsyncGenerator[Frame, None]:
""" Generates 0 or more frames from the given iterable or queue.
This is a convenience function to take a collection of frames, process
them, and yield processed frames.
The preferred way to use FrameProcessors is with a pipeline, but if you
have a very simple case (eg. a list of static text blocks you want to speak,
or a list of static image description you want to render) this function
will be helpful.
"""
try:
if isinstance(frames, AsyncIterable):
async for frame in frames:
@@ -73,14 +95,14 @@ class LLMService(AIService):
self._messages = messages
@abstractmethod
async def run_llm_async(self, messages) -> AsyncGenerator[str, None]:
async def run_llm_async(self, messages, tool_choice=None) -> AsyncGenerator[str, None]:
yield ""
@abstractmethod
async def run_llm(self, messages) -> str:
pass
async def process_frame(self, frame: Frame, tool_choice: str = None) -> AsyncGenerator[Frame, None]:
async def process_frame(self, frame: Frame, tool_choice: str | None = None) -> AsyncGenerator[Frame, None]:
if isinstance(frame, LLMMessagesQueueFrame):
function_name = ""
arguments = ""

View File

@@ -19,11 +19,13 @@ class OLLamaLLMService(LLMService):
model=self._model
)
async def run_llm_async(self, messages) -> AsyncGenerator[str, None]:
async def run_llm_async(self, messages, tool_choice=None) -> AsyncGenerator[str, None]:
messages_for_log = json.dumps(messages)
self.logger.debug(f"Generating chat via openai: {messages_for_log}")
chunks = await self._client.chat.completions.create(model=self._model, stream=True, messages=messages)
chunks = await self._client.chat.completions.create(
model=self._model, stream=True, messages=messages
)
async for chunk in chunks:
if len(chunk.choices) == 0:
continue
@@ -33,7 +35,7 @@ class OLLamaLLMService(LLMService):
async def run_llm(self, messages) -> str | None:
messages_for_log = json.dumps(messages)
self.logger.debug(f"Generating chat via openai: {messages_for_log}")
self.logger.debug(f"Generating chat via ollama: {messages_for_log}")
response = await self._client.chat.completions.create(model=self._model, stream=False, messages=messages)
if response and len(response.choices) > 0: