Files
AI-VideoAssistant/engine/tests/test_dify_provider.py

167 lines
5.3 KiB
Python

from typing import Any, Dict, List, Optional
import pytest
from providers.common.base import LLMMessage
from providers.llm.dify import DifyLLMService
class _FakeStreamResponse:
def __init__(self, lines: List[bytes], status: int = 200):
self.content = _FakeStreamContent(lines)
self.status = status
self.closed = False
async def json(self) -> Dict[str, Any]:
return {}
async def text(self) -> str:
return ""
def close(self) -> None:
self.closed = True
class _FakeJSONResponse:
def __init__(self, payload: Dict[str, Any], status: int = 200):
self.payload = payload
self.status = status
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def json(self) -> Dict[str, Any]:
return dict(self.payload)
async def text(self) -> str:
return ""
class _FakeStreamContent:
def __init__(self, lines: List[bytes]):
self._lines = list(lines)
def __aiter__(self):
return self._iter()
async def _iter(self):
for line in self._lines:
yield line
class _FakeClientSession:
post_responses: List[_FakeStreamResponse] = []
get_payloads: List[Dict[str, Any]] = []
last_post_url: Optional[str] = None
last_post_json: Optional[Dict[str, Any]] = None
last_get_url: Optional[str] = None
last_get_params: Optional[Dict[str, Any]] = None
def __init__(self, headers: Optional[Dict[str, str]] = None):
self.headers = headers or {}
self.closed = False
async def close(self) -> None:
self.closed = True
async def post(self, url: str, json: Dict[str, Any]):
type(self).last_post_url = url
type(self).last_post_json = dict(json)
if not type(self).post_responses:
raise AssertionError("No fake Dify stream response queued")
return type(self).post_responses.pop(0)
def get(self, url: str, params: Dict[str, Any]):
type(self).last_get_url = url
type(self).last_get_params = dict(params)
if not type(self).get_payloads:
raise AssertionError("No fake Dify JSON payload queued")
return _FakeJSONResponse(type(self).get_payloads.pop(0))
@pytest.mark.asyncio
async def test_dify_provider_streams_message_answer_and_tracks_conversation(monkeypatch):
monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession)
_FakeClientSession.post_responses = [
_FakeStreamResponse(
[
b'data: {"event":"message","conversation_id":"conv-1","answer":"Hello "}\n',
b"\n",
b'data: {"event":"agent_message","conversation_id":"conv-1","answer":"from Dify."}\n',
b"\n",
b'data: {"event":"message_end","conversation_id":"conv-1"}\n',
b"\n",
]
)
]
service = DifyLLMService(api_key="key", base_url="https://dify.example/v1")
await service.connect()
events = [event async for event in service.generate_stream([LLMMessage(role="user", content="Hi there")])]
assert [event.type for event in events] == ["text_delta", "text_delta", "done"]
assert events[0].text == "Hello "
assert events[1].text == "from Dify."
assert service._conversation_id == "conv-1"
assert _FakeClientSession.last_post_url == "https://dify.example/v1/chat-messages"
assert _FakeClientSession.last_post_json == {
"inputs": {},
"query": "Hi there",
"user": service._user_id,
"response_mode": "streaming",
}
@pytest.mark.asyncio
async def test_dify_provider_reuses_conversation_id_on_follow_up(monkeypatch):
monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession)
_FakeClientSession.post_responses = [
_FakeStreamResponse(
[
b'data: {"event":"message","conversation_id":"conv-2","answer":"First"}\n',
b"\n",
]
),
_FakeStreamResponse(
[
b'data: {"event":"message","conversation_id":"conv-2","answer":"Second"}\n',
b"\n",
]
),
]
service = DifyLLMService(api_key="key", base_url="https://dify.example/v1")
await service.connect()
_ = [event async for event in service.generate_stream([LLMMessage(role="user", content="Turn one")])]
_ = [event async for event in service.generate_stream([LLMMessage(role="user", content="Turn two")])]
assert _FakeClientSession.last_post_json == {
"inputs": {},
"query": "Turn two",
"user": service._user_id,
"response_mode": "streaming",
"conversation_id": "conv-2",
}
@pytest.mark.asyncio
async def test_dify_provider_loads_initial_greeting_from_parameters(monkeypatch):
monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession)
_FakeClientSession.get_payloads = [
{"opening_statement": "Hello from Dify."},
]
service = DifyLLMService(api_key="key", base_url="https://dify.example/v1")
await service.connect()
greeting = await service.get_initial_greeting()
assert greeting == "Hello from Dify."
assert _FakeClientSession.last_get_url == "https://dify.example/v1/parameters"
assert _FakeClientSession.last_get_params == {"user": service._user_id}