From a6777a827bcc820aeddbd4473ab769859cec1870 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Wed, 17 Jun 2026 14:18:24 +0800 Subject: [PATCH] Add text chunking functionality to chat endpoint - Introduced SentenceTextChunker and SentenceTextChunkerConfig for improved text processing in chat responses. - Updated chat endpoint to conditionally use text chunking based on the new 'useTextChunk' parameter from the request. - Enhanced logging to include 'useTextChunk' status and adjusted text delta handling to support chunked responses. - Modified ProcessRequest_chat model to include 'useTextChunk' field for request handling. - Added unit tests for SentenceTextChunker to ensure correct chunking behavior and edge case handling. --- src/api/endpoints.py | 43 ++++++++++-- src/schemas/models.py | 1 + src/voice/text_chunker.py | 135 ++++++++++++++++++++++++++++++++++++++ test/test_text_chunker.py | 57 ++++++++++++++++ 4 files changed, 232 insertions(+), 4 deletions(-) create mode 100644 src/voice/text_chunker.py create mode 100644 test/test_text_chunker.py diff --git a/src/api/endpoints.py b/src/api/endpoints.py index 1b766dc..28b9601 100644 --- a/src/api/endpoints.py +++ b/src/api/endpoints.py @@ -7,6 +7,7 @@ from fastgpt_client.exceptions import ( ) from ..core.fastgpt_client import get_fastgpt_client from ..core.config import Config +from ..voice.text_chunker import SentenceTextChunker, SentenceTextChunkerConfig from loguru import logger import json import re @@ -173,12 +174,14 @@ async def chat( """Handle chat completion request.""" json_data = request.model_dump() need_form_update = json_data.get('needFormUpdate', False) + use_text_chunk = json_data.get('useTextChunk', False) chat_variables = {'needFormUpdate': need_form_update} request_started_at = time.perf_counter() logger.info( "Chat request received " f"sessionId={json_data['sessionId']} stream={stream} " - f"needFormUpdate={need_form_update} text_len={len(json_data.get('text', ''))} " + f"needFormUpdate={need_form_update} useTextChunk={use_text_chunk} " + f"text_len={len(json_data.get('text', ''))} " f"input={json_data.get('text', '')!r}" ) @@ -191,6 +194,17 @@ async def chat( text_delta_count = 0 output_chunks = [] form_update_payload = {} + text_chunker = ( + SentenceTextChunker( + SentenceTextChunkerConfig( + min_chars=1, + max_chars=0, + use_soft_breaks=False, + ) + ) + if use_text_chunk + else None + ) try: # Use SDK's create_chat_completion with stream=True response = await client.create_chat_completion( @@ -227,6 +241,20 @@ async def chat( def flush_form_update(form_update): return create_sse_event("formUpdate", form_update) + + def build_text_delta_events(text: str): + if not text: + return [] + chunks = text_chunker.feed(text) if text_chunker else [text] + return [flush_text_delta(chunk) for chunk in chunks if chunk] + + def flush_text_chunker_events(): + if not text_chunker: + return [] + chunk = text_chunker.flush() + if not chunk: + return [] + return [flush_text_delta(chunk)] async for event in aiter_stream_events(response): try: @@ -294,10 +322,12 @@ async def chat( # Send remaining content as text_delta remaining_content = buffer[match.end():] if remaining_content: - yield flush_text_delta(remaining_content) + for text_event in build_text_delta_events(remaining_content): + yield text_event buffer = "" # Clear buffer after extracting state else: - yield flush_text_delta(delta_content) + for text_event in build_text_delta_events(delta_content): + yield text_event buffer = "" except Exception as e: @@ -307,7 +337,11 @@ async def chat( # If stream ends and no state code found (unlikely if format is strict), # we might want to send what we have if not state_code_found and buffer: - yield flush_text_delta(buffer) + for text_event in build_text_delta_events(buffer): + yield text_event + + for text_event in flush_text_chunker_events(): + yield text_event text_delta_end_ms = ( f"{(last_text_delta_at - stream_started_at) * 1000:.1f}" @@ -320,6 +354,7 @@ async def chat( f"duration_ms={(time.perf_counter() - stream_started_at) * 1000:.1f} " f"text_delta_end_ms={text_delta_end_ms} " f"text_delta_count={text_delta_count} " + f"useTextChunk={use_text_chunk} " f"stage_code_found={state_code_found} formUpdate_sent={module_form_sent} " f"output={''.join(output_chunks)!r} " f"formUpdate={form_update_payload!r}" diff --git a/src/schemas/models.py b/src/schemas/models.py index f178959..a7425fa 100644 --- a/src/schemas/models.py +++ b/src/schemas/models.py @@ -6,6 +6,7 @@ class ProcessRequest_chat(BaseModel): timeStamp: str = Field(..., max_length=32) text: str = Field(...) needFormUpdate: bool = False + useTextChunk: bool = False class ProcessResponse_chat(BaseModel): sessionId: str = Field(..., max_length=64) diff --git a/src/voice/text_chunker.py b/src/voice/text_chunker.py new file mode 100644 index 0000000..27410a4 --- /dev/null +++ b/src/voice/text_chunker.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +SENTENCE_ENDING_PUNCTUATION = frozenset(".!?;。!?;") +SOFT_BREAK_PUNCTUATION = frozenset(",,、::") +CLOSING_PUNCTUATION = frozenset("\"'”’)]})】》」』") + + +@dataclass(frozen=True) +class SentenceTextChunkerConfig: + """Configuration for streaming text chunks sent to TTS.""" + + min_chars: int = 1 + max_chars: int = 80 + use_soft_breaks: bool = True + + +class SentenceTextChunker: + """Lightweight Pipecat-style sentence chunker for streaming TTS text. + + The chunker waits for one non-whitespace lookahead character after sentence + punctuation before emitting a sentence. This avoids splitting too early when + a punctuation mark arrives at the end of a stream token. + """ + + def __init__(self, config: SentenceTextChunkerConfig | None = None) -> None: + self._config = config or SentenceTextChunkerConfig() + self._buffer = "" + self._needs_lookahead = False + self._last_soft_break = -1 + + @property + def text(self) -> str: + return self._buffer.strip(" ") + + def feed(self, text: str) -> list[str]: + """Append streaming text and return chunks ready for TTS.""" + chunks: list[str] = [] + if not text: + return chunks + + for char in text: + self._buffer += char + index = len(self._buffer) - 1 + + if char in SOFT_BREAK_PUNCTUATION: + self._last_soft_break = index + 1 + + if self._needs_lookahead and char.strip(): + self._needs_lookahead = False + chunk = self._pop_sentence_chunk() + if chunk: + chunks.append(chunk) + continue + + if char in SENTENCE_ENDING_PUNCTUATION and not self._is_decimal_point(index): + self._needs_lookahead = True + + chunk = self._pop_soft_chunk_if_needed() + if chunk: + chunks.append(chunk) + + return chunks + + def flush(self) -> str | None: + """Return any remaining buffered text at end of stream.""" + if not self._buffer: + return None + chunk = self._buffer.strip(" ") + self.reset() + return chunk or None + + def reset(self) -> None: + self._buffer = "" + self._needs_lookahead = False + self._last_soft_break = -1 + + def _pop_sentence_chunk(self) -> str | None: + end = self._sentence_end_index() + if end is None: + return None + chunk = self._buffer[:end].strip(" ") + if len(chunk) < self._config.min_chars: + return None + self._buffer = self._buffer[end:] + self._last_soft_break = self._find_last_soft_break() + return chunk + + def _sentence_end_index(self) -> int | None: + index = 0 + while index < len(self._buffer): + char = self._buffer[index] + if char in SENTENCE_ENDING_PUNCTUATION and not self._is_decimal_point(index): + end = index + 1 + while end < len(self._buffer) and self._buffer[end] in CLOSING_PUNCTUATION: + end += 1 + return end + index += 1 + return None + + def _pop_soft_chunk_if_needed(self) -> str | None: + if ( + not self._config.use_soft_breaks + or self._config.max_chars <= 0 + or len(self._buffer) < self._config.max_chars + or self._last_soft_break <= 0 + ): + return None + + chunk = self._buffer[: self._last_soft_break].strip(" ") + if len(chunk) < self._config.min_chars: + return None + + self._buffer = self._buffer[self._last_soft_break :] + self._last_soft_break = self._find_last_soft_break() + self._needs_lookahead = False + return chunk + + def _find_last_soft_break(self) -> int: + for index in range(len(self._buffer) - 1, -1, -1): + if self._buffer[index] in SOFT_BREAK_PUNCTUATION: + return index + 1 + return -1 + + def _is_decimal_point(self, index: int) -> bool: + if self._buffer[index] != ".": + return False + return ( + index > 0 + and index + 1 < len(self._buffer) + and self._buffer[index - 1].isdigit() + and self._buffer[index + 1].isdigit() + ) diff --git a/test/test_text_chunker.py b/test/test_text_chunker.py new file mode 100644 index 0000000..cb2f2ad --- /dev/null +++ b/test/test_text_chunker.py @@ -0,0 +1,57 @@ +from src.voice.text_chunker import SentenceTextChunker, SentenceTextChunkerConfig + + +def test_chinese_sentence_chunks_wait_for_lookahead(): + chunker = SentenceTextChunker() + chunks = [] + + for token in ["你好", "世界", "。", "下一", "句话", "。"]: + chunks.extend(chunker.feed(token)) + + assert chunks == ["你好世界。"] + assert chunker.flush() == "下一句话。" + + +def test_flush_returns_pending_text(): + chunker = SentenceTextChunker() + + assert chunker.feed("还没有句号") == [] + assert chunker.flush() == "还没有句号" + assert chunker.flush() is None + + +def test_decimal_point_does_not_split_sentence(): + chunker = SentenceTextChunker() + + chunks = chunker.feed("价格是29.95元。下一句") + + assert chunks == ["价格是29.95元。"] + assert chunker.flush() == "下一句" + + +def test_soft_break_after_max_chars(): + chunker = SentenceTextChunker( + SentenceTextChunkerConfig( + min_chars=1, + max_chars=12, + use_soft_breaks=True, + ) + ) + + chunks = chunker.feed("这是一段比较长的话,需要先切一下继续播放") + + assert chunks == ["这是一段比较长的话,"] + assert chunker.flush() == "需要先切一下继续播放" + + +def test_can_disable_soft_breaks(): + chunker = SentenceTextChunker( + SentenceTextChunkerConfig( + min_chars=1, + max_chars=12, + use_soft_breaks=False, + ) + ) + + assert chunker.feed("这是一段比较长的话,需要先切一下继续播放") == [] + assert chunker.flush() == "这是一段比较长的话,需要先切一下继续播放"