167 lines
5.3 KiB
Python
167 lines
5.3 KiB
Python
from typing import Any, Dict, List, Optional
|
|
|
|
import pytest
|
|
|
|
from providers.common.base import LLMMessage
|
|
from providers.llm.dify import DifyLLMService
|
|
|
|
|
|
class _FakeStreamResponse:
|
|
def __init__(self, lines: List[bytes], status: int = 200):
|
|
self.content = _FakeStreamContent(lines)
|
|
self.status = status
|
|
self.closed = False
|
|
|
|
async def json(self) -> Dict[str, Any]:
|
|
return {}
|
|
|
|
async def text(self) -> str:
|
|
return ""
|
|
|
|
def close(self) -> None:
|
|
self.closed = True
|
|
|
|
|
|
class _FakeJSONResponse:
|
|
def __init__(self, payload: Dict[str, Any], status: int = 200):
|
|
self.payload = payload
|
|
self.status = status
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
async def json(self) -> Dict[str, Any]:
|
|
return dict(self.payload)
|
|
|
|
async def text(self) -> str:
|
|
return ""
|
|
|
|
|
|
class _FakeStreamContent:
|
|
def __init__(self, lines: List[bytes]):
|
|
self._lines = list(lines)
|
|
|
|
def __aiter__(self):
|
|
return self._iter()
|
|
|
|
async def _iter(self):
|
|
for line in self._lines:
|
|
yield line
|
|
|
|
|
|
class _FakeClientSession:
|
|
post_responses: List[_FakeStreamResponse] = []
|
|
get_payloads: List[Dict[str, Any]] = []
|
|
last_post_url: Optional[str] = None
|
|
last_post_json: Optional[Dict[str, Any]] = None
|
|
last_get_url: Optional[str] = None
|
|
last_get_params: Optional[Dict[str, Any]] = None
|
|
|
|
def __init__(self, headers: Optional[Dict[str, str]] = None):
|
|
self.headers = headers or {}
|
|
self.closed = False
|
|
|
|
async def close(self) -> None:
|
|
self.closed = True
|
|
|
|
async def post(self, url: str, json: Dict[str, Any]):
|
|
type(self).last_post_url = url
|
|
type(self).last_post_json = dict(json)
|
|
if not type(self).post_responses:
|
|
raise AssertionError("No fake Dify stream response queued")
|
|
return type(self).post_responses.pop(0)
|
|
|
|
def get(self, url: str, params: Dict[str, Any]):
|
|
type(self).last_get_url = url
|
|
type(self).last_get_params = dict(params)
|
|
if not type(self).get_payloads:
|
|
raise AssertionError("No fake Dify JSON payload queued")
|
|
return _FakeJSONResponse(type(self).get_payloads.pop(0))
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dify_provider_streams_message_answer_and_tracks_conversation(monkeypatch):
|
|
monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession)
|
|
_FakeClientSession.post_responses = [
|
|
_FakeStreamResponse(
|
|
[
|
|
b'data: {"event":"message","conversation_id":"conv-1","answer":"Hello "}\n',
|
|
b"\n",
|
|
b'data: {"event":"agent_message","conversation_id":"conv-1","answer":"from Dify."}\n',
|
|
b"\n",
|
|
b'data: {"event":"message_end","conversation_id":"conv-1"}\n',
|
|
b"\n",
|
|
]
|
|
)
|
|
]
|
|
|
|
service = DifyLLMService(api_key="key", base_url="https://dify.example/v1")
|
|
await service.connect()
|
|
|
|
events = [event async for event in service.generate_stream([LLMMessage(role="user", content="Hi there")])]
|
|
|
|
assert [event.type for event in events] == ["text_delta", "text_delta", "done"]
|
|
assert events[0].text == "Hello "
|
|
assert events[1].text == "from Dify."
|
|
assert service._conversation_id == "conv-1"
|
|
assert _FakeClientSession.last_post_url == "https://dify.example/v1/chat-messages"
|
|
assert _FakeClientSession.last_post_json == {
|
|
"inputs": {},
|
|
"query": "Hi there",
|
|
"user": service._user_id,
|
|
"response_mode": "streaming",
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dify_provider_reuses_conversation_id_on_follow_up(monkeypatch):
|
|
monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession)
|
|
_FakeClientSession.post_responses = [
|
|
_FakeStreamResponse(
|
|
[
|
|
b'data: {"event":"message","conversation_id":"conv-2","answer":"First"}\n',
|
|
b"\n",
|
|
]
|
|
),
|
|
_FakeStreamResponse(
|
|
[
|
|
b'data: {"event":"message","conversation_id":"conv-2","answer":"Second"}\n',
|
|
b"\n",
|
|
]
|
|
),
|
|
]
|
|
|
|
service = DifyLLMService(api_key="key", base_url="https://dify.example/v1")
|
|
await service.connect()
|
|
|
|
_ = [event async for event in service.generate_stream([LLMMessage(role="user", content="Turn one")])]
|
|
_ = [event async for event in service.generate_stream([LLMMessage(role="user", content="Turn two")])]
|
|
|
|
assert _FakeClientSession.last_post_json == {
|
|
"inputs": {},
|
|
"query": "Turn two",
|
|
"user": service._user_id,
|
|
"response_mode": "streaming",
|
|
"conversation_id": "conv-2",
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dify_provider_loads_initial_greeting_from_parameters(monkeypatch):
|
|
monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession)
|
|
_FakeClientSession.get_payloads = [
|
|
{"opening_statement": "Hello from Dify."},
|
|
]
|
|
|
|
service = DifyLLMService(api_key="key", base_url="https://dify.example/v1")
|
|
await service.connect()
|
|
|
|
greeting = await service.get_initial_greeting()
|
|
|
|
assert greeting == "Hello from Dify."
|
|
assert _FakeClientSession.last_get_url == "https://dify.example/v1/parameters"
|
|
assert _FakeClientSession.last_get_params == {"user": service._user_id}
|