227 lines
7.9 KiB
Python
227 lines
7.9 KiB
Python
"""Dify-backed LLM provider."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
|
|
import aiohttp
|
|
from loguru import logger
|
|
|
|
from providers.common.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
|
|
|
|
|
|
class DifyLLMService(BaseLLMService):
|
|
"""LLM provider that delegates chat orchestration to Dify Service API."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key: str,
|
|
base_url: str,
|
|
model: str = "dify",
|
|
system_prompt: Optional[str] = None,
|
|
):
|
|
super().__init__(model=model or "dify")
|
|
self.api_key = api_key
|
|
self.base_url = str(base_url or "").rstrip("/")
|
|
self.system_prompt = system_prompt or ""
|
|
self._session: Optional[aiohttp.ClientSession] = None
|
|
self._cancel_event = asyncio.Event()
|
|
self._conversation_id: Optional[str] = None
|
|
self._user_id = f"engine_{uuid.uuid4().hex}"
|
|
self._knowledge_config: Dict[str, Any] = {}
|
|
self._tool_schemas: List[Dict[str, Any]] = []
|
|
|
|
async def connect(self) -> None:
|
|
if not self.api_key:
|
|
raise ValueError("Dify API key not provided")
|
|
if not self.base_url:
|
|
raise ValueError("Dify base URL not provided")
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
self._session = aiohttp.ClientSession(headers=headers)
|
|
self.state = ServiceState.CONNECTED
|
|
logger.info("Dify LLM service connected: base_url={}", self.base_url)
|
|
|
|
async def disconnect(self) -> None:
|
|
if self._session is not None:
|
|
await self._session.close()
|
|
self._session = None
|
|
self.state = ServiceState.DISCONNECTED
|
|
logger.info("Dify LLM service disconnected")
|
|
|
|
def cancel(self) -> None:
|
|
self._cancel_event.set()
|
|
|
|
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
|
|
# Dify owns retriever orchestration in this provider mode.
|
|
self._knowledge_config = dict(config or {})
|
|
|
|
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
|
|
# Dify owns tool/workflow orchestration in this provider mode.
|
|
self._tool_schemas = list(schemas or [])
|
|
|
|
async def get_initial_greeting(self) -> Optional[str]:
|
|
if self._session is None:
|
|
return None
|
|
|
|
url = f"{self.base_url}/parameters"
|
|
async with self._session.get(url, params={"user": self._user_id}) as response:
|
|
await self._raise_for_status(response, "Dify parameters request failed")
|
|
payload = await response.json()
|
|
|
|
opening_statement = str(payload.get("opening_statement") or "").strip()
|
|
return opening_statement or None
|
|
|
|
async def generate(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
) -> str:
|
|
parts: List[str] = []
|
|
async for event in self.generate_stream(messages, temperature=temperature, max_tokens=max_tokens):
|
|
if event.type == "text_delta" and event.text:
|
|
parts.append(event.text)
|
|
return "".join(parts)
|
|
|
|
async def generate_stream(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
) -> AsyncIterator[LLMStreamEvent]:
|
|
del temperature, max_tokens
|
|
if self._session is None:
|
|
raise RuntimeError("LLM service not connected")
|
|
|
|
query = self._extract_query(messages)
|
|
if not query:
|
|
yield LLMStreamEvent(type="done")
|
|
return
|
|
|
|
if self.system_prompt:
|
|
logger.debug("Ignoring local system prompt for Dify-managed assistant config")
|
|
|
|
payload: Dict[str, Any] = {
|
|
"inputs": {},
|
|
"query": query,
|
|
"user": self._user_id,
|
|
"response_mode": "streaming",
|
|
}
|
|
if self._conversation_id:
|
|
payload["conversation_id"] = self._conversation_id
|
|
|
|
self._cancel_event.clear()
|
|
url = f"{self.base_url}/chat-messages"
|
|
response = await self._session.post(url, json=payload)
|
|
try:
|
|
await self._raise_for_status(response, "Dify chat request failed")
|
|
async for event in self._iter_sse_events(response):
|
|
if self._cancel_event.is_set():
|
|
logger.info("Dify stream cancelled")
|
|
break
|
|
|
|
event_name = str(event.get("event") or "").strip().lower()
|
|
if event.get("conversation_id"):
|
|
self._conversation_id = str(event.get("conversation_id"))
|
|
|
|
if event_name in {"message", "agent_message"}:
|
|
text = self._extract_text_delta(event)
|
|
if text:
|
|
yield LLMStreamEvent(type="text_delta", text=text)
|
|
elif event_name == "error":
|
|
raise RuntimeError(str(event.get("message") or event.get("error") or "Dify stream error"))
|
|
elif event_name in {"message_end", "agent_message_end"}:
|
|
continue
|
|
finally:
|
|
response.close()
|
|
|
|
yield LLMStreamEvent(type="done")
|
|
|
|
@staticmethod
|
|
def _extract_query(messages: List[LLMMessage]) -> str:
|
|
for message in reversed(messages):
|
|
if str(message.role or "").strip().lower() == "user":
|
|
return str(message.content or "").strip()
|
|
for message in reversed(messages):
|
|
content = str(message.content or "").strip()
|
|
if content:
|
|
return content
|
|
return ""
|
|
|
|
@staticmethod
|
|
def _extract_text_delta(event: Dict[str, Any]) -> str:
|
|
for key in ("answer", "text", "content"):
|
|
value = event.get(key)
|
|
if value is not None:
|
|
text = str(value)
|
|
if text:
|
|
return text
|
|
return ""
|
|
|
|
async def _raise_for_status(self, response: aiohttp.ClientResponse, context: str) -> None:
|
|
if int(response.status) < 400:
|
|
return
|
|
|
|
try:
|
|
payload = await response.json()
|
|
except Exception:
|
|
payload = await response.text()
|
|
raise RuntimeError(f"{context}: HTTP {response.status} {payload}")
|
|
|
|
async def _iter_sse_events(self, response: aiohttp.ClientResponse) -> AsyncIterator[Dict[str, Any]]:
|
|
event_name = ""
|
|
data_lines: List[str] = []
|
|
|
|
async for raw_line in response.content:
|
|
line = raw_line.decode("utf-8", errors="ignore").rstrip("\r\n")
|
|
if not line:
|
|
payload = self._decode_sse_payload(event_name, data_lines)
|
|
event_name = ""
|
|
data_lines = []
|
|
if payload is not None:
|
|
yield payload
|
|
continue
|
|
|
|
if line.startswith(":"):
|
|
continue
|
|
if line.startswith("event:"):
|
|
event_name = line.split(":", 1)[1].strip()
|
|
continue
|
|
if line.startswith("data:"):
|
|
data_lines.append(line.split(":", 1)[1].lstrip())
|
|
|
|
payload = self._decode_sse_payload(event_name, data_lines)
|
|
if payload is not None:
|
|
yield payload
|
|
|
|
@staticmethod
|
|
def _decode_sse_payload(event_name: str, data_lines: List[str]) -> Optional[Dict[str, Any]]:
|
|
if not data_lines:
|
|
return None
|
|
|
|
raw = "\n".join(data_lines).strip()
|
|
if not raw:
|
|
return None
|
|
if raw == "[DONE]":
|
|
return {"event": "message_end"}
|
|
|
|
try:
|
|
payload = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
logger.debug("Skipping non-JSON Dify SSE payload: {}", raw)
|
|
return None
|
|
|
|
if not isinstance(payload, dict):
|
|
return None
|
|
if event_name and not payload.get("event"):
|
|
payload["event"] = event_name
|
|
return payload
|