diff --git a/api/app/routers/assistants.py b/api/app/routers/assistants.py index c398cc0..0f93bbb 100644 --- a/api/app/routers/assistants.py +++ b/api/app/routers/assistants.py @@ -302,8 +302,8 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s if config_mode == "dify": metadata["services"]["llm"] = { - "provider": "openai", - "model": "", + "provider": "dify", + "model": "dify", "apiKey": assistant.api_key, "baseUrl": assistant.api_url, } diff --git a/api/tests/test_assistants.py b/api/tests/test_assistants.py index eaab5b5..8b9f918 100644 --- a/api/tests/test_assistants.py +++ b/api/tests/test_assistants.py @@ -438,3 +438,19 @@ class TestAssistantAPI: metadata = runtime_resp.json()["sessionStartMetadata"] assert metadata["services"]["llm"]["provider"] == "fastgpt" assert metadata["services"]["llm"]["appId"] == "app-fastgpt-123" + + def test_dify_runtime_config_uses_dify_provider(self, client, sample_assistant_data): + sample_assistant_data.update({ + "configMode": "dify", + "apiUrl": "https://api.dify.ai/v1", + "apiKey": "dify-key", + }) + assistant_resp = client.post("/api/assistants", json=sample_assistant_data) + assert assistant_resp.status_code == 200 + assistant_id = assistant_resp.json()["id"] + + runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config") + assert runtime_resp.status_code == 200 + metadata = runtime_resp.json()["sessionStartMetadata"] + assert metadata["services"]["llm"]["provider"] == "dify" + assert metadata["services"]["llm"]["model"] == "dify" diff --git a/engine/providers/factory/default.py b/engine/providers/factory/default.py index 478d290..b2f4147 100644 --- a/engine/providers/factory/default.py +++ b/engine/providers/factory/default.py @@ -28,7 +28,7 @@ from providers.tts.volcengine import VolcengineTTSService _OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} _DASHSCOPE_PROVIDERS = {"dashscope"} _VOLCENGINE_PROVIDERS = {"volcengine"} -_SUPPORTED_LLM_PROVIDERS = {"openai", "fastgpt", *_OPENAI_COMPATIBLE_PROVIDERS} +_SUPPORTED_LLM_PROVIDERS = {"openai", "dify", "fastgpt", *_OPENAI_COMPATIBLE_PROVIDERS} class DefaultRealtimeServiceFactory(RealtimeServiceFactory): @@ -58,6 +58,16 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory): def create_llm_service(self, spec: LLMServiceSpec) -> LLMPort: provider = self._normalize_provider(spec.provider) + if provider == "dify" and spec.api_key and spec.base_url: + from providers.llm.dify import DifyLLMService + + return DifyLLMService( + api_key=spec.api_key, + base_url=spec.base_url, + model=spec.model, + system_prompt=spec.system_prompt, + ) + if provider == "fastgpt" and spec.api_key and spec.base_url: from providers.llm.fastgpt import FastGPTLLMService diff --git a/engine/providers/llm/__init__.py b/engine/providers/llm/__init__.py index 528d1e1..76a1c9d 100644 --- a/engine/providers/llm/__init__.py +++ b/engine/providers/llm/__init__.py @@ -1,5 +1,6 @@ """LLM providers.""" +from providers.llm.dify import DifyLLMService from providers.llm.openai import MockLLMService, OpenAILLMService try: # pragma: no cover - import depends on optional sibling SDK @@ -8,6 +9,7 @@ except Exception: # pragma: no cover - provider remains lazily available via fa FastGPTLLMService = None # type: ignore[assignment] __all__ = [ + "DifyLLMService", "FastGPTLLMService", "MockLLMService", "OpenAILLMService", diff --git a/engine/providers/llm/dify.py b/engine/providers/llm/dify.py new file mode 100644 index 0000000..91f7b74 --- /dev/null +++ b/engine/providers/llm/dify.py @@ -0,0 +1,226 @@ +"""Dify-backed LLM provider.""" + +from __future__ import annotations + +import asyncio +import json +import uuid +from typing import Any, AsyncIterator, Dict, List, Optional + +import aiohttp +from loguru import logger + +from providers.common.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState + + +class DifyLLMService(BaseLLMService): + """LLM provider that delegates chat orchestration to Dify Service API.""" + + def __init__( + self, + *, + api_key: str, + base_url: str, + model: str = "dify", + system_prompt: Optional[str] = None, + ): + super().__init__(model=model or "dify") + self.api_key = api_key + self.base_url = str(base_url or "").rstrip("/") + self.system_prompt = system_prompt or "" + self._session: Optional[aiohttp.ClientSession] = None + self._cancel_event = asyncio.Event() + self._conversation_id: Optional[str] = None + self._user_id = f"engine_{uuid.uuid4().hex}" + self._knowledge_config: Dict[str, Any] = {} + self._tool_schemas: List[Dict[str, Any]] = [] + + async def connect(self) -> None: + if not self.api_key: + raise ValueError("Dify API key not provided") + if not self.base_url: + raise ValueError("Dify base URL not provided") + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + self._session = aiohttp.ClientSession(headers=headers) + self.state = ServiceState.CONNECTED + logger.info("Dify LLM service connected: base_url={}", self.base_url) + + async def disconnect(self) -> None: + if self._session is not None: + await self._session.close() + self._session = None + self.state = ServiceState.DISCONNECTED + logger.info("Dify LLM service disconnected") + + def cancel(self) -> None: + self._cancel_event.set() + + def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None: + # Dify owns retriever orchestration in this provider mode. + self._knowledge_config = dict(config or {}) + + def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None: + # Dify owns tool/workflow orchestration in this provider mode. + self._tool_schemas = list(schemas or []) + + async def get_initial_greeting(self) -> Optional[str]: + if self._session is None: + return None + + url = f"{self.base_url}/parameters" + async with self._session.get(url, params={"user": self._user_id}) as response: + await self._raise_for_status(response, "Dify parameters request failed") + payload = await response.json() + + opening_statement = str(payload.get("opening_statement") or "").strip() + return opening_statement or None + + async def generate( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None, + ) -> str: + parts: List[str] = [] + async for event in self.generate_stream(messages, temperature=temperature, max_tokens=max_tokens): + if event.type == "text_delta" and event.text: + parts.append(event.text) + return "".join(parts) + + async def generate_stream( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None, + ) -> AsyncIterator[LLMStreamEvent]: + del temperature, max_tokens + if self._session is None: + raise RuntimeError("LLM service not connected") + + query = self._extract_query(messages) + if not query: + yield LLMStreamEvent(type="done") + return + + if self.system_prompt: + logger.debug("Ignoring local system prompt for Dify-managed assistant config") + + payload: Dict[str, Any] = { + "inputs": {}, + "query": query, + "user": self._user_id, + "response_mode": "streaming", + } + if self._conversation_id: + payload["conversation_id"] = self._conversation_id + + self._cancel_event.clear() + url = f"{self.base_url}/chat-messages" + response = await self._session.post(url, json=payload) + try: + await self._raise_for_status(response, "Dify chat request failed") + async for event in self._iter_sse_events(response): + if self._cancel_event.is_set(): + logger.info("Dify stream cancelled") + break + + event_name = str(event.get("event") or "").strip().lower() + if event.get("conversation_id"): + self._conversation_id = str(event.get("conversation_id")) + + if event_name in {"message", "agent_message"}: + text = self._extract_text_delta(event) + if text: + yield LLMStreamEvent(type="text_delta", text=text) + elif event_name == "error": + raise RuntimeError(str(event.get("message") or event.get("error") or "Dify stream error")) + elif event_name in {"message_end", "agent_message_end"}: + continue + finally: + response.close() + + yield LLMStreamEvent(type="done") + + @staticmethod + def _extract_query(messages: List[LLMMessage]) -> str: + for message in reversed(messages): + if str(message.role or "").strip().lower() == "user": + return str(message.content or "").strip() + for message in reversed(messages): + content = str(message.content or "").strip() + if content: + return content + return "" + + @staticmethod + def _extract_text_delta(event: Dict[str, Any]) -> str: + for key in ("answer", "text", "content"): + value = event.get(key) + if value is not None: + text = str(value) + if text: + return text + return "" + + async def _raise_for_status(self, response: aiohttp.ClientResponse, context: str) -> None: + if int(response.status) < 400: + return + + try: + payload = await response.json() + except Exception: + payload = await response.text() + raise RuntimeError(f"{context}: HTTP {response.status} {payload}") + + async def _iter_sse_events(self, response: aiohttp.ClientResponse) -> AsyncIterator[Dict[str, Any]]: + event_name = "" + data_lines: List[str] = [] + + async for raw_line in response.content: + line = raw_line.decode("utf-8", errors="ignore").rstrip("\r\n") + if not line: + payload = self._decode_sse_payload(event_name, data_lines) + event_name = "" + data_lines = [] + if payload is not None: + yield payload + continue + + if line.startswith(":"): + continue + if line.startswith("event:"): + event_name = line.split(":", 1)[1].strip() + continue + if line.startswith("data:"): + data_lines.append(line.split(":", 1)[1].lstrip()) + + payload = self._decode_sse_payload(event_name, data_lines) + if payload is not None: + yield payload + + @staticmethod + def _decode_sse_payload(event_name: str, data_lines: List[str]) -> Optional[Dict[str, Any]]: + if not data_lines: + return None + + raw = "\n".join(data_lines).strip() + if not raw: + return None + if raw == "[DONE]": + return {"event": "message_end"} + + try: + payload = json.loads(raw) + except json.JSONDecodeError: + logger.debug("Skipping non-JSON Dify SSE payload: {}", raw) + return None + + if not isinstance(payload, dict): + return None + if event_name and not payload.get("event"): + payload["event"] = event_name + return payload diff --git a/engine/tests/test_dify_provider.py b/engine/tests/test_dify_provider.py new file mode 100644 index 0000000..4698581 --- /dev/null +++ b/engine/tests/test_dify_provider.py @@ -0,0 +1,166 @@ +from typing import Any, Dict, List, Optional + +import pytest + +from providers.common.base import LLMMessage +from providers.llm.dify import DifyLLMService + + +class _FakeStreamResponse: + def __init__(self, lines: List[bytes], status: int = 200): + self.content = _FakeStreamContent(lines) + self.status = status + self.closed = False + + async def json(self) -> Dict[str, Any]: + return {} + + async def text(self) -> str: + return "" + + def close(self) -> None: + self.closed = True + + +class _FakeJSONResponse: + def __init__(self, payload: Dict[str, Any], status: int = 200): + self.payload = payload + self.status = status + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def json(self) -> Dict[str, Any]: + return dict(self.payload) + + async def text(self) -> str: + return "" + + +class _FakeStreamContent: + def __init__(self, lines: List[bytes]): + self._lines = list(lines) + + def __aiter__(self): + return self._iter() + + async def _iter(self): + for line in self._lines: + yield line + + +class _FakeClientSession: + post_responses: List[_FakeStreamResponse] = [] + get_payloads: List[Dict[str, Any]] = [] + last_post_url: Optional[str] = None + last_post_json: Optional[Dict[str, Any]] = None + last_get_url: Optional[str] = None + last_get_params: Optional[Dict[str, Any]] = None + + def __init__(self, headers: Optional[Dict[str, str]] = None): + self.headers = headers or {} + self.closed = False + + async def close(self) -> None: + self.closed = True + + async def post(self, url: str, json: Dict[str, Any]): + type(self).last_post_url = url + type(self).last_post_json = dict(json) + if not type(self).post_responses: + raise AssertionError("No fake Dify stream response queued") + return type(self).post_responses.pop(0) + + def get(self, url: str, params: Dict[str, Any]): + type(self).last_get_url = url + type(self).last_get_params = dict(params) + if not type(self).get_payloads: + raise AssertionError("No fake Dify JSON payload queued") + return _FakeJSONResponse(type(self).get_payloads.pop(0)) + + +@pytest.mark.asyncio +async def test_dify_provider_streams_message_answer_and_tracks_conversation(monkeypatch): + monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession) + _FakeClientSession.post_responses = [ + _FakeStreamResponse( + [ + b'data: {"event":"message","conversation_id":"conv-1","answer":"Hello "}\n', + b"\n", + b'data: {"event":"agent_message","conversation_id":"conv-1","answer":"from Dify."}\n', + b"\n", + b'data: {"event":"message_end","conversation_id":"conv-1"}\n', + b"\n", + ] + ) + ] + + service = DifyLLMService(api_key="key", base_url="https://dify.example/v1") + await service.connect() + + events = [event async for event in service.generate_stream([LLMMessage(role="user", content="Hi there")])] + + assert [event.type for event in events] == ["text_delta", "text_delta", "done"] + assert events[0].text == "Hello " + assert events[1].text == "from Dify." + assert service._conversation_id == "conv-1" + assert _FakeClientSession.last_post_url == "https://dify.example/v1/chat-messages" + assert _FakeClientSession.last_post_json == { + "inputs": {}, + "query": "Hi there", + "user": service._user_id, + "response_mode": "streaming", + } + + +@pytest.mark.asyncio +async def test_dify_provider_reuses_conversation_id_on_follow_up(monkeypatch): + monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession) + _FakeClientSession.post_responses = [ + _FakeStreamResponse( + [ + b'data: {"event":"message","conversation_id":"conv-2","answer":"First"}\n', + b"\n", + ] + ), + _FakeStreamResponse( + [ + b'data: {"event":"message","conversation_id":"conv-2","answer":"Second"}\n', + b"\n", + ] + ), + ] + + service = DifyLLMService(api_key="key", base_url="https://dify.example/v1") + await service.connect() + + _ = [event async for event in service.generate_stream([LLMMessage(role="user", content="Turn one")])] + _ = [event async for event in service.generate_stream([LLMMessage(role="user", content="Turn two")])] + + assert _FakeClientSession.last_post_json == { + "inputs": {}, + "query": "Turn two", + "user": service._user_id, + "response_mode": "streaming", + "conversation_id": "conv-2", + } + + +@pytest.mark.asyncio +async def test_dify_provider_loads_initial_greeting_from_parameters(monkeypatch): + monkeypatch.setattr("providers.llm.dify.aiohttp.ClientSession", _FakeClientSession) + _FakeClientSession.get_payloads = [ + {"opening_statement": "Hello from Dify."}, + ] + + service = DifyLLMService(api_key="key", base_url="https://dify.example/v1") + await service.connect() + + greeting = await service.get_initial_greeting() + + assert greeting == "Hello from Dify." + assert _FakeClientSession.last_get_url == "https://dify.example/v1/parameters" + assert _FakeClientSession.last_get_params == {"user": service._user_id} diff --git a/engine/tests/test_llm_factory_modes.py b/engine/tests/test_llm_factory_modes.py new file mode 100644 index 0000000..8c966ae --- /dev/null +++ b/engine/tests/test_llm_factory_modes.py @@ -0,0 +1,32 @@ +from providers.factory.default import DefaultRealtimeServiceFactory +from providers.llm.dify import DifyLLMService +from providers.llm.openai import OpenAILLMService +from runtime.ports import LLMServiceSpec + + +def test_create_llm_service_dify_returns_dify_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_llm_service( + LLMServiceSpec( + provider="dify", + model="dify", + api_key="test-key", + base_url="https://dify.example/v1", + ) + ) + + assert isinstance(service, DifyLLMService) + + +def test_create_llm_service_openai_returns_openai_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_llm_service( + LLMServiceSpec( + provider="openai", + model="gpt-4o-mini", + api_key="test-key", + base_url="https://api.openai.com/v1", + ) + ) + + assert isinstance(service, OpenAILLMService)