"""Dify-backed LLM provider.""" from __future__ import annotations import asyncio import json import uuid from typing import Any, AsyncIterator, Dict, List, Optional import aiohttp from loguru import logger from providers.common.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState class DifyLLMService(BaseLLMService): """LLM provider that delegates chat orchestration to Dify Service API.""" def __init__( self, *, api_key: str, base_url: str, model: str = "dify", system_prompt: Optional[str] = None, ): super().__init__(model=model or "dify") self.api_key = api_key self.base_url = str(base_url or "").rstrip("/") self.system_prompt = system_prompt or "" self._session: Optional[aiohttp.ClientSession] = None self._cancel_event = asyncio.Event() self._conversation_id: Optional[str] = None self._user_id = f"engine_{uuid.uuid4().hex}" self._knowledge_config: Dict[str, Any] = {} self._tool_schemas: List[Dict[str, Any]] = [] async def connect(self) -> None: if not self.api_key: raise ValueError("Dify API key not provided") if not self.base_url: raise ValueError("Dify base URL not provided") headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } self._session = aiohttp.ClientSession(headers=headers) self.state = ServiceState.CONNECTED logger.info("Dify LLM service connected: base_url={}", self.base_url) async def disconnect(self) -> None: if self._session is not None: await self._session.close() self._session = None self.state = ServiceState.DISCONNECTED logger.info("Dify LLM service disconnected") def cancel(self) -> None: self._cancel_event.set() def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None: # Dify owns retriever orchestration in this provider mode. self._knowledge_config = dict(config or {}) def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None: # Dify owns tool/workflow orchestration in this provider mode. self._tool_schemas = list(schemas or []) async def get_initial_greeting(self) -> Optional[str]: if self._session is None: return None url = f"{self.base_url}/parameters" async with self._session.get(url, params={"user": self._user_id}) as response: await self._raise_for_status(response, "Dify parameters request failed") payload = await response.json() opening_statement = str(payload.get("opening_statement") or "").strip() return opening_statement or None async def generate( self, messages: List[LLMMessage], temperature: float = 0.7, max_tokens: Optional[int] = None, ) -> str: parts: List[str] = [] async for event in self.generate_stream(messages, temperature=temperature, max_tokens=max_tokens): if event.type == "text_delta" and event.text: parts.append(event.text) return "".join(parts) async def generate_stream( self, messages: List[LLMMessage], temperature: float = 0.7, max_tokens: Optional[int] = None, ) -> AsyncIterator[LLMStreamEvent]: del temperature, max_tokens if self._session is None: raise RuntimeError("LLM service not connected") query = self._extract_query(messages) if not query: yield LLMStreamEvent(type="done") return if self.system_prompt: logger.debug("Ignoring local system prompt for Dify-managed assistant config") payload: Dict[str, Any] = { "inputs": {}, "query": query, "user": self._user_id, "response_mode": "streaming", } if self._conversation_id: payload["conversation_id"] = self._conversation_id self._cancel_event.clear() url = f"{self.base_url}/chat-messages" response = await self._session.post(url, json=payload) try: await self._raise_for_status(response, "Dify chat request failed") async for event in self._iter_sse_events(response): if self._cancel_event.is_set(): logger.info("Dify stream cancelled") break event_name = str(event.get("event") or "").strip().lower() if event.get("conversation_id"): self._conversation_id = str(event.get("conversation_id")) if event_name in {"message", "agent_message"}: text = self._extract_text_delta(event) if text: yield LLMStreamEvent(type="text_delta", text=text) elif event_name == "error": raise RuntimeError(str(event.get("message") or event.get("error") or "Dify stream error")) elif event_name in {"message_end", "agent_message_end"}: continue finally: response.close() yield LLMStreamEvent(type="done") @staticmethod def _extract_query(messages: List[LLMMessage]) -> str: for message in reversed(messages): if str(message.role or "").strip().lower() == "user": return str(message.content or "").strip() for message in reversed(messages): content = str(message.content or "").strip() if content: return content return "" @staticmethod def _extract_text_delta(event: Dict[str, Any]) -> str: for key in ("answer", "text", "content"): value = event.get(key) if value is not None: text = str(value) if text: return text return "" async def _raise_for_status(self, response: aiohttp.ClientResponse, context: str) -> None: if int(response.status) < 400: return try: payload = await response.json() except Exception: payload = await response.text() raise RuntimeError(f"{context}: HTTP {response.status} {payload}") async def _iter_sse_events(self, response: aiohttp.ClientResponse) -> AsyncIterator[Dict[str, Any]]: event_name = "" data_lines: List[str] = [] async for raw_line in response.content: line = raw_line.decode("utf-8", errors="ignore").rstrip("\r\n") if not line: payload = self._decode_sse_payload(event_name, data_lines) event_name = "" data_lines = [] if payload is not None: yield payload continue if line.startswith(":"): continue if line.startswith("event:"): event_name = line.split(":", 1)[1].strip() continue if line.startswith("data:"): data_lines.append(line.split(":", 1)[1].lstrip()) payload = self._decode_sse_payload(event_name, data_lines) if payload is not None: yield payload @staticmethod def _decode_sse_payload(event_name: str, data_lines: List[str]) -> Optional[Dict[str, Any]]: if not data_lines: return None raw = "\n".join(data_lines).strip() if not raw: return None if raw == "[DONE]": return {"event": "message_end"} try: payload = json.loads(raw) except json.JSONDecodeError: logger.debug("Skipping non-JSON Dify SSE payload: {}", raw) return None if not isinstance(payload, dict): return None if event_name and not payload.get("event"): payload["event"] = event_name return payload