feat: Implement Dify LLM provider and update related configurations and tests

This commit is contained in:
Xin Wang
2026-03-11 16:35:59 +08:00
parent 3b9ee80c8f
commit 5eec8f2b30
7 changed files with 455 additions and 3 deletions

View File

@@ -28,7 +28,7 @@ from providers.tts.volcengine import VolcengineTTSService
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
_DASHSCOPE_PROVIDERS = {"dashscope"}
_VOLCENGINE_PROVIDERS = {"volcengine"}
_SUPPORTED_LLM_PROVIDERS = {"openai", "fastgpt", *_OPENAI_COMPATIBLE_PROVIDERS}
_SUPPORTED_LLM_PROVIDERS = {"openai", "dify", "fastgpt", *_OPENAI_COMPATIBLE_PROVIDERS}
class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
@@ -58,6 +58,16 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
def create_llm_service(self, spec: LLMServiceSpec) -> LLMPort:
provider = self._normalize_provider(spec.provider)
if provider == "dify" and spec.api_key and spec.base_url:
from providers.llm.dify import DifyLLMService
return DifyLLMService(
api_key=spec.api_key,
base_url=spec.base_url,
model=spec.model,
system_prompt=spec.system_prompt,
)
if provider == "fastgpt" and spec.api_key and spec.base_url:
from providers.llm.fastgpt import FastGPTLLMService

View File

@@ -1,5 +1,6 @@
"""LLM providers."""
from providers.llm.dify import DifyLLMService
from providers.llm.openai import MockLLMService, OpenAILLMService
try: # pragma: no cover - import depends on optional sibling SDK
@@ -8,6 +9,7 @@ except Exception: # pragma: no cover - provider remains lazily available via fa
FastGPTLLMService = None # type: ignore[assignment]
__all__ = [
"DifyLLMService",
"FastGPTLLMService",
"MockLLMService",
"OpenAILLMService",

View File

@@ -0,0 +1,226 @@
"""Dify-backed LLM provider."""
from __future__ import annotations
import asyncio
import json
import uuid
from typing import Any, AsyncIterator, Dict, List, Optional
import aiohttp
from loguru import logger
from providers.common.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
class DifyLLMService(BaseLLMService):
"""LLM provider that delegates chat orchestration to Dify Service API."""
def __init__(
self,
*,
api_key: str,
base_url: str,
model: str = "dify",
system_prompt: Optional[str] = None,
):
super().__init__(model=model or "dify")
self.api_key = api_key
self.base_url = str(base_url or "").rstrip("/")
self.system_prompt = system_prompt or ""
self._session: Optional[aiohttp.ClientSession] = None
self._cancel_event = asyncio.Event()
self._conversation_id: Optional[str] = None
self._user_id = f"engine_{uuid.uuid4().hex}"
self._knowledge_config: Dict[str, Any] = {}
self._tool_schemas: List[Dict[str, Any]] = []
async def connect(self) -> None:
if not self.api_key:
raise ValueError("Dify API key not provided")
if not self.base_url:
raise ValueError("Dify base URL not provided")
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
self._session = aiohttp.ClientSession(headers=headers)
self.state = ServiceState.CONNECTED
logger.info("Dify LLM service connected: base_url={}", self.base_url)
async def disconnect(self) -> None:
if self._session is not None:
await self._session.close()
self._session = None
self.state = ServiceState.DISCONNECTED
logger.info("Dify LLM service disconnected")
def cancel(self) -> None:
self._cancel_event.set()
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
# Dify owns retriever orchestration in this provider mode.
self._knowledge_config = dict(config or {})
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
# Dify owns tool/workflow orchestration in this provider mode.
self._tool_schemas = list(schemas or [])
async def get_initial_greeting(self) -> Optional[str]:
if self._session is None:
return None
url = f"{self.base_url}/parameters"
async with self._session.get(url, params={"user": self._user_id}) as response:
await self._raise_for_status(response, "Dify parameters request failed")
payload = await response.json()
opening_statement = str(payload.get("opening_statement") or "").strip()
return opening_statement or None
async def generate(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
) -> str:
parts: List[str] = []
async for event in self.generate_stream(messages, temperature=temperature, max_tokens=max_tokens):
if event.type == "text_delta" and event.text:
parts.append(event.text)
return "".join(parts)
async def generate_stream(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
) -> AsyncIterator[LLMStreamEvent]:
del temperature, max_tokens
if self._session is None:
raise RuntimeError("LLM service not connected")
query = self._extract_query(messages)
if not query:
yield LLMStreamEvent(type="done")
return
if self.system_prompt:
logger.debug("Ignoring local system prompt for Dify-managed assistant config")
payload: Dict[str, Any] = {
"inputs": {},
"query": query,
"user": self._user_id,
"response_mode": "streaming",
}
if self._conversation_id:
payload["conversation_id"] = self._conversation_id
self._cancel_event.clear()
url = f"{self.base_url}/chat-messages"
response = await self._session.post(url, json=payload)
try:
await self._raise_for_status(response, "Dify chat request failed")
async for event in self._iter_sse_events(response):
if self._cancel_event.is_set():
logger.info("Dify stream cancelled")
break
event_name = str(event.get("event") or "").strip().lower()
if event.get("conversation_id"):
self._conversation_id = str(event.get("conversation_id"))
if event_name in {"message", "agent_message"}:
text = self._extract_text_delta(event)
if text:
yield LLMStreamEvent(type="text_delta", text=text)
elif event_name == "error":
raise RuntimeError(str(event.get("message") or event.get("error") or "Dify stream error"))
elif event_name in {"message_end", "agent_message_end"}:
continue
finally:
response.close()
yield LLMStreamEvent(type="done")
@staticmethod
def _extract_query(messages: List[LLMMessage]) -> str:
for message in reversed(messages):
if str(message.role or "").strip().lower() == "user":
return str(message.content or "").strip()
for message in reversed(messages):
content = str(message.content or "").strip()
if content:
return content
return ""
@staticmethod
def _extract_text_delta(event: Dict[str, Any]) -> str:
for key in ("answer", "text", "content"):
value = event.get(key)
if value is not None:
text = str(value)
if text:
return text
return ""
async def _raise_for_status(self, response: aiohttp.ClientResponse, context: str) -> None:
if int(response.status) < 400:
return
try:
payload = await response.json()
except Exception:
payload = await response.text()
raise RuntimeError(f"{context}: HTTP {response.status} {payload}")
async def _iter_sse_events(self, response: aiohttp.ClientResponse) -> AsyncIterator[Dict[str, Any]]:
event_name = ""
data_lines: List[str] = []
async for raw_line in response.content:
line = raw_line.decode("utf-8", errors="ignore").rstrip("\r\n")
if not line:
payload = self._decode_sse_payload(event_name, data_lines)
event_name = ""
data_lines = []
if payload is not None:
yield payload
continue
if line.startswith(":"):
continue
if line.startswith("event:"):
event_name = line.split(":", 1)[1].strip()
continue
if line.startswith("data:"):
data_lines.append(line.split(":", 1)[1].lstrip())
payload = self._decode_sse_payload(event_name, data_lines)
if payload is not None:
yield payload
@staticmethod
def _decode_sse_payload(event_name: str, data_lines: List[str]) -> Optional[Dict[str, Any]]:
if not data_lines:
return None
raw = "\n".join(data_lines).strip()
if not raw:
return None
if raw == "[DONE]":
return {"event": "message_end"}
try:
payload = json.loads(raw)
except json.JSONDecodeError:
logger.debug("Skipping non-JSON Dify SSE payload: {}", raw)
return None
if not isinstance(payload, dict):
return None
if event_name and not payload.get("event"):
payload["event"] = event_name
return payload

View File

@@ -0,0 +1,166 @@
from typing import Any, Dict, List, Optional
import pytest
from providers.common.base import LLMMessage
from providers.llm.dify import DifyLLMService
class _FakeStreamResponse:
def __init__(self, lines: List[bytes], status: int = 200):
self.content = _FakeStreamContent(lines)
self.status = status
self.closed = False
async def json(self) -> Dict[str, Any]:
return {}
async def text(self) -> str:
return ""
def close(self) -> None:
self.closed = True
class _FakeJSONResponse:
def __init__(self, payload: Dict[str, Any], status: int = 200):
self.payload = payload
self.status = status
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def json(self) -> Dict[str, Any]:
return dict(self.payload)
async def text(self) -> str:
return ""
class _FakeStreamContent:
def __init__(self, lines: List[bytes]):
self._lines = list(lines)
def __aiter__(self):
return self._iter()
async def _iter(self):
for line in self._lines:
yield line
class _FakeClientSession:
post_responses: List[_FakeStreamResponse] = []
get_payloads: List[Dict[str, Any]] = []
last_post_url: Optional[str] = None
last_post_json: Optional[Dict[str, Any]] = None
last_get_url: Optional[str] = None
last_get_params: Optional[Dict[str, Any]] = None
def __init__(self, headers: Optional[Dict[str, str]] = None):
self.headers = headers or {}
self.closed = False
async def close(self) -> None:
self.closed = True
async def post(self, url: str, json: Dict[str, Any]):
type(self).last_post_url = url
type(self).last_post_json = dict(json)
if not type(self).post_responses:
raise AssertionError("No fake Dify stream response queued")
return type(self).post_responses.pop(0)
def get(self, url: str, params: Dict[str, Any]):
type(self).last_get_url = url
type(self).last_get_params = dict(params)
if not type(self).get_payloads:
raise AssertionError("No fake Dify JSON payload queued")
return _FakeJSONResponse(type(self).get_payloads.pop(0))
@pytest.mark.asyncio
async def test_dify_provider_streams_message_answer_and_tracks_conversation(monkeypatch):
monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession)
_FakeClientSession.post_responses = [
_FakeStreamResponse(
[
b'data: {"event":"message","conversation_id":"conv-1","answer":"Hello "}\n',
b"\n",
b'data: {"event":"agent_message","conversation_id":"conv-1","answer":"from Dify."}\n',
b"\n",
b'data: {"event":"message_end","conversation_id":"conv-1"}\n',
b"\n",
]
)
]
service = DifyLLMService(api_key="key", base_url="https://dify.example/v1")
await service.connect()
events = [event async for event in service.generate_stream([LLMMessage(role="user", content="Hi there")])]
assert [event.type for event in events] == ["text_delta", "text_delta", "done"]
assert events[0].text == "Hello "
assert events[1].text == "from Dify."
assert service._conversation_id == "conv-1"
assert _FakeClientSession.last_post_url == "https://dify.example/v1/chat-messages"
assert _FakeClientSession.last_post_json == {
"inputs": {},
"query": "Hi there",
"user": service._user_id,
"response_mode": "streaming",
}
@pytest.mark.asyncio
async def test_dify_provider_reuses_conversation_id_on_follow_up(monkeypatch):
monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession)
_FakeClientSession.post_responses = [
_FakeStreamResponse(
[
b'data: {"event":"message","conversation_id":"conv-2","answer":"First"}\n',
b"\n",
]
),
_FakeStreamResponse(
[
b'data: {"event":"message","conversation_id":"conv-2","answer":"Second"}\n',
b"\n",
]
),
]
service = DifyLLMService(api_key="key", base_url="https://dify.example/v1")
await service.connect()
_ = [event async for event in service.generate_stream([LLMMessage(role="user", content="Turn one")])]
_ = [event async for event in service.generate_stream([LLMMessage(role="user", content="Turn two")])]
assert _FakeClientSession.last_post_json == {
"inputs": {},
"query": "Turn two",
"user": service._user_id,
"response_mode": "streaming",
"conversation_id": "conv-2",
}
@pytest.mark.asyncio
async def test_dify_provider_loads_initial_greeting_from_parameters(monkeypatch):
monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession)
_FakeClientSession.get_payloads = [
{"opening_statement": "Hello from Dify."},
]
service = DifyLLMService(api_key="key", base_url="https://dify.example/v1")
await service.connect()
greeting = await service.get_initial_greeting()
assert greeting == "Hello from Dify."
assert _FakeClientSession.last_get_url == "https://dify.example/v1/parameters"
assert _FakeClientSession.last_get_params == {"user": service._user_id}

View File

@@ -0,0 +1,32 @@
from providers.factory.default import DefaultRealtimeServiceFactory
from providers.llm.dify import DifyLLMService
from providers.llm.openai import OpenAILLMService
from runtime.ports import LLMServiceSpec
def test_create_llm_service_dify_returns_dify_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_llm_service(
LLMServiceSpec(
provider="dify",
model="dify",
api_key="test-key",
base_url="https://dify.example/v1",
)
)
assert isinstance(service, DifyLLMService)
def test_create_llm_service_openai_returns_openai_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_llm_service(
LLMServiceSpec(
provider="openai",
model="gpt-4o-mini",
api_key="test-key",
base_url="https://api.openai.com/v1",
)
)
assert isinstance(service, OpenAILLMService)