Refactor state extraction and content cleaning in chat endpoint

- Introduced a new function to extract the first state code and clean content by removing state tags.
- Updated the chat endpoint to utilize the new extraction function, improving state handling and content processing.
- Enhanced logging to provide clearer insights into extracted state codes and cleaned content.
This commit is contained in:
Xin Wang
2026-06-18 17:10:02 +08:00
parent aa2768acc0
commit 569dae4446

View File

@@ -15,6 +15,7 @@ import time
router = APIRouter()
FORM_EXTRACT_MODULE_NAME = "文本内容提取事故信息"
STATE_TAG_PATTERN = re.compile(r"<state>\s*(\d+)\s*</state>", flags=re.DOTALL)
STATUS_CODE_MAP = {
'0000': '结束通话',
'0001': '转接人工',
@@ -50,27 +51,12 @@ def normalize_stage_code(stage_code: str) -> str:
return stage_code
def extract_state_and_content(data1: str) -> dict | None:
"""
Extracts the state and content from a string in the format <state>STATE</state>content.
Args:
data1: The input string.
Returns:
A dictionary with 'state' and 'content' keys if a match is found,
otherwise None.
"""
data1 = data1.strip()
regex = r"<state>(.*?)</state>(.*)"
match = re.search(regex, data1, flags=re.DOTALL)
if match:
return {
"state": match.group(1),
"content": match.group(2),
}
return None
def extract_first_state_and_clean_content(text: str) -> tuple[str | None, str]:
"""Return the first state code and content with all state tags removed."""
match = STATE_TAG_PATTERN.search(text)
if not match:
return None, text
return match.group(1), STATE_TAG_PATTERN.sub("", text)
def parse_json_value(value):
@@ -104,7 +90,7 @@ def extract_form_update_from_flow_nodes(nodes):
if not isinstance(extract_result, dict):
return {}
form_update = extract_result.get("formUpdate") or extract_result.get("form") or ""
form_update = extract_result.get("formUpdate", "")
if not form_update:
return {}
return parse_json_value(form_update)
@@ -221,6 +207,7 @@ async def chat(
)
buffer = ""
state_filter_buffer = ""
state_code_found = False
module_form_sent = False
@@ -255,6 +242,41 @@ async def chat(
if not chunk:
return []
return [flush_text_delta(chunk)]
def clean_later_state_tags(text: str, *, final: bool = False) -> str:
nonlocal state_filter_buffer
state_filter_buffer += text
state_filter_buffer = STATE_TAG_PATTERN.sub("", state_filter_buffer)
state_start_index = state_filter_buffer.find("<state")
if state_start_index >= 0:
cleaned = state_filter_buffer[:state_start_index]
state_filter_buffer = state_filter_buffer[state_start_index:]
return cleaned
if final:
if state_filter_buffer.startswith("<state"):
state_filter_buffer = ""
return ""
cleaned = state_filter_buffer
state_filter_buffer = ""
return cleaned
hold_length = 0
state_prefix = "<state>"
max_suffix_length = min(len(state_filter_buffer), len(state_prefix) - 1)
for suffix_length in range(1, max_suffix_length + 1):
if state_prefix.startswith(state_filter_buffer[-suffix_length:]):
hold_length = suffix_length
if hold_length:
cleaned = state_filter_buffer[:-hold_length]
state_filter_buffer = state_filter_buffer[-hold_length:]
return cleaned
cleaned = state_filter_buffer
state_filter_buffer = ""
return cleaned
async for event in aiter_stream_events(response):
try:
@@ -298,9 +320,8 @@ async def chat(
if not state_code_found:
# Check for <state>XXXX</state> pattern
match = re.search(r"<state>(.*?)</state>", buffer, flags=re.DOTALL)
if match:
state_code = match.group(1)
state_code, cleaned_content = extract_first_state_and_clean_content(buffer)
if state_code:
# Apply logic to map/adjust state code
nextStageCode = normalize_stage_code(state_code)
@@ -319,15 +340,15 @@ async def chat(
state_code_found = True
# Send remaining content as text_delta
remaining_content = buffer[match.end():]
if remaining_content:
for text_event in build_text_delta_events(remaining_content):
if cleaned_content:
for text_event in build_text_delta_events(cleaned_content):
yield text_event
buffer = "" # Clear buffer after extracting state
else:
for text_event in build_text_delta_events(delta_content):
yield text_event
cleaned_content = clean_later_state_tags(delta_content)
if cleaned_content:
for text_event in build_text_delta_events(cleaned_content):
yield text_event
buffer = ""
except Exception as e:
@@ -339,6 +360,11 @@ async def chat(
if not state_code_found and buffer:
for text_event in build_text_delta_events(buffer):
yield text_event
elif state_code_found:
cleaned_content = clean_later_state_tags("", final=True)
if cleaned_content:
for text_event in build_text_delta_events(cleaned_content):
yield text_event
for text_event in flush_text_chunker_events():
yield text_event
@@ -470,11 +496,12 @@ async def chat(
if isinstance(content, str):
logger.debug("content是一个str")
state_and_content = extract_state_and_content(content)
if state_and_content:
logger.debug(f"解析后的state和content为: {state_and_content}")
content_stage_code = state_and_content['state']
content = state_and_content['content']
content_stage_code, content = extract_first_state_and_clean_content(content)
if content_stage_code:
logger.debug(
f"解析后的第一个state为: {content_stage_code}, "
f"移除state标签后的content为: {content}"
)
else:
raise ValueError("大模型回复中的state解析失败")
else: