"""FastGPT-backed LLM provider.""" from __future__ import annotations import asyncio import json import uuid from typing import Any, AsyncIterator, Dict, List, Optional from loguru import logger from providers.common.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState from providers.llm.fastgpt_types import ( FastGPTConversationState, FastGPTField, FastGPTInteractivePrompt, FastGPTOption, FastGPTPendingInteraction, ) try: from fastgpt_client import AsyncChatClient, aiter_stream_events except Exception as exc: # pragma: no cover - exercised indirectly via connect() AsyncChatClient = None # type: ignore[assignment] aiter_stream_events = None # type: ignore[assignment] _FASTGPT_IMPORT_ERROR: Optional[Exception] = exc else: # pragma: no cover - import success depends on local environment _FASTGPT_IMPORT_ERROR = None class FastGPTLLMService(BaseLLMService): """LLM provider that delegates orchestration to FastGPT.""" INTERACTIVE_TOOL_NAME = "fastgpt.interactive" INTERACTIVE_TIMEOUT_MS = 300000 def __init__( self, *, api_key: str, base_url: str, app_id: Optional[str] = None, model: str = "fastgpt", system_prompt: Optional[str] = None, ): super().__init__(model=model or "fastgpt") self.api_key = api_key self.base_url = str(base_url or "").rstrip("/") self.app_id = str(app_id or "").strip() self.system_prompt = system_prompt or "" self.client: Any = None self._cancel_event = asyncio.Event() self._state = FastGPTConversationState() self._knowledge_config: Dict[str, Any] = {} self._tool_schemas: List[Dict[str, Any]] = [] async def connect(self) -> None: if AsyncChatClient is None or aiter_stream_events is None: raise RuntimeError( "fastgpt_client package is not available. " "Install the sibling fastgpt-python-sdk package first." ) from _FASTGPT_IMPORT_ERROR if not self.api_key: raise ValueError("FastGPT API key not provided") if not self.base_url: raise ValueError("FastGPT base URL not provided") self.client = AsyncChatClient(api_key=self.api_key, base_url=self.base_url) self.state = ServiceState.CONNECTED logger.info("FastGPT LLM service connected: base_url={}", self.base_url) async def disconnect(self) -> None: if self.client and hasattr(self.client, "close"): await self.client.close() self.client = None self._state.pending_interaction = None self.state = ServiceState.DISCONNECTED logger.info("FastGPT LLM service disconnected") def cancel(self) -> None: self._cancel_event.set() self._state.pending_interaction = None def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None: # FastGPT owns KB orchestration in this provider mode. self._knowledge_config = dict(config or {}) def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None: # FastGPT owns workflow and tool orchestration in this provider mode. self._tool_schemas = list(schemas or []) def handles_client_tool(self, tool_name: str) -> bool: return str(tool_name or "").strip() == self.INTERACTIVE_TOOL_NAME async def get_initial_greeting(self) -> Optional[str]: if not self.client or not self.app_id: return None response = await self.client.get_chat_init( appId=self.app_id, chatId=self._ensure_chat_id(), ) raise_for_status = getattr(response, "raise_for_status", None) if callable(raise_for_status): raise_for_status() elif int(getattr(response, "status_code", 200) or 200) >= 400: raise RuntimeError(f"FastGPT chat init failed: HTTP {getattr(response, 'status_code', 'unknown')}") payload = response.json() if hasattr(response, "json") else {} return self._extract_initial_greeting(payload) async def generate( self, messages: List[LLMMessage], temperature: float = 0.7, max_tokens: Optional[int] = None, ) -> str: parts: List[str] = [] async for event in self.generate_stream(messages, temperature=temperature, max_tokens=max_tokens): if event.type == "text_delta" and event.text: parts.append(event.text) if event.type == "tool_call": break return "".join(parts) async def generate_stream( self, messages: List[LLMMessage], temperature: float = 0.7, max_tokens: Optional[int] = None, ) -> AsyncIterator[LLMStreamEvent]: del temperature, max_tokens if not self.client: raise RuntimeError("LLM service not connected") self._cancel_event.clear() request_messages = self._build_request_messages(messages) response = await self.client.create_chat_completion( messages=request_messages, chatId=self._ensure_chat_id(), detail=True, stream=True, ) try: async for event in aiter_stream_events(response): if self._cancel_event.is_set(): logger.info("FastGPT stream cancelled") break stop_after_event = False for mapped in self._map_stream_event(event): if mapped.type == "tool_call": stop_after_event = True yield mapped if stop_after_event: break finally: await self._close_stream_response(response) async def resume_after_client_tool_result( self, tool_call_id: str, result: Dict[str, Any], ) -> AsyncIterator[LLMStreamEvent]: if not self.client: raise RuntimeError("LLM service not connected") pending = self._require_pending_interaction(tool_call_id) follow_up_text = self._build_resume_text(pending, result) self._state.pending_interaction = None if not follow_up_text: yield LLMStreamEvent(type="done") return self._cancel_event.clear() response = await self.client.create_chat_completion( messages=[{"role": "user", "content": follow_up_text}], chatId=pending.chat_id, detail=True, stream=True, ) try: async for event in aiter_stream_events(response): if self._cancel_event.is_set(): logger.info("FastGPT resume stream cancelled") break stop_after_event = False for mapped in self._map_stream_event(event): if mapped.type == "tool_call": stop_after_event = True yield mapped if stop_after_event: break finally: await self._close_stream_response(response) async def _close_stream_response(self, response: Any) -> None: if response is None: return # httpx async streaming responses must use `aclose()`. aclose = getattr(response, "aclose", None) if callable(aclose): await aclose() return close = getattr(response, "close", None) if callable(close): maybe_awaitable = close() if hasattr(maybe_awaitable, "__await__"): await maybe_awaitable def _ensure_chat_id(self) -> str: chat_id = str(self._state.chat_id or "").strip() if not chat_id: chat_id = f"fastgpt_{uuid.uuid4().hex}" self._state.chat_id = chat_id return chat_id def _build_request_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]: non_empty = [msg for msg in messages if str(msg.content or "").strip()] if not non_empty: return [{"role": "user", "content": ""}] latest_user = next((msg for msg in reversed(non_empty) if msg.role == "user"), None) trailing_system = non_empty[-1] if non_empty and non_empty[-1].role == "system" else None request: List[Dict[str, Any]] = [] if trailing_system and trailing_system is not latest_user: request.append({"role": "system", "content": trailing_system.content.strip()}) if latest_user and str(latest_user.content or "").strip(): request.append({"role": "user", "content": latest_user.content.strip()}) return request last_message = non_empty[-1] payload = last_message.to_dict() payload["content"] = str(payload.get("content") or "").strip() return [payload] def _extract_initial_greeting(self, payload: Any) -> Optional[str]: if not isinstance(payload, dict): return None candidates: List[Any] = [ payload.get("app"), payload.get("data"), ] for container in candidates: if not isinstance(container, dict): continue nested_app = container.get("app") if isinstance(container.get("app"), dict) else None if nested_app: text = self._welcome_text_from_app(nested_app) if text: return text text = self._welcome_text_from_app(container) if text: return text return None @staticmethod def _welcome_text_from_app(app_payload: Dict[str, Any]) -> Optional[str]: chat_config = app_payload.get("chatConfig") if isinstance(app_payload.get("chatConfig"), dict) else {} text = str( chat_config.get("welcomeText") or app_payload.get("welcomeText") or "" ).strip() return text or None def _map_stream_event(self, event: Any) -> List[LLMStreamEvent]: kind = str(getattr(event, "kind", "") or "") data = getattr(event, "data", {}) if not isinstance(data, dict): data = {} if kind in {"data", "answer", "fastAnswer"}: chunks = self._extract_text_chunks(kind, data) return [LLMStreamEvent(type="text_delta", text=chunk) for chunk in chunks if chunk] if kind == "interactive": return [self._build_interactive_tool_event(data)] if kind == "error": message = str(data.get("message") or data.get("error") or "FastGPT streaming error") raise RuntimeError(message) if kind == "done": return [LLMStreamEvent(type="done")] return [] @staticmethod def _normalize_interactive_payload(payload: Dict[str, Any]) -> Dict[str, Any]: normalized = payload wrapped = normalized.get("interactive") if isinstance(wrapped, dict): normalized = wrapped interaction_type = str(normalized.get("type") or "").strip() if interaction_type == "toolChildrenInteractive": params = normalized.get("params") if isinstance(normalized.get("params"), dict) else {} children_response = params.get("childrenResponse") if isinstance(children_response, dict): normalized = children_response return normalized def _extract_text_chunks(self, kind: str, data: Dict[str, Any]) -> List[str]: if kind in {"answer", "fastAnswer"}: text = str(data.get("text") or "") if text: return [text] choices = data.get("choices") if isinstance(data.get("choices"), list) else [] if not choices: text = str(data.get("text") or "") return [text] if text else [] first = choices[0] if isinstance(choices[0], dict) else {} delta = first.get("delta") if isinstance(first.get("delta"), dict) else {} if isinstance(delta.get("content"), str) and delta.get("content"): return [str(delta.get("content"))] message = first.get("message") if isinstance(first.get("message"), dict) else {} if isinstance(message.get("content"), str) and message.get("content"): return [str(message.get("content"))] return [] def _build_interactive_tool_event(self, payload: Dict[str, Any]) -> LLMStreamEvent: normalized_payload = self._normalize_interactive_payload(payload) prompt = self._parse_interactive_prompt(normalized_payload) call_id = f"fgi_{uuid.uuid4().hex[:12]}" pending = FastGPTPendingInteraction( tool_call_id=call_id, chat_id=self._ensure_chat_id(), prompt=prompt, timeout_ms=self.INTERACTIVE_TIMEOUT_MS, fastgpt_event=dict(normalized_payload), ) self._state.pending_interaction = pending arguments = prompt.to_ws_arguments(chat_id=pending.chat_id) tool_call = { "id": call_id, "type": "function", "executor": "client", "wait_for_response": True, "timeout_ms": pending.timeout_ms, "display_name": prompt.title or prompt.description or prompt.prompt or "FastGPT Interactive", "function": { "name": self.INTERACTIVE_TOOL_NAME, "arguments": json.dumps(arguments, ensure_ascii=False), }, } return LLMStreamEvent(type="tool_call", tool_call=tool_call) def _parse_interactive_prompt(self, payload: Dict[str, Any]) -> FastGPTInteractivePrompt: params = payload.get("params") if isinstance(payload.get("params"), dict) else {} kind = str(payload.get("type") or "userSelect").strip() or "userSelect" title = str( payload.get("title") or params.get("title") or payload.get("nodeName") or payload.get("label") or "" ).strip() description = str( payload.get("description") or payload.get("desc") or params.get("description") or params.get("desc") or "" ).strip() prompt_text = str( payload.get("opener") or params.get("opener") or payload.get("intro") or params.get("intro") or payload.get("prompt") or params.get("prompt") or payload.get("text") or params.get("text") or title or description ).strip() required = self._coerce_bool(payload.get("required"), default=True) multiple = self._coerce_bool(params.get("multiple") or payload.get("multiple"), default=False) submit_label = str(params.get("submitText") or payload.get("submitText") or "Continue").strip() or "Continue" cancel_label = str(params.get("cancelText") or payload.get("cancelText") or "Cancel").strip() or "Cancel" options: List[FastGPTOption] = [] raw_options = params.get("userSelectOptions") if isinstance(params.get("userSelectOptions"), list) else [] for index, raw_option in enumerate(raw_options): if isinstance(raw_option, str): value = raw_option.strip() if not value: continue options.append(FastGPTOption(id=f"option_{index}", label=value, value=value)) continue if not isinstance(raw_option, dict): continue label = str(raw_option.get("label") or raw_option.get("value") or raw_option.get("id") or "").strip() value = str(raw_option.get("value") or raw_option.get("label") or raw_option.get("id") or "").strip() option_id = str(raw_option.get("id") or value or f"option_{index}").strip() if not label and not value: continue options.append( FastGPTOption( id=option_id or f"option_{index}", label=label or value, value=value or label, description=str( raw_option.get("description") or raw_option.get("desc") or raw_option.get("intro") or raw_option.get("summary") or "" ).strip(), ) ) form: List[FastGPTField] = [] raw_form = params.get("inputForm") if isinstance(params.get("inputForm"), list) else [] for index, raw_field in enumerate(raw_form): if not isinstance(raw_field, dict): continue field_options: List[FastGPTOption] = [] nested_options = raw_field.get("options") if isinstance(raw_field.get("options"), list) else [] for opt_index, option in enumerate(nested_options): if isinstance(option, str): value = option.strip() if not value: continue field_options.append(FastGPTOption(id=f"field_{index}_opt_{opt_index}", label=value, value=value)) continue if not isinstance(option, dict): continue label = str(option.get("label") or option.get("value") or option.get("id") or "").strip() value = str(option.get("value") or option.get("label") or option.get("id") or "").strip() option_id = str(option.get("id") or value or f"field_{index}_opt_{opt_index}").strip() if not label and not value: continue field_options.append( FastGPTOption( id=option_id or f"field_{index}_opt_{opt_index}", label=label or value, value=value or label, description=str( option.get("description") or option.get("desc") or option.get("intro") or option.get("summary") or "" ).strip(), ) ) name = str(raw_field.get("key") or raw_field.get("name") or raw_field.get("label") or f"field_{index}").strip() label = str(raw_field.get("label") or raw_field.get("name") or name).strip() form.append( FastGPTField( name=name or f"field_{index}", label=label or name or f"field_{index}", input_type=str(raw_field.get("type") or raw_field.get("inputType") or "text").strip() or "text", required=self._coerce_bool(raw_field.get("required"), default=False), placeholder=str( raw_field.get("placeholder") or raw_field.get("description") or raw_field.get("desc") or "" ).strip(), default=raw_field.get("defaultValue", raw_field.get("default")), options=field_options, ) ) return FastGPTInteractivePrompt( kind="userInput" if kind == "userInput" else "userSelect", title=title, description=description, prompt=prompt_text, required=required, multiple=multiple, submit_label=submit_label, cancel_label=cancel_label, options=options, form=form, raw=dict(payload), ) def _require_pending_interaction(self, tool_call_id: str) -> FastGPTPendingInteraction: pending = self._state.pending_interaction if pending is None or pending.tool_call_id != tool_call_id: raise ValueError(f"FastGPT interaction not pending for tool call: {tool_call_id}") return pending def _build_resume_text(self, pending: FastGPTPendingInteraction, result: Dict[str, Any]) -> str: status = result.get("status") if isinstance(result.get("status"), dict) else {} status_code = self._safe_int(status.get("code"), default=0) output = result.get("output") if isinstance(result.get("output"), dict) else {} action = str(output.get("action") or "").strip().lower() if action == "cancel" or status_code == 499: return "" if status_code == 422: raise ValueError("Invalid FastGPT interactive payload from client") if status_code and not 200 <= status_code < 300: raise ValueError(f"FastGPT interactive result rejected with status {status_code}") if action and action != "submit": raise ValueError(f"Unsupported FastGPT interactive action: {action}") payload = output.get("result") if isinstance(output.get("result"), dict) else output if not isinstance(payload, dict): raise ValueError("FastGPT interactive client result must be an object") if pending.prompt.kind == "userSelect": selected = str(payload.get("selected") or "").strip() if selected: return selected selected_values = payload.get("selected_values") if isinstance(payload.get("selected_values"), list) else [] values = [str(item).strip() for item in selected_values if str(item).strip()] if values: return ", ".join(values) text_value = str(payload.get("text") or "").strip() return text_value text_value = str(payload.get("text") or "").strip() if text_value: return text_value fields = payload.get("fields") if isinstance(payload.get("fields"), dict) else {} compact_fields = {str(key): value for key, value in fields.items()} if compact_fields: return json.dumps(compact_fields, ensure_ascii=False) return "" @staticmethod def _coerce_bool(value: Any, *, default: bool) -> bool: if isinstance(value, bool): return value if isinstance(value, str): normalized = value.strip().lower() if normalized in {"true", "1", "yes", "on"}: return True if normalized in {"false", "0", "no", "off"}: return False return default @staticmethod def _safe_int(value: Any, *, default: int) -> int: try: return int(value) except (TypeError, ValueError): return default