feat: Implement Dify LLM provider and update related configurations and tests

This commit is contained in:
Xin Wang
2026-03-11 16:35:59 +08:00
parent 3b9ee80c8f
commit 5eec8f2b30
7 changed files with 455 additions and 3 deletions

View File

@@ -0,0 +1,226 @@
"""Dify-backed LLM provider."""
from __future__ import annotations
import asyncio
import json
import uuid
from typing import Any, AsyncIterator, Dict, List, Optional
import aiohttp
from loguru import logger
from providers.common.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
class DifyLLMService(BaseLLMService):
"""LLM provider that delegates chat orchestration to Dify Service API."""
def __init__(
self,
*,
api_key: str,
base_url: str,
model: str = "dify",
system_prompt: Optional[str] = None,
):
super().__init__(model=model or "dify")
self.api_key = api_key
self.base_url = str(base_url or "").rstrip("/")
self.system_prompt = system_prompt or ""
self._session: Optional[aiohttp.ClientSession] = None
self._cancel_event = asyncio.Event()
self._conversation_id: Optional[str] = None
self._user_id = f"engine_{uuid.uuid4().hex}"
self._knowledge_config: Dict[str, Any] = {}
self._tool_schemas: List[Dict[str, Any]] = []
async def connect(self) -> None:
if not self.api_key:
raise ValueError("Dify API key not provided")
if not self.base_url:
raise ValueError("Dify base URL not provided")
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
self._session = aiohttp.ClientSession(headers=headers)
self.state = ServiceState.CONNECTED
logger.info("Dify LLM service connected: base_url={}", self.base_url)
async def disconnect(self) -> None:
if self._session is not None:
await self._session.close()
self._session = None
self.state = ServiceState.DISCONNECTED
logger.info("Dify LLM service disconnected")
def cancel(self) -> None:
self._cancel_event.set()
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
# Dify owns retriever orchestration in this provider mode.
self._knowledge_config = dict(config or {})
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
# Dify owns tool/workflow orchestration in this provider mode.
self._tool_schemas = list(schemas or [])
async def get_initial_greeting(self) -> Optional[str]:
if self._session is None:
return None
url = f"{self.base_url}/parameters"
async with self._session.get(url, params={"user": self._user_id}) as response:
await self._raise_for_status(response, "Dify parameters request failed")
payload = await response.json()
opening_statement = str(payload.get("opening_statement") or "").strip()
return opening_statement or None
async def generate(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
) -> str:
parts: List[str] = []
async for event in self.generate_stream(messages, temperature=temperature, max_tokens=max_tokens):
if event.type == "text_delta" and event.text:
parts.append(event.text)
return "".join(parts)
async def generate_stream(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
) -> AsyncIterator[LLMStreamEvent]:
del temperature, max_tokens
if self._session is None:
raise RuntimeError("LLM service not connected")
query = self._extract_query(messages)
if not query:
yield LLMStreamEvent(type="done")
return
if self.system_prompt:
logger.debug("Ignoring local system prompt for Dify-managed assistant config")
payload: Dict[str, Any] = {
"inputs": {},
"query": query,
"user": self._user_id,
"response_mode": "streaming",
}
if self._conversation_id:
payload["conversation_id"] = self._conversation_id
self._cancel_event.clear()
url = f"{self.base_url}/chat-messages"
response = await self._session.post(url, json=payload)
try:
await self._raise_for_status(response, "Dify chat request failed")
async for event in self._iter_sse_events(response):
if self._cancel_event.is_set():
logger.info("Dify stream cancelled")
break
event_name = str(event.get("event") or "").strip().lower()
if event.get("conversation_id"):
self._conversation_id = str(event.get("conversation_id"))
if event_name in {"message", "agent_message"}:
text = self._extract_text_delta(event)
if text:
yield LLMStreamEvent(type="text_delta", text=text)
elif event_name == "error":
raise RuntimeError(str(event.get("message") or event.get("error") or "Dify stream error"))
elif event_name in {"message_end", "agent_message_end"}:
continue
finally:
response.close()
yield LLMStreamEvent(type="done")
@staticmethod
def _extract_query(messages: List[LLMMessage]) -> str:
for message in reversed(messages):
if str(message.role or "").strip().lower() == "user":
return str(message.content or "").strip()
for message in reversed(messages):
content = str(message.content or "").strip()
if content:
return content
return ""
@staticmethod
def _extract_text_delta(event: Dict[str, Any]) -> str:
for key in ("answer", "text", "content"):
value = event.get(key)
if value is not None:
text = str(value)
if text:
return text
return ""
async def _raise_for_status(self, response: aiohttp.ClientResponse, context: str) -> None:
if int(response.status) < 400:
return
try:
payload = await response.json()
except Exception:
payload = await response.text()
raise RuntimeError(f"{context}: HTTP {response.status} {payload}")
async def _iter_sse_events(self, response: aiohttp.ClientResponse) -> AsyncIterator[Dict[str, Any]]:
event_name = ""
data_lines: List[str] = []
async for raw_line in response.content:
line = raw_line.decode("utf-8", errors="ignore").rstrip("\r\n")
if not line:
payload = self._decode_sse_payload(event_name, data_lines)
event_name = ""
data_lines = []
if payload is not None:
yield payload
continue
if line.startswith(":"):
continue
if line.startswith("event:"):
event_name = line.split(":", 1)[1].strip()
continue
if line.startswith("data:"):
data_lines.append(line.split(":", 1)[1].lstrip())
payload = self._decode_sse_payload(event_name, data_lines)
if payload is not None:
yield payload
@staticmethod
def _decode_sse_payload(event_name: str, data_lines: List[str]) -> Optional[Dict[str, Any]]:
if not data_lines:
return None
raw = "\n".join(data_lines).strip()
if not raw:
return None
if raw == "[DONE]":
return {"event": "message_end"}
try:
payload = json.loads(raw)
except json.JSONDecodeError:
logger.debug("Skipping non-JSON Dify SSE payload: {}", raw)
return None
if not isinstance(payload, dict):
return None
if event_name and not payload.get("event"):
payload["event"] = event_name
return payload