feat: Implement Dify LLM provider and update related configurations and tests
This commit is contained in:
226
engine/providers/llm/dify.py
Normal file
226
engine/providers/llm/dify.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Dify-backed LLM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
|
||||
|
||||
|
||||
class DifyLLMService(BaseLLMService):
|
||||
"""LLM provider that delegates chat orchestration to Dify Service API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
model: str = "dify",
|
||||
system_prompt: Optional[str] = None,
|
||||
):
|
||||
super().__init__(model=model or "dify")
|
||||
self.api_key = api_key
|
||||
self.base_url = str(base_url or "").rstrip("/")
|
||||
self.system_prompt = system_prompt or ""
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._conversation_id: Optional[str] = None
|
||||
self._user_id = f"engine_{uuid.uuid4().hex}"
|
||||
self._knowledge_config: Dict[str, Any] = {}
|
||||
self._tool_schemas: List[Dict[str, Any]] = []
|
||||
|
||||
async def connect(self) -> None:
|
||||
if not self.api_key:
|
||||
raise ValueError("Dify API key not provided")
|
||||
if not self.base_url:
|
||||
raise ValueError("Dify base URL not provided")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self._session = aiohttp.ClientSession(headers=headers)
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info("Dify LLM service connected: base_url={}", self.base_url)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Dify LLM service disconnected")
|
||||
|
||||
def cancel(self) -> None:
|
||||
self._cancel_event.set()
|
||||
|
||||
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
|
||||
# Dify owns retriever orchestration in this provider mode.
|
||||
self._knowledge_config = dict(config or {})
|
||||
|
||||
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
|
||||
# Dify owns tool/workflow orchestration in this provider mode.
|
||||
self._tool_schemas = list(schemas or [])
|
||||
|
||||
async def get_initial_greeting(self) -> Optional[str]:
|
||||
if self._session is None:
|
||||
return None
|
||||
|
||||
url = f"{self.base_url}/parameters"
|
||||
async with self._session.get(url, params={"user": self._user_id}) as response:
|
||||
await self._raise_for_status(response, "Dify parameters request failed")
|
||||
payload = await response.json()
|
||||
|
||||
opening_statement = str(payload.get("opening_statement") or "").strip()
|
||||
return opening_statement or None
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
parts: List[str] = []
|
||||
async for event in self.generate_stream(messages, temperature=temperature, max_tokens=max_tokens):
|
||||
if event.type == "text_delta" and event.text:
|
||||
parts.append(event.text)
|
||||
return "".join(parts)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> AsyncIterator[LLMStreamEvent]:
|
||||
del temperature, max_tokens
|
||||
if self._session is None:
|
||||
raise RuntimeError("LLM service not connected")
|
||||
|
||||
query = self._extract_query(messages)
|
||||
if not query:
|
||||
yield LLMStreamEvent(type="done")
|
||||
return
|
||||
|
||||
if self.system_prompt:
|
||||
logger.debug("Ignoring local system prompt for Dify-managed assistant config")
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"inputs": {},
|
||||
"query": query,
|
||||
"user": self._user_id,
|
||||
"response_mode": "streaming",
|
||||
}
|
||||
if self._conversation_id:
|
||||
payload["conversation_id"] = self._conversation_id
|
||||
|
||||
self._cancel_event.clear()
|
||||
url = f"{self.base_url}/chat-messages"
|
||||
response = await self._session.post(url, json=payload)
|
||||
try:
|
||||
await self._raise_for_status(response, "Dify chat request failed")
|
||||
async for event in self._iter_sse_events(response):
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("Dify stream cancelled")
|
||||
break
|
||||
|
||||
event_name = str(event.get("event") or "").strip().lower()
|
||||
if event.get("conversation_id"):
|
||||
self._conversation_id = str(event.get("conversation_id"))
|
||||
|
||||
if event_name in {"message", "agent_message"}:
|
||||
text = self._extract_text_delta(event)
|
||||
if text:
|
||||
yield LLMStreamEvent(type="text_delta", text=text)
|
||||
elif event_name == "error":
|
||||
raise RuntimeError(str(event.get("message") or event.get("error") or "Dify stream error"))
|
||||
elif event_name in {"message_end", "agent_message_end"}:
|
||||
continue
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
yield LLMStreamEvent(type="done")
|
||||
|
||||
@staticmethod
|
||||
def _extract_query(messages: List[LLMMessage]) -> str:
|
||||
for message in reversed(messages):
|
||||
if str(message.role or "").strip().lower() == "user":
|
||||
return str(message.content or "").strip()
|
||||
for message in reversed(messages):
|
||||
content = str(message.content or "").strip()
|
||||
if content:
|
||||
return content
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_delta(event: Dict[str, Any]) -> str:
|
||||
for key in ("answer", "text", "content"):
|
||||
value = event.get(key)
|
||||
if value is not None:
|
||||
text = str(value)
|
||||
if text:
|
||||
return text
|
||||
return ""
|
||||
|
||||
async def _raise_for_status(self, response: aiohttp.ClientResponse, context: str) -> None:
|
||||
if int(response.status) < 400:
|
||||
return
|
||||
|
||||
try:
|
||||
payload = await response.json()
|
||||
except Exception:
|
||||
payload = await response.text()
|
||||
raise RuntimeError(f"{context}: HTTP {response.status} {payload}")
|
||||
|
||||
async def _iter_sse_events(self, response: aiohttp.ClientResponse) -> AsyncIterator[Dict[str, Any]]:
|
||||
event_name = ""
|
||||
data_lines: List[str] = []
|
||||
|
||||
async for raw_line in response.content:
|
||||
line = raw_line.decode("utf-8", errors="ignore").rstrip("\r\n")
|
||||
if not line:
|
||||
payload = self._decode_sse_payload(event_name, data_lines)
|
||||
event_name = ""
|
||||
data_lines = []
|
||||
if payload is not None:
|
||||
yield payload
|
||||
continue
|
||||
|
||||
if line.startswith(":"):
|
||||
continue
|
||||
if line.startswith("event:"):
|
||||
event_name = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
if line.startswith("data:"):
|
||||
data_lines.append(line.split(":", 1)[1].lstrip())
|
||||
|
||||
payload = self._decode_sse_payload(event_name, data_lines)
|
||||
if payload is not None:
|
||||
yield payload
|
||||
|
||||
@staticmethod
|
||||
def _decode_sse_payload(event_name: str, data_lines: List[str]) -> Optional[Dict[str, Any]]:
|
||||
if not data_lines:
|
||||
return None
|
||||
|
||||
raw = "\n".join(data_lines).strip()
|
||||
if not raw:
|
||||
return None
|
||||
if raw == "[DONE]":
|
||||
return {"event": "message_end"}
|
||||
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping non-JSON Dify SSE payload: {}", raw)
|
||||
return None
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
if event_name and not payload.get("event"):
|
||||
payload["event"] = event_name
|
||||
return payload
|
||||
Reference in New Issue
Block a user