diff --git a/src/api/endpoints.py b/src/api/endpoints.py
index 852b1d9..c290d31 100644
--- a/src/api/endpoints.py
+++ b/src/api/endpoints.py
@@ -15,6 +15,7 @@ import time
router = APIRouter()
FORM_EXTRACT_MODULE_NAME = "文本内容提取事故信息"
+STATE_TAG_PATTERN = re.compile(r"\s*(\d+)\s*", flags=re.DOTALL)
STATUS_CODE_MAP = {
'0000': '结束通话',
'0001': '转接人工',
@@ -50,27 +51,12 @@ def normalize_stage_code(stage_code: str) -> str:
return stage_code
-def extract_state_and_content(data1: str) -> dict | None:
- """
- Extracts the state and content from a string in the format STATEcontent.
-
- Args:
- data1: The input string.
-
- Returns:
- A dictionary with 'state' and 'content' keys if a match is found,
- otherwise None.
- """
- data1 = data1.strip()
- regex = r"(.*?)(.*)"
- match = re.search(regex, data1, flags=re.DOTALL)
-
- if match:
- return {
- "state": match.group(1),
- "content": match.group(2),
- }
- return None
+def extract_first_state_and_clean_content(text: str) -> tuple[str | None, str]:
+ """Return the first state code and content with all state tags removed."""
+ match = STATE_TAG_PATTERN.search(text)
+ if not match:
+ return None, text
+ return match.group(1), STATE_TAG_PATTERN.sub("", text)
def parse_json_value(value):
@@ -104,7 +90,7 @@ def extract_form_update_from_flow_nodes(nodes):
if not isinstance(extract_result, dict):
return {}
- form_update = extract_result.get("formUpdate") or extract_result.get("form") or ""
+ form_update = extract_result.get("formUpdate", "")
if not form_update:
return {}
return parse_json_value(form_update)
@@ -221,6 +207,7 @@ async def chat(
)
buffer = ""
+ state_filter_buffer = ""
state_code_found = False
module_form_sent = False
@@ -255,6 +242,41 @@ async def chat(
if not chunk:
return []
return [flush_text_delta(chunk)]
+
+ def clean_later_state_tags(text: str, *, final: bool = False) -> str:
+ nonlocal state_filter_buffer
+ state_filter_buffer += text
+ state_filter_buffer = STATE_TAG_PATTERN.sub("", state_filter_buffer)
+
+ state_start_index = state_filter_buffer.find("= 0:
+ cleaned = state_filter_buffer[:state_start_index]
+ state_filter_buffer = state_filter_buffer[state_start_index:]
+ return cleaned
+
+ if final:
+ if state_filter_buffer.startswith("XXXX pattern
- match = re.search(r"(.*?)", buffer, flags=re.DOTALL)
- if match:
- state_code = match.group(1)
+ state_code, cleaned_content = extract_first_state_and_clean_content(buffer)
+ if state_code:
# Apply logic to map/adjust state code
nextStageCode = normalize_stage_code(state_code)
@@ -319,15 +340,15 @@ async def chat(
state_code_found = True
- # Send remaining content as text_delta
- remaining_content = buffer[match.end():]
- if remaining_content:
- for text_event in build_text_delta_events(remaining_content):
+ if cleaned_content:
+ for text_event in build_text_delta_events(cleaned_content):
yield text_event
buffer = "" # Clear buffer after extracting state
else:
- for text_event in build_text_delta_events(delta_content):
- yield text_event
+ cleaned_content = clean_later_state_tags(delta_content)
+ if cleaned_content:
+ for text_event in build_text_delta_events(cleaned_content):
+ yield text_event
buffer = ""
except Exception as e:
@@ -339,6 +360,11 @@ async def chat(
if not state_code_found and buffer:
for text_event in build_text_delta_events(buffer):
yield text_event
+ elif state_code_found:
+ cleaned_content = clean_later_state_tags("", final=True)
+ if cleaned_content:
+ for text_event in build_text_delta_events(cleaned_content):
+ yield text_event
for text_event in flush_text_chunker_events():
yield text_event
@@ -470,11 +496,12 @@ async def chat(
if isinstance(content, str):
logger.debug("content是一个str")
- state_and_content = extract_state_and_content(content)
- if state_and_content:
- logger.debug(f"解析后的state和content为: {state_and_content}")
- content_stage_code = state_and_content['state']
- content = state_and_content['content']
+ content_stage_code, content = extract_first_state_and_clean_content(content)
+ if content_stage_code:
+ logger.debug(
+ f"解析后的第一个state为: {content_stage_code}, "
+ f"移除state标签后的content为: {content}"
+ )
else:
raise ValueError("大模型回复中的state解析失败")
else: