Files
2026-03-11 08:37:34 +08:00

554 lines
22 KiB
Python

"""FastGPT-backed LLM provider."""
from __future__ import annotations
import asyncio
import json
import uuid
from typing import Any, AsyncIterator, Dict, List, Optional
from loguru import logger
from providers.common.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
from providers.llm.fastgpt_types import (
FastGPTConversationState,
FastGPTField,
FastGPTInteractivePrompt,
FastGPTOption,
FastGPTPendingInteraction,
)
try:
from fastgpt_client import AsyncChatClient, aiter_stream_events
except Exception as exc: # pragma: no cover - exercised indirectly via connect()
AsyncChatClient = None # type: ignore[assignment]
aiter_stream_events = None # type: ignore[assignment]
_FASTGPT_IMPORT_ERROR: Optional[Exception] = exc
else: # pragma: no cover - import success depends on local environment
_FASTGPT_IMPORT_ERROR = None
class FastGPTLLMService(BaseLLMService):
"""LLM provider that delegates orchestration to FastGPT."""
INTERACTIVE_TOOL_NAME = "fastgpt.interactive"
INTERACTIVE_TIMEOUT_MS = 300000
def __init__(
self,
*,
api_key: str,
base_url: str,
app_id: Optional[str] = None,
model: str = "fastgpt",
system_prompt: Optional[str] = None,
):
super().__init__(model=model or "fastgpt")
self.api_key = api_key
self.base_url = str(base_url or "").rstrip("/")
self.app_id = str(app_id or "").strip()
self.system_prompt = system_prompt or ""
self.client: Any = None
self._cancel_event = asyncio.Event()
self._state = FastGPTConversationState()
self._knowledge_config: Dict[str, Any] = {}
self._tool_schemas: List[Dict[str, Any]] = []
async def connect(self) -> None:
if AsyncChatClient is None or aiter_stream_events is None:
raise RuntimeError(
"fastgpt_client package is not available. "
"Install the sibling fastgpt-python-sdk package first."
) from _FASTGPT_IMPORT_ERROR
if not self.api_key:
raise ValueError("FastGPT API key not provided")
if not self.base_url:
raise ValueError("FastGPT base URL not provided")
self.client = AsyncChatClient(api_key=self.api_key, base_url=self.base_url)
self.state = ServiceState.CONNECTED
logger.info("FastGPT LLM service connected: base_url={}", self.base_url)
async def disconnect(self) -> None:
if self.client and hasattr(self.client, "close"):
await self.client.close()
self.client = None
self._state.pending_interaction = None
self.state = ServiceState.DISCONNECTED
logger.info("FastGPT LLM service disconnected")
def cancel(self) -> None:
self._cancel_event.set()
self._state.pending_interaction = None
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
# FastGPT owns KB orchestration in this provider mode.
self._knowledge_config = dict(config or {})
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
# FastGPT owns workflow and tool orchestration in this provider mode.
self._tool_schemas = list(schemas or [])
def handles_client_tool(self, tool_name: str) -> bool:
return str(tool_name or "").strip() == self.INTERACTIVE_TOOL_NAME
async def get_initial_greeting(self) -> Optional[str]:
if not self.client or not self.app_id:
return None
response = await self.client.get_chat_init(
appId=self.app_id,
chatId=self._ensure_chat_id(),
)
raise_for_status = getattr(response, "raise_for_status", None)
if callable(raise_for_status):
raise_for_status()
elif int(getattr(response, "status_code", 200) or 200) >= 400:
raise RuntimeError(f"FastGPT chat init failed: HTTP {getattr(response, 'status_code', 'unknown')}")
payload = response.json() if hasattr(response, "json") else {}
return self._extract_initial_greeting(payload)
async def generate(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
) -> str:
parts: List[str] = []
async for event in self.generate_stream(messages, temperature=temperature, max_tokens=max_tokens):
if event.type == "text_delta" and event.text:
parts.append(event.text)
if event.type == "tool_call":
break
return "".join(parts)
async def generate_stream(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
) -> AsyncIterator[LLMStreamEvent]:
del temperature, max_tokens
if not self.client:
raise RuntimeError("LLM service not connected")
self._cancel_event.clear()
request_messages = self._build_request_messages(messages)
response = await self.client.create_chat_completion(
messages=request_messages,
chatId=self._ensure_chat_id(),
detail=True,
stream=True,
)
try:
async for event in aiter_stream_events(response):
if self._cancel_event.is_set():
logger.info("FastGPT stream cancelled")
break
stop_after_event = False
for mapped in self._map_stream_event(event):
if mapped.type == "tool_call":
stop_after_event = True
yield mapped
if stop_after_event:
break
finally:
await self._close_stream_response(response)
async def resume_after_client_tool_result(
self,
tool_call_id: str,
result: Dict[str, Any],
) -> AsyncIterator[LLMStreamEvent]:
if not self.client:
raise RuntimeError("LLM service not connected")
pending = self._require_pending_interaction(tool_call_id)
follow_up_text = self._build_resume_text(pending, result)
self._state.pending_interaction = None
if not follow_up_text:
yield LLMStreamEvent(type="done")
return
self._cancel_event.clear()
response = await self.client.create_chat_completion(
messages=[{"role": "user", "content": follow_up_text}],
chatId=pending.chat_id,
detail=True,
stream=True,
)
try:
async for event in aiter_stream_events(response):
if self._cancel_event.is_set():
logger.info("FastGPT resume stream cancelled")
break
stop_after_event = False
for mapped in self._map_stream_event(event):
if mapped.type == "tool_call":
stop_after_event = True
yield mapped
if stop_after_event:
break
finally:
await self._close_stream_response(response)
async def _close_stream_response(self, response: Any) -> None:
if response is None:
return
# httpx async streaming responses must use `aclose()`.
aclose = getattr(response, "aclose", None)
if callable(aclose):
await aclose()
return
close = getattr(response, "close", None)
if callable(close):
maybe_awaitable = close()
if hasattr(maybe_awaitable, "__await__"):
await maybe_awaitable
def _ensure_chat_id(self) -> str:
chat_id = str(self._state.chat_id or "").strip()
if not chat_id:
chat_id = f"fastgpt_{uuid.uuid4().hex}"
self._state.chat_id = chat_id
return chat_id
def _build_request_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
non_empty = [msg for msg in messages if str(msg.content or "").strip()]
if not non_empty:
return [{"role": "user", "content": ""}]
latest_user = next((msg for msg in reversed(non_empty) if msg.role == "user"), None)
trailing_system = non_empty[-1] if non_empty and non_empty[-1].role == "system" else None
request: List[Dict[str, Any]] = []
if trailing_system and trailing_system is not latest_user:
request.append({"role": "system", "content": trailing_system.content.strip()})
if latest_user and str(latest_user.content or "").strip():
request.append({"role": "user", "content": latest_user.content.strip()})
return request
last_message = non_empty[-1]
payload = last_message.to_dict()
payload["content"] = str(payload.get("content") or "").strip()
return [payload]
def _extract_initial_greeting(self, payload: Any) -> Optional[str]:
if not isinstance(payload, dict):
return None
candidates: List[Any] = [
payload.get("app"),
payload.get("data"),
]
for container in candidates:
if not isinstance(container, dict):
continue
nested_app = container.get("app") if isinstance(container.get("app"), dict) else None
if nested_app:
text = self._welcome_text_from_app(nested_app)
if text:
return text
text = self._welcome_text_from_app(container)
if text:
return text
return None
@staticmethod
def _welcome_text_from_app(app_payload: Dict[str, Any]) -> Optional[str]:
chat_config = app_payload.get("chatConfig") if isinstance(app_payload.get("chatConfig"), dict) else {}
text = str(
chat_config.get("welcomeText")
or app_payload.get("welcomeText")
or ""
).strip()
return text or None
def _map_stream_event(self, event: Any) -> List[LLMStreamEvent]:
kind = str(getattr(event, "kind", "") or "")
data = getattr(event, "data", {})
if not isinstance(data, dict):
data = {}
if kind in {"data", "answer", "fastAnswer"}:
chunks = self._extract_text_chunks(kind, data)
return [LLMStreamEvent(type="text_delta", text=chunk) for chunk in chunks if chunk]
if kind == "interactive":
return [self._build_interactive_tool_event(data)]
if kind == "error":
message = str(data.get("message") or data.get("error") or "FastGPT streaming error")
raise RuntimeError(message)
if kind == "done":
return [LLMStreamEvent(type="done")]
return []
@staticmethod
def _normalize_interactive_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
normalized = payload
wrapped = normalized.get("interactive")
if isinstance(wrapped, dict):
normalized = wrapped
interaction_type = str(normalized.get("type") or "").strip()
if interaction_type == "toolChildrenInteractive":
params = normalized.get("params") if isinstance(normalized.get("params"), dict) else {}
children_response = params.get("childrenResponse")
if isinstance(children_response, dict):
normalized = children_response
return normalized
def _extract_text_chunks(self, kind: str, data: Dict[str, Any]) -> List[str]:
if kind in {"answer", "fastAnswer"}:
text = str(data.get("text") or "")
if text:
return [text]
choices = data.get("choices") if isinstance(data.get("choices"), list) else []
if not choices:
text = str(data.get("text") or "")
return [text] if text else []
first = choices[0] if isinstance(choices[0], dict) else {}
delta = first.get("delta") if isinstance(first.get("delta"), dict) else {}
if isinstance(delta.get("content"), str) and delta.get("content"):
return [str(delta.get("content"))]
message = first.get("message") if isinstance(first.get("message"), dict) else {}
if isinstance(message.get("content"), str) and message.get("content"):
return [str(message.get("content"))]
return []
def _build_interactive_tool_event(self, payload: Dict[str, Any]) -> LLMStreamEvent:
normalized_payload = self._normalize_interactive_payload(payload)
prompt = self._parse_interactive_prompt(normalized_payload)
call_id = f"fgi_{uuid.uuid4().hex[:12]}"
pending = FastGPTPendingInteraction(
tool_call_id=call_id,
chat_id=self._ensure_chat_id(),
prompt=prompt,
timeout_ms=self.INTERACTIVE_TIMEOUT_MS,
fastgpt_event=dict(normalized_payload),
)
self._state.pending_interaction = pending
arguments = prompt.to_ws_arguments(chat_id=pending.chat_id)
tool_call = {
"id": call_id,
"type": "function",
"executor": "client",
"wait_for_response": True,
"timeout_ms": pending.timeout_ms,
"display_name": prompt.title or prompt.description or prompt.prompt or "FastGPT Interactive",
"function": {
"name": self.INTERACTIVE_TOOL_NAME,
"arguments": json.dumps(arguments, ensure_ascii=False),
},
}
return LLMStreamEvent(type="tool_call", tool_call=tool_call)
def _parse_interactive_prompt(self, payload: Dict[str, Any]) -> FastGPTInteractivePrompt:
params = payload.get("params") if isinstance(payload.get("params"), dict) else {}
kind = str(payload.get("type") or "userSelect").strip() or "userSelect"
title = str(
payload.get("title")
or params.get("title")
or payload.get("nodeName")
or payload.get("label")
or ""
).strip()
description = str(
payload.get("description")
or payload.get("desc")
or params.get("description")
or params.get("desc")
or ""
).strip()
prompt_text = str(
payload.get("opener")
or params.get("opener")
or payload.get("intro")
or params.get("intro")
or payload.get("prompt")
or params.get("prompt")
or payload.get("text")
or params.get("text")
or title
or description
).strip()
required = self._coerce_bool(payload.get("required"), default=True)
multiple = self._coerce_bool(params.get("multiple") or payload.get("multiple"), default=False)
submit_label = str(params.get("submitText") or payload.get("submitText") or "Continue").strip() or "Continue"
cancel_label = str(params.get("cancelText") or payload.get("cancelText") or "Cancel").strip() or "Cancel"
options: List[FastGPTOption] = []
raw_options = params.get("userSelectOptions") if isinstance(params.get("userSelectOptions"), list) else []
for index, raw_option in enumerate(raw_options):
if isinstance(raw_option, str):
value = raw_option.strip()
if not value:
continue
options.append(FastGPTOption(id=f"option_{index}", label=value, value=value))
continue
if not isinstance(raw_option, dict):
continue
label = str(raw_option.get("label") or raw_option.get("value") or raw_option.get("id") or "").strip()
value = str(raw_option.get("value") or raw_option.get("label") or raw_option.get("id") or "").strip()
option_id = str(raw_option.get("id") or value or f"option_{index}").strip()
if not label and not value:
continue
options.append(
FastGPTOption(
id=option_id or f"option_{index}",
label=label or value,
value=value or label,
description=str(
raw_option.get("description")
or raw_option.get("desc")
or raw_option.get("intro")
or raw_option.get("summary")
or ""
).strip(),
)
)
form: List[FastGPTField] = []
raw_form = params.get("inputForm") if isinstance(params.get("inputForm"), list) else []
for index, raw_field in enumerate(raw_form):
if not isinstance(raw_field, dict):
continue
field_options: List[FastGPTOption] = []
nested_options = raw_field.get("options") if isinstance(raw_field.get("options"), list) else []
for opt_index, option in enumerate(nested_options):
if isinstance(option, str):
value = option.strip()
if not value:
continue
field_options.append(FastGPTOption(id=f"field_{index}_opt_{opt_index}", label=value, value=value))
continue
if not isinstance(option, dict):
continue
label = str(option.get("label") or option.get("value") or option.get("id") or "").strip()
value = str(option.get("value") or option.get("label") or option.get("id") or "").strip()
option_id = str(option.get("id") or value or f"field_{index}_opt_{opt_index}").strip()
if not label and not value:
continue
field_options.append(
FastGPTOption(
id=option_id or f"field_{index}_opt_{opt_index}",
label=label or value,
value=value or label,
description=str(
option.get("description")
or option.get("desc")
or option.get("intro")
or option.get("summary")
or ""
).strip(),
)
)
name = str(raw_field.get("key") or raw_field.get("name") or raw_field.get("label") or f"field_{index}").strip()
label = str(raw_field.get("label") or raw_field.get("name") or name).strip()
form.append(
FastGPTField(
name=name or f"field_{index}",
label=label or name or f"field_{index}",
input_type=str(raw_field.get("type") or raw_field.get("inputType") or "text").strip() or "text",
required=self._coerce_bool(raw_field.get("required"), default=False),
placeholder=str(
raw_field.get("placeholder")
or raw_field.get("description")
or raw_field.get("desc")
or ""
).strip(),
default=raw_field.get("defaultValue", raw_field.get("default")),
options=field_options,
)
)
return FastGPTInteractivePrompt(
kind="userInput" if kind == "userInput" else "userSelect",
title=title,
description=description,
prompt=prompt_text,
required=required,
multiple=multiple,
submit_label=submit_label,
cancel_label=cancel_label,
options=options,
form=form,
raw=dict(payload),
)
def _require_pending_interaction(self, tool_call_id: str) -> FastGPTPendingInteraction:
pending = self._state.pending_interaction
if pending is None or pending.tool_call_id != tool_call_id:
raise ValueError(f"FastGPT interaction not pending for tool call: {tool_call_id}")
return pending
def _build_resume_text(self, pending: FastGPTPendingInteraction, result: Dict[str, Any]) -> str:
status = result.get("status") if isinstance(result.get("status"), dict) else {}
status_code = self._safe_int(status.get("code"), default=0)
output = result.get("output") if isinstance(result.get("output"), dict) else {}
action = str(output.get("action") or "").strip().lower()
if action == "cancel" or status_code == 499:
return ""
if status_code == 422:
raise ValueError("Invalid FastGPT interactive payload from client")
if status_code and not 200 <= status_code < 300:
raise ValueError(f"FastGPT interactive result rejected with status {status_code}")
if action and action != "submit":
raise ValueError(f"Unsupported FastGPT interactive action: {action}")
payload = output.get("result") if isinstance(output.get("result"), dict) else output
if not isinstance(payload, dict):
raise ValueError("FastGPT interactive client result must be an object")
if pending.prompt.kind == "userSelect":
selected = str(payload.get("selected") or "").strip()
if selected:
return selected
selected_values = payload.get("selected_values") if isinstance(payload.get("selected_values"), list) else []
values = [str(item).strip() for item in selected_values if str(item).strip()]
if values:
return ", ".join(values)
text_value = str(payload.get("text") or "").strip()
return text_value
text_value = str(payload.get("text") or "").strip()
if text_value:
return text_value
fields = payload.get("fields") if isinstance(payload.get("fields"), dict) else {}
compact_fields = {str(key): value for key, value in fields.items()}
if compact_fields:
return json.dumps(compact_fields, ensure_ascii=False)
return ""
@staticmethod
def _coerce_bool(value: Any, *, default: bool) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"true", "1", "yes", "on"}:
return True
if normalized in {"false", "0", "no", "off"}:
return False
return default
@staticmethod
def _safe_int(value: Any, *, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default