554 lines
22 KiB
Python
554 lines
22 KiB
Python
"""FastGPT-backed LLM provider."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
|
|
from loguru import logger
|
|
|
|
from providers.common.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
|
|
from providers.llm.fastgpt_types import (
|
|
FastGPTConversationState,
|
|
FastGPTField,
|
|
FastGPTInteractivePrompt,
|
|
FastGPTOption,
|
|
FastGPTPendingInteraction,
|
|
)
|
|
|
|
try:
|
|
from fastgpt_client import AsyncChatClient, aiter_stream_events
|
|
except Exception as exc: # pragma: no cover - exercised indirectly via connect()
|
|
AsyncChatClient = None # type: ignore[assignment]
|
|
aiter_stream_events = None # type: ignore[assignment]
|
|
_FASTGPT_IMPORT_ERROR: Optional[Exception] = exc
|
|
else: # pragma: no cover - import success depends on local environment
|
|
_FASTGPT_IMPORT_ERROR = None
|
|
|
|
|
|
class FastGPTLLMService(BaseLLMService):
|
|
"""LLM provider that delegates orchestration to FastGPT."""
|
|
|
|
INTERACTIVE_TOOL_NAME = "fastgpt.interactive"
|
|
INTERACTIVE_TIMEOUT_MS = 300000
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key: str,
|
|
base_url: str,
|
|
app_id: Optional[str] = None,
|
|
model: str = "fastgpt",
|
|
system_prompt: Optional[str] = None,
|
|
):
|
|
super().__init__(model=model or "fastgpt")
|
|
self.api_key = api_key
|
|
self.base_url = str(base_url or "").rstrip("/")
|
|
self.app_id = str(app_id or "").strip()
|
|
self.system_prompt = system_prompt or ""
|
|
self.client: Any = None
|
|
self._cancel_event = asyncio.Event()
|
|
self._state = FastGPTConversationState()
|
|
self._knowledge_config: Dict[str, Any] = {}
|
|
self._tool_schemas: List[Dict[str, Any]] = []
|
|
|
|
async def connect(self) -> None:
|
|
if AsyncChatClient is None or aiter_stream_events is None:
|
|
raise RuntimeError(
|
|
"fastgpt_client package is not available. "
|
|
"Install the sibling fastgpt-python-sdk package first."
|
|
) from _FASTGPT_IMPORT_ERROR
|
|
if not self.api_key:
|
|
raise ValueError("FastGPT API key not provided")
|
|
if not self.base_url:
|
|
raise ValueError("FastGPT base URL not provided")
|
|
self.client = AsyncChatClient(api_key=self.api_key, base_url=self.base_url)
|
|
self.state = ServiceState.CONNECTED
|
|
logger.info("FastGPT LLM service connected: base_url={}", self.base_url)
|
|
|
|
async def disconnect(self) -> None:
|
|
if self.client and hasattr(self.client, "close"):
|
|
await self.client.close()
|
|
self.client = None
|
|
self._state.pending_interaction = None
|
|
self.state = ServiceState.DISCONNECTED
|
|
logger.info("FastGPT LLM service disconnected")
|
|
|
|
def cancel(self) -> None:
|
|
self._cancel_event.set()
|
|
self._state.pending_interaction = None
|
|
|
|
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
|
|
# FastGPT owns KB orchestration in this provider mode.
|
|
self._knowledge_config = dict(config or {})
|
|
|
|
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
|
|
# FastGPT owns workflow and tool orchestration in this provider mode.
|
|
self._tool_schemas = list(schemas or [])
|
|
|
|
def handles_client_tool(self, tool_name: str) -> bool:
|
|
return str(tool_name or "").strip() == self.INTERACTIVE_TOOL_NAME
|
|
|
|
async def get_initial_greeting(self) -> Optional[str]:
|
|
if not self.client or not self.app_id:
|
|
return None
|
|
|
|
response = await self.client.get_chat_init(
|
|
appId=self.app_id,
|
|
chatId=self._ensure_chat_id(),
|
|
)
|
|
raise_for_status = getattr(response, "raise_for_status", None)
|
|
if callable(raise_for_status):
|
|
raise_for_status()
|
|
elif int(getattr(response, "status_code", 200) or 200) >= 400:
|
|
raise RuntimeError(f"FastGPT chat init failed: HTTP {getattr(response, 'status_code', 'unknown')}")
|
|
|
|
payload = response.json() if hasattr(response, "json") else {}
|
|
return self._extract_initial_greeting(payload)
|
|
|
|
async def generate(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
) -> str:
|
|
parts: List[str] = []
|
|
async for event in self.generate_stream(messages, temperature=temperature, max_tokens=max_tokens):
|
|
if event.type == "text_delta" and event.text:
|
|
parts.append(event.text)
|
|
if event.type == "tool_call":
|
|
break
|
|
return "".join(parts)
|
|
|
|
async def generate_stream(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
) -> AsyncIterator[LLMStreamEvent]:
|
|
del temperature, max_tokens
|
|
if not self.client:
|
|
raise RuntimeError("LLM service not connected")
|
|
|
|
self._cancel_event.clear()
|
|
request_messages = self._build_request_messages(messages)
|
|
response = await self.client.create_chat_completion(
|
|
messages=request_messages,
|
|
chatId=self._ensure_chat_id(),
|
|
detail=True,
|
|
stream=True,
|
|
)
|
|
try:
|
|
async for event in aiter_stream_events(response):
|
|
if self._cancel_event.is_set():
|
|
logger.info("FastGPT stream cancelled")
|
|
break
|
|
|
|
stop_after_event = False
|
|
for mapped in self._map_stream_event(event):
|
|
if mapped.type == "tool_call":
|
|
stop_after_event = True
|
|
yield mapped
|
|
if stop_after_event:
|
|
break
|
|
finally:
|
|
await self._close_stream_response(response)
|
|
|
|
async def resume_after_client_tool_result(
|
|
self,
|
|
tool_call_id: str,
|
|
result: Dict[str, Any],
|
|
) -> AsyncIterator[LLMStreamEvent]:
|
|
if not self.client:
|
|
raise RuntimeError("LLM service not connected")
|
|
|
|
pending = self._require_pending_interaction(tool_call_id)
|
|
follow_up_text = self._build_resume_text(pending, result)
|
|
self._state.pending_interaction = None
|
|
|
|
if not follow_up_text:
|
|
yield LLMStreamEvent(type="done")
|
|
return
|
|
|
|
self._cancel_event.clear()
|
|
response = await self.client.create_chat_completion(
|
|
messages=[{"role": "user", "content": follow_up_text}],
|
|
chatId=pending.chat_id,
|
|
detail=True,
|
|
stream=True,
|
|
)
|
|
try:
|
|
async for event in aiter_stream_events(response):
|
|
if self._cancel_event.is_set():
|
|
logger.info("FastGPT resume stream cancelled")
|
|
break
|
|
|
|
stop_after_event = False
|
|
for mapped in self._map_stream_event(event):
|
|
if mapped.type == "tool_call":
|
|
stop_after_event = True
|
|
yield mapped
|
|
if stop_after_event:
|
|
break
|
|
finally:
|
|
await self._close_stream_response(response)
|
|
|
|
async def _close_stream_response(self, response: Any) -> None:
|
|
if response is None:
|
|
return
|
|
|
|
# httpx async streaming responses must use `aclose()`.
|
|
aclose = getattr(response, "aclose", None)
|
|
if callable(aclose):
|
|
await aclose()
|
|
return
|
|
|
|
close = getattr(response, "close", None)
|
|
if callable(close):
|
|
maybe_awaitable = close()
|
|
if hasattr(maybe_awaitable, "__await__"):
|
|
await maybe_awaitable
|
|
|
|
def _ensure_chat_id(self) -> str:
|
|
chat_id = str(self._state.chat_id or "").strip()
|
|
if not chat_id:
|
|
chat_id = f"fastgpt_{uuid.uuid4().hex}"
|
|
self._state.chat_id = chat_id
|
|
return chat_id
|
|
|
|
def _build_request_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
|
|
non_empty = [msg for msg in messages if str(msg.content or "").strip()]
|
|
if not non_empty:
|
|
return [{"role": "user", "content": ""}]
|
|
|
|
latest_user = next((msg for msg in reversed(non_empty) if msg.role == "user"), None)
|
|
trailing_system = non_empty[-1] if non_empty and non_empty[-1].role == "system" else None
|
|
|
|
request: List[Dict[str, Any]] = []
|
|
if trailing_system and trailing_system is not latest_user:
|
|
request.append({"role": "system", "content": trailing_system.content.strip()})
|
|
if latest_user and str(latest_user.content or "").strip():
|
|
request.append({"role": "user", "content": latest_user.content.strip()})
|
|
return request
|
|
|
|
last_message = non_empty[-1]
|
|
payload = last_message.to_dict()
|
|
payload["content"] = str(payload.get("content") or "").strip()
|
|
return [payload]
|
|
|
|
def _extract_initial_greeting(self, payload: Any) -> Optional[str]:
|
|
if not isinstance(payload, dict):
|
|
return None
|
|
|
|
candidates: List[Any] = [
|
|
payload.get("app"),
|
|
payload.get("data"),
|
|
]
|
|
for container in candidates:
|
|
if not isinstance(container, dict):
|
|
continue
|
|
nested_app = container.get("app") if isinstance(container.get("app"), dict) else None
|
|
if nested_app:
|
|
text = self._welcome_text_from_app(nested_app)
|
|
if text:
|
|
return text
|
|
text = self._welcome_text_from_app(container)
|
|
if text:
|
|
return text
|
|
|
|
return None
|
|
|
|
@staticmethod
|
|
def _welcome_text_from_app(app_payload: Dict[str, Any]) -> Optional[str]:
|
|
chat_config = app_payload.get("chatConfig") if isinstance(app_payload.get("chatConfig"), dict) else {}
|
|
text = str(
|
|
chat_config.get("welcomeText")
|
|
or app_payload.get("welcomeText")
|
|
or ""
|
|
).strip()
|
|
return text or None
|
|
|
|
def _map_stream_event(self, event: Any) -> List[LLMStreamEvent]:
|
|
kind = str(getattr(event, "kind", "") or "")
|
|
data = getattr(event, "data", {})
|
|
if not isinstance(data, dict):
|
|
data = {}
|
|
|
|
if kind in {"data", "answer", "fastAnswer"}:
|
|
chunks = self._extract_text_chunks(kind, data)
|
|
return [LLMStreamEvent(type="text_delta", text=chunk) for chunk in chunks if chunk]
|
|
|
|
if kind == "interactive":
|
|
return [self._build_interactive_tool_event(data)]
|
|
|
|
if kind == "error":
|
|
message = str(data.get("message") or data.get("error") or "FastGPT streaming error")
|
|
raise RuntimeError(message)
|
|
|
|
if kind == "done":
|
|
return [LLMStreamEvent(type="done")]
|
|
|
|
return []
|
|
|
|
@staticmethod
|
|
def _normalize_interactive_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
normalized = payload
|
|
wrapped = normalized.get("interactive")
|
|
if isinstance(wrapped, dict):
|
|
normalized = wrapped
|
|
|
|
interaction_type = str(normalized.get("type") or "").strip()
|
|
if interaction_type == "toolChildrenInteractive":
|
|
params = normalized.get("params") if isinstance(normalized.get("params"), dict) else {}
|
|
children_response = params.get("childrenResponse")
|
|
if isinstance(children_response, dict):
|
|
normalized = children_response
|
|
|
|
return normalized
|
|
|
|
def _extract_text_chunks(self, kind: str, data: Dict[str, Any]) -> List[str]:
|
|
if kind in {"answer", "fastAnswer"}:
|
|
text = str(data.get("text") or "")
|
|
if text:
|
|
return [text]
|
|
|
|
choices = data.get("choices") if isinstance(data.get("choices"), list) else []
|
|
if not choices:
|
|
text = str(data.get("text") or "")
|
|
return [text] if text else []
|
|
|
|
first = choices[0] if isinstance(choices[0], dict) else {}
|
|
delta = first.get("delta") if isinstance(first.get("delta"), dict) else {}
|
|
if isinstance(delta.get("content"), str) and delta.get("content"):
|
|
return [str(delta.get("content"))]
|
|
message = first.get("message") if isinstance(first.get("message"), dict) else {}
|
|
if isinstance(message.get("content"), str) and message.get("content"):
|
|
return [str(message.get("content"))]
|
|
return []
|
|
|
|
def _build_interactive_tool_event(self, payload: Dict[str, Any]) -> LLMStreamEvent:
|
|
normalized_payload = self._normalize_interactive_payload(payload)
|
|
prompt = self._parse_interactive_prompt(normalized_payload)
|
|
call_id = f"fgi_{uuid.uuid4().hex[:12]}"
|
|
pending = FastGPTPendingInteraction(
|
|
tool_call_id=call_id,
|
|
chat_id=self._ensure_chat_id(),
|
|
prompt=prompt,
|
|
timeout_ms=self.INTERACTIVE_TIMEOUT_MS,
|
|
fastgpt_event=dict(normalized_payload),
|
|
)
|
|
self._state.pending_interaction = pending
|
|
arguments = prompt.to_ws_arguments(chat_id=pending.chat_id)
|
|
tool_call = {
|
|
"id": call_id,
|
|
"type": "function",
|
|
"executor": "client",
|
|
"wait_for_response": True,
|
|
"timeout_ms": pending.timeout_ms,
|
|
"display_name": prompt.title or prompt.description or prompt.prompt or "FastGPT Interactive",
|
|
"function": {
|
|
"name": self.INTERACTIVE_TOOL_NAME,
|
|
"arguments": json.dumps(arguments, ensure_ascii=False),
|
|
},
|
|
}
|
|
return LLMStreamEvent(type="tool_call", tool_call=tool_call)
|
|
|
|
def _parse_interactive_prompt(self, payload: Dict[str, Any]) -> FastGPTInteractivePrompt:
|
|
params = payload.get("params") if isinstance(payload.get("params"), dict) else {}
|
|
kind = str(payload.get("type") or "userSelect").strip() or "userSelect"
|
|
title = str(
|
|
payload.get("title")
|
|
or params.get("title")
|
|
or payload.get("nodeName")
|
|
or payload.get("label")
|
|
or ""
|
|
).strip()
|
|
description = str(
|
|
payload.get("description")
|
|
or payload.get("desc")
|
|
or params.get("description")
|
|
or params.get("desc")
|
|
or ""
|
|
).strip()
|
|
prompt_text = str(
|
|
payload.get("opener")
|
|
or params.get("opener")
|
|
or payload.get("intro")
|
|
or params.get("intro")
|
|
or payload.get("prompt")
|
|
or params.get("prompt")
|
|
or payload.get("text")
|
|
or params.get("text")
|
|
or title
|
|
or description
|
|
).strip()
|
|
required = self._coerce_bool(payload.get("required"), default=True)
|
|
multiple = self._coerce_bool(params.get("multiple") or payload.get("multiple"), default=False)
|
|
submit_label = str(params.get("submitText") or payload.get("submitText") or "Continue").strip() or "Continue"
|
|
cancel_label = str(params.get("cancelText") or payload.get("cancelText") or "Cancel").strip() or "Cancel"
|
|
|
|
options: List[FastGPTOption] = []
|
|
raw_options = params.get("userSelectOptions") if isinstance(params.get("userSelectOptions"), list) else []
|
|
for index, raw_option in enumerate(raw_options):
|
|
if isinstance(raw_option, str):
|
|
value = raw_option.strip()
|
|
if not value:
|
|
continue
|
|
options.append(FastGPTOption(id=f"option_{index}", label=value, value=value))
|
|
continue
|
|
if not isinstance(raw_option, dict):
|
|
continue
|
|
label = str(raw_option.get("label") or raw_option.get("value") or raw_option.get("id") or "").strip()
|
|
value = str(raw_option.get("value") or raw_option.get("label") or raw_option.get("id") or "").strip()
|
|
option_id = str(raw_option.get("id") or value or f"option_{index}").strip()
|
|
if not label and not value:
|
|
continue
|
|
options.append(
|
|
FastGPTOption(
|
|
id=option_id or f"option_{index}",
|
|
label=label or value,
|
|
value=value or label,
|
|
description=str(
|
|
raw_option.get("description")
|
|
or raw_option.get("desc")
|
|
or raw_option.get("intro")
|
|
or raw_option.get("summary")
|
|
or ""
|
|
).strip(),
|
|
)
|
|
)
|
|
|
|
form: List[FastGPTField] = []
|
|
raw_form = params.get("inputForm") if isinstance(params.get("inputForm"), list) else []
|
|
for index, raw_field in enumerate(raw_form):
|
|
if not isinstance(raw_field, dict):
|
|
continue
|
|
field_options: List[FastGPTOption] = []
|
|
nested_options = raw_field.get("options") if isinstance(raw_field.get("options"), list) else []
|
|
for opt_index, option in enumerate(nested_options):
|
|
if isinstance(option, str):
|
|
value = option.strip()
|
|
if not value:
|
|
continue
|
|
field_options.append(FastGPTOption(id=f"field_{index}_opt_{opt_index}", label=value, value=value))
|
|
continue
|
|
if not isinstance(option, dict):
|
|
continue
|
|
label = str(option.get("label") or option.get("value") or option.get("id") or "").strip()
|
|
value = str(option.get("value") or option.get("label") or option.get("id") or "").strip()
|
|
option_id = str(option.get("id") or value or f"field_{index}_opt_{opt_index}").strip()
|
|
if not label and not value:
|
|
continue
|
|
field_options.append(
|
|
FastGPTOption(
|
|
id=option_id or f"field_{index}_opt_{opt_index}",
|
|
label=label or value,
|
|
value=value or label,
|
|
description=str(
|
|
option.get("description")
|
|
or option.get("desc")
|
|
or option.get("intro")
|
|
or option.get("summary")
|
|
or ""
|
|
).strip(),
|
|
)
|
|
)
|
|
name = str(raw_field.get("key") or raw_field.get("name") or raw_field.get("label") or f"field_{index}").strip()
|
|
label = str(raw_field.get("label") or raw_field.get("name") or name).strip()
|
|
form.append(
|
|
FastGPTField(
|
|
name=name or f"field_{index}",
|
|
label=label or name or f"field_{index}",
|
|
input_type=str(raw_field.get("type") or raw_field.get("inputType") or "text").strip() or "text",
|
|
required=self._coerce_bool(raw_field.get("required"), default=False),
|
|
placeholder=str(
|
|
raw_field.get("placeholder")
|
|
or raw_field.get("description")
|
|
or raw_field.get("desc")
|
|
or ""
|
|
).strip(),
|
|
default=raw_field.get("defaultValue", raw_field.get("default")),
|
|
options=field_options,
|
|
)
|
|
)
|
|
|
|
return FastGPTInteractivePrompt(
|
|
kind="userInput" if kind == "userInput" else "userSelect",
|
|
title=title,
|
|
description=description,
|
|
prompt=prompt_text,
|
|
required=required,
|
|
multiple=multiple,
|
|
submit_label=submit_label,
|
|
cancel_label=cancel_label,
|
|
options=options,
|
|
form=form,
|
|
raw=dict(payload),
|
|
)
|
|
|
|
def _require_pending_interaction(self, tool_call_id: str) -> FastGPTPendingInteraction:
|
|
pending = self._state.pending_interaction
|
|
if pending is None or pending.tool_call_id != tool_call_id:
|
|
raise ValueError(f"FastGPT interaction not pending for tool call: {tool_call_id}")
|
|
return pending
|
|
|
|
def _build_resume_text(self, pending: FastGPTPendingInteraction, result: Dict[str, Any]) -> str:
|
|
status = result.get("status") if isinstance(result.get("status"), dict) else {}
|
|
status_code = self._safe_int(status.get("code"), default=0)
|
|
output = result.get("output") if isinstance(result.get("output"), dict) else {}
|
|
action = str(output.get("action") or "").strip().lower()
|
|
|
|
if action == "cancel" or status_code == 499:
|
|
return ""
|
|
if status_code == 422:
|
|
raise ValueError("Invalid FastGPT interactive payload from client")
|
|
if status_code and not 200 <= status_code < 300:
|
|
raise ValueError(f"FastGPT interactive result rejected with status {status_code}")
|
|
if action and action != "submit":
|
|
raise ValueError(f"Unsupported FastGPT interactive action: {action}")
|
|
|
|
payload = output.get("result") if isinstance(output.get("result"), dict) else output
|
|
if not isinstance(payload, dict):
|
|
raise ValueError("FastGPT interactive client result must be an object")
|
|
|
|
if pending.prompt.kind == "userSelect":
|
|
selected = str(payload.get("selected") or "").strip()
|
|
if selected:
|
|
return selected
|
|
selected_values = payload.get("selected_values") if isinstance(payload.get("selected_values"), list) else []
|
|
values = [str(item).strip() for item in selected_values if str(item).strip()]
|
|
if values:
|
|
return ", ".join(values)
|
|
text_value = str(payload.get("text") or "").strip()
|
|
return text_value
|
|
|
|
text_value = str(payload.get("text") or "").strip()
|
|
if text_value:
|
|
return text_value
|
|
fields = payload.get("fields") if isinstance(payload.get("fields"), dict) else {}
|
|
compact_fields = {str(key): value for key, value in fields.items()}
|
|
if compact_fields:
|
|
return json.dumps(compact_fields, ensure_ascii=False)
|
|
return ""
|
|
|
|
@staticmethod
|
|
def _coerce_bool(value: Any, *, default: bool) -> bool:
|
|
if isinstance(value, bool):
|
|
return value
|
|
if isinstance(value, str):
|
|
normalized = value.strip().lower()
|
|
if normalized in {"true", "1", "yes", "on"}:
|
|
return True
|
|
if normalized in {"false", "0", "no", "off"}:
|
|
return False
|
|
return default
|
|
|
|
@staticmethod
|
|
def _safe_int(value: Any, *, default: int) -> int:
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return default
|