feat: Implement Dify LLM provider and update related configurations and tests
This commit is contained in:
166
engine/tests/test_dify_provider.py
Normal file
166
engine/tests/test_dify_provider.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.common.base import LLMMessage
|
||||
from providers.llm.dify import DifyLLMService
|
||||
|
||||
|
||||
class _FakeStreamResponse:
|
||||
def __init__(self, lines: List[bytes], status: int = 200):
|
||||
self.content = _FakeStreamContent(lines)
|
||||
self.status = status
|
||||
self.closed = False
|
||||
|
||||
async def json(self) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
async def text(self) -> str:
|
||||
return ""
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
class _FakeJSONResponse:
|
||||
def __init__(self, payload: Dict[str, Any], status: int = 200):
|
||||
self.payload = payload
|
||||
self.status = status
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def json(self) -> Dict[str, Any]:
|
||||
return dict(self.payload)
|
||||
|
||||
async def text(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class _FakeStreamContent:
|
||||
def __init__(self, lines: List[bytes]):
|
||||
self._lines = list(lines)
|
||||
|
||||
def __aiter__(self):
|
||||
return self._iter()
|
||||
|
||||
async def _iter(self):
|
||||
for line in self._lines:
|
||||
yield line
|
||||
|
||||
|
||||
class _FakeClientSession:
|
||||
post_responses: List[_FakeStreamResponse] = []
|
||||
get_payloads: List[Dict[str, Any]] = []
|
||||
last_post_url: Optional[str] = None
|
||||
last_post_json: Optional[Dict[str, Any]] = None
|
||||
last_get_url: Optional[str] = None
|
||||
last_get_params: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __init__(self, headers: Optional[Dict[str, str]] = None):
|
||||
self.headers = headers or {}
|
||||
self.closed = False
|
||||
|
||||
async def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
async def post(self, url: str, json: Dict[str, Any]):
|
||||
type(self).last_post_url = url
|
||||
type(self).last_post_json = dict(json)
|
||||
if not type(self).post_responses:
|
||||
raise AssertionError("No fake Dify stream response queued")
|
||||
return type(self).post_responses.pop(0)
|
||||
|
||||
def get(self, url: str, params: Dict[str, Any]):
|
||||
type(self).last_get_url = url
|
||||
type(self).last_get_params = dict(params)
|
||||
if not type(self).get_payloads:
|
||||
raise AssertionError("No fake Dify JSON payload queued")
|
||||
return _FakeJSONResponse(type(self).get_payloads.pop(0))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dify_provider_streams_message_answer_and_tracks_conversation(monkeypatch):
|
||||
monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession)
|
||||
_FakeClientSession.post_responses = [
|
||||
_FakeStreamResponse(
|
||||
[
|
||||
b'data: {"event":"message","conversation_id":"conv-1","answer":"Hello "}\n',
|
||||
b"\n",
|
||||
b'data: {"event":"agent_message","conversation_id":"conv-1","answer":"from Dify."}\n',
|
||||
b"\n",
|
||||
b'data: {"event":"message_end","conversation_id":"conv-1"}\n',
|
||||
b"\n",
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
service = DifyLLMService(api_key="key", base_url="https://dify.example/v1")
|
||||
await service.connect()
|
||||
|
||||
events = [event async for event in service.generate_stream([LLMMessage(role="user", content="Hi there")])]
|
||||
|
||||
assert [event.type for event in events] == ["text_delta", "text_delta", "done"]
|
||||
assert events[0].text == "Hello "
|
||||
assert events[1].text == "from Dify."
|
||||
assert service._conversation_id == "conv-1"
|
||||
assert _FakeClientSession.last_post_url == "https://dify.example/v1/chat-messages"
|
||||
assert _FakeClientSession.last_post_json == {
|
||||
"inputs": {},
|
||||
"query": "Hi there",
|
||||
"user": service._user_id,
|
||||
"response_mode": "streaming",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dify_provider_reuses_conversation_id_on_follow_up(monkeypatch):
|
||||
monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession)
|
||||
_FakeClientSession.post_responses = [
|
||||
_FakeStreamResponse(
|
||||
[
|
||||
b'data: {"event":"message","conversation_id":"conv-2","answer":"First"}\n',
|
||||
b"\n",
|
||||
]
|
||||
),
|
||||
_FakeStreamResponse(
|
||||
[
|
||||
b'data: {"event":"message","conversation_id":"conv-2","answer":"Second"}\n',
|
||||
b"\n",
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
service = DifyLLMService(api_key="key", base_url="https://dify.example/v1")
|
||||
await service.connect()
|
||||
|
||||
_ = [event async for event in service.generate_stream([LLMMessage(role="user", content="Turn one")])]
|
||||
_ = [event async for event in service.generate_stream([LLMMessage(role="user", content="Turn two")])]
|
||||
|
||||
assert _FakeClientSession.last_post_json == {
|
||||
"inputs": {},
|
||||
"query": "Turn two",
|
||||
"user": service._user_id,
|
||||
"response_mode": "streaming",
|
||||
"conversation_id": "conv-2",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dify_provider_loads_initial_greeting_from_parameters(monkeypatch):
|
||||
monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession)
|
||||
_FakeClientSession.get_payloads = [
|
||||
{"opening_statement": "Hello from Dify."},
|
||||
]
|
||||
|
||||
service = DifyLLMService(api_key="key", base_url="https://dify.example/v1")
|
||||
await service.connect()
|
||||
|
||||
greeting = await service.get_initial_greeting()
|
||||
|
||||
assert greeting == "Hello from Dify."
|
||||
assert _FakeClientSession.last_get_url == "https://dify.example/v1/parameters"
|
||||
assert _FakeClientSession.last_get_params == {"user": service._user_id}
|
||||
Reference in New Issue
Block a user