Add text chunking functionality to chat endpoint
- Introduced SentenceTextChunker and SentenceTextChunkerConfig for improved text processing in chat responses. - Updated chat endpoint to conditionally use text chunking based on the new 'useTextChunk' parameter from the request. - Enhanced logging to include 'useTextChunk' status and adjusted text delta handling to support chunked responses. - Modified ProcessRequest_chat model to include 'useTextChunk' field for request handling. - Added unit tests for SentenceTextChunker to ensure correct chunking behavior and edge case handling.
This commit is contained in:
@@ -7,6 +7,7 @@ from fastgpt_client.exceptions import (
|
||||
)
|
||||
from ..core.fastgpt_client import get_fastgpt_client
|
||||
from ..core.config import Config
|
||||
from ..voice.text_chunker import SentenceTextChunker, SentenceTextChunkerConfig
|
||||
from loguru import logger
|
||||
import json
|
||||
import re
|
||||
@@ -173,12 +174,14 @@ async def chat(
|
||||
"""Handle chat completion request."""
|
||||
json_data = request.model_dump()
|
||||
need_form_update = json_data.get('needFormUpdate', False)
|
||||
use_text_chunk = json_data.get('useTextChunk', False)
|
||||
chat_variables = {'needFormUpdate': need_form_update}
|
||||
request_started_at = time.perf_counter()
|
||||
logger.info(
|
||||
"Chat request received "
|
||||
f"sessionId={json_data['sessionId']} stream={stream} "
|
||||
f"needFormUpdate={need_form_update} text_len={len(json_data.get('text', ''))} "
|
||||
f"needFormUpdate={need_form_update} useTextChunk={use_text_chunk} "
|
||||
f"text_len={len(json_data.get('text', ''))} "
|
||||
f"input={json_data.get('text', '')!r}"
|
||||
)
|
||||
|
||||
@@ -191,6 +194,17 @@ async def chat(
|
||||
text_delta_count = 0
|
||||
output_chunks = []
|
||||
form_update_payload = {}
|
||||
text_chunker = (
|
||||
SentenceTextChunker(
|
||||
SentenceTextChunkerConfig(
|
||||
min_chars=1,
|
||||
max_chars=0,
|
||||
use_soft_breaks=False,
|
||||
)
|
||||
)
|
||||
if use_text_chunk
|
||||
else None
|
||||
)
|
||||
try:
|
||||
# Use SDK's create_chat_completion with stream=True
|
||||
response = await client.create_chat_completion(
|
||||
@@ -227,6 +241,20 @@ async def chat(
|
||||
|
||||
def flush_form_update(form_update):
|
||||
return create_sse_event("formUpdate", form_update)
|
||||
|
||||
def build_text_delta_events(text: str):
|
||||
if not text:
|
||||
return []
|
||||
chunks = text_chunker.feed(text) if text_chunker else [text]
|
||||
return [flush_text_delta(chunk) for chunk in chunks if chunk]
|
||||
|
||||
def flush_text_chunker_events():
|
||||
if not text_chunker:
|
||||
return []
|
||||
chunk = text_chunker.flush()
|
||||
if not chunk:
|
||||
return []
|
||||
return [flush_text_delta(chunk)]
|
||||
|
||||
async for event in aiter_stream_events(response):
|
||||
try:
|
||||
@@ -294,10 +322,12 @@ async def chat(
|
||||
# Send remaining content as text_delta
|
||||
remaining_content = buffer[match.end():]
|
||||
if remaining_content:
|
||||
yield flush_text_delta(remaining_content)
|
||||
for text_event in build_text_delta_events(remaining_content):
|
||||
yield text_event
|
||||
buffer = "" # Clear buffer after extracting state
|
||||
else:
|
||||
yield flush_text_delta(delta_content)
|
||||
for text_event in build_text_delta_events(delta_content):
|
||||
yield text_event
|
||||
buffer = ""
|
||||
|
||||
except Exception as e:
|
||||
@@ -307,7 +337,11 @@ async def chat(
|
||||
# If stream ends and no state code found (unlikely if format is strict),
|
||||
# we might want to send what we have
|
||||
if not state_code_found and buffer:
|
||||
yield flush_text_delta(buffer)
|
||||
for text_event in build_text_delta_events(buffer):
|
||||
yield text_event
|
||||
|
||||
for text_event in flush_text_chunker_events():
|
||||
yield text_event
|
||||
|
||||
text_delta_end_ms = (
|
||||
f"{(last_text_delta_at - stream_started_at) * 1000:.1f}"
|
||||
@@ -320,6 +354,7 @@ async def chat(
|
||||
f"duration_ms={(time.perf_counter() - stream_started_at) * 1000:.1f} "
|
||||
f"text_delta_end_ms={text_delta_end_ms} "
|
||||
f"text_delta_count={text_delta_count} "
|
||||
f"useTextChunk={use_text_chunk} "
|
||||
f"stage_code_found={state_code_found} formUpdate_sent={module_form_sent} "
|
||||
f"output={''.join(output_chunks)!r} "
|
||||
f"formUpdate={form_update_payload!r}"
|
||||
|
||||
@@ -6,6 +6,7 @@ class ProcessRequest_chat(BaseModel):
|
||||
timeStamp: str = Field(..., max_length=32)
|
||||
text: str = Field(...)
|
||||
needFormUpdate: bool = False
|
||||
useTextChunk: bool = False
|
||||
|
||||
class ProcessResponse_chat(BaseModel):
|
||||
sessionId: str = Field(..., max_length=64)
|
||||
|
||||
135
src/voice/text_chunker.py
Normal file
135
src/voice/text_chunker.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
SENTENCE_ENDING_PUNCTUATION = frozenset(".!?;。!?;")
|
||||
SOFT_BREAK_PUNCTUATION = frozenset(",,、::")
|
||||
CLOSING_PUNCTUATION = frozenset("\"'”’)]})】》」』")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SentenceTextChunkerConfig:
|
||||
"""Configuration for streaming text chunks sent to TTS."""
|
||||
|
||||
min_chars: int = 1
|
||||
max_chars: int = 80
|
||||
use_soft_breaks: bool = True
|
||||
|
||||
|
||||
class SentenceTextChunker:
|
||||
"""Lightweight Pipecat-style sentence chunker for streaming TTS text.
|
||||
|
||||
The chunker waits for one non-whitespace lookahead character after sentence
|
||||
punctuation before emitting a sentence. This avoids splitting too early when
|
||||
a punctuation mark arrives at the end of a stream token.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SentenceTextChunkerConfig | None = None) -> None:
|
||||
self._config = config or SentenceTextChunkerConfig()
|
||||
self._buffer = ""
|
||||
self._needs_lookahead = False
|
||||
self._last_soft_break = -1
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return self._buffer.strip(" ")
|
||||
|
||||
def feed(self, text: str) -> list[str]:
|
||||
"""Append streaming text and return chunks ready for TTS."""
|
||||
chunks: list[str] = []
|
||||
if not text:
|
||||
return chunks
|
||||
|
||||
for char in text:
|
||||
self._buffer += char
|
||||
index = len(self._buffer) - 1
|
||||
|
||||
if char in SOFT_BREAK_PUNCTUATION:
|
||||
self._last_soft_break = index + 1
|
||||
|
||||
if self._needs_lookahead and char.strip():
|
||||
self._needs_lookahead = False
|
||||
chunk = self._pop_sentence_chunk()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
continue
|
||||
|
||||
if char in SENTENCE_ENDING_PUNCTUATION and not self._is_decimal_point(index):
|
||||
self._needs_lookahead = True
|
||||
|
||||
chunk = self._pop_soft_chunk_if_needed()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def flush(self) -> str | None:
|
||||
"""Return any remaining buffered text at end of stream."""
|
||||
if not self._buffer:
|
||||
return None
|
||||
chunk = self._buffer.strip(" ")
|
||||
self.reset()
|
||||
return chunk or None
|
||||
|
||||
def reset(self) -> None:
|
||||
self._buffer = ""
|
||||
self._needs_lookahead = False
|
||||
self._last_soft_break = -1
|
||||
|
||||
def _pop_sentence_chunk(self) -> str | None:
|
||||
end = self._sentence_end_index()
|
||||
if end is None:
|
||||
return None
|
||||
chunk = self._buffer[:end].strip(" ")
|
||||
if len(chunk) < self._config.min_chars:
|
||||
return None
|
||||
self._buffer = self._buffer[end:]
|
||||
self._last_soft_break = self._find_last_soft_break()
|
||||
return chunk
|
||||
|
||||
def _sentence_end_index(self) -> int | None:
|
||||
index = 0
|
||||
while index < len(self._buffer):
|
||||
char = self._buffer[index]
|
||||
if char in SENTENCE_ENDING_PUNCTUATION and not self._is_decimal_point(index):
|
||||
end = index + 1
|
||||
while end < len(self._buffer) and self._buffer[end] in CLOSING_PUNCTUATION:
|
||||
end += 1
|
||||
return end
|
||||
index += 1
|
||||
return None
|
||||
|
||||
def _pop_soft_chunk_if_needed(self) -> str | None:
|
||||
if (
|
||||
not self._config.use_soft_breaks
|
||||
or self._config.max_chars <= 0
|
||||
or len(self._buffer) < self._config.max_chars
|
||||
or self._last_soft_break <= 0
|
||||
):
|
||||
return None
|
||||
|
||||
chunk = self._buffer[: self._last_soft_break].strip(" ")
|
||||
if len(chunk) < self._config.min_chars:
|
||||
return None
|
||||
|
||||
self._buffer = self._buffer[self._last_soft_break :]
|
||||
self._last_soft_break = self._find_last_soft_break()
|
||||
self._needs_lookahead = False
|
||||
return chunk
|
||||
|
||||
def _find_last_soft_break(self) -> int:
|
||||
for index in range(len(self._buffer) - 1, -1, -1):
|
||||
if self._buffer[index] in SOFT_BREAK_PUNCTUATION:
|
||||
return index + 1
|
||||
return -1
|
||||
|
||||
def _is_decimal_point(self, index: int) -> bool:
|
||||
if self._buffer[index] != ".":
|
||||
return False
|
||||
return (
|
||||
index > 0
|
||||
and index + 1 < len(self._buffer)
|
||||
and self._buffer[index - 1].isdigit()
|
||||
and self._buffer[index + 1].isdigit()
|
||||
)
|
||||
57
test/test_text_chunker.py
Normal file
57
test/test_text_chunker.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from src.voice.text_chunker import SentenceTextChunker, SentenceTextChunkerConfig
|
||||
|
||||
|
||||
def test_chinese_sentence_chunks_wait_for_lookahead():
|
||||
chunker = SentenceTextChunker()
|
||||
chunks = []
|
||||
|
||||
for token in ["你好", "世界", "。", "下一", "句话", "。"]:
|
||||
chunks.extend(chunker.feed(token))
|
||||
|
||||
assert chunks == ["你好世界。"]
|
||||
assert chunker.flush() == "下一句话。"
|
||||
|
||||
|
||||
def test_flush_returns_pending_text():
|
||||
chunker = SentenceTextChunker()
|
||||
|
||||
assert chunker.feed("还没有句号") == []
|
||||
assert chunker.flush() == "还没有句号"
|
||||
assert chunker.flush() is None
|
||||
|
||||
|
||||
def test_decimal_point_does_not_split_sentence():
|
||||
chunker = SentenceTextChunker()
|
||||
|
||||
chunks = chunker.feed("价格是29.95元。下一句")
|
||||
|
||||
assert chunks == ["价格是29.95元。"]
|
||||
assert chunker.flush() == "下一句"
|
||||
|
||||
|
||||
def test_soft_break_after_max_chars():
|
||||
chunker = SentenceTextChunker(
|
||||
SentenceTextChunkerConfig(
|
||||
min_chars=1,
|
||||
max_chars=12,
|
||||
use_soft_breaks=True,
|
||||
)
|
||||
)
|
||||
|
||||
chunks = chunker.feed("这是一段比较长的话,需要先切一下继续播放")
|
||||
|
||||
assert chunks == ["这是一段比较长的话,"]
|
||||
assert chunker.flush() == "需要先切一下继续播放"
|
||||
|
||||
|
||||
def test_can_disable_soft_breaks():
|
||||
chunker = SentenceTextChunker(
|
||||
SentenceTextChunkerConfig(
|
||||
min_chars=1,
|
||||
max_chars=12,
|
||||
use_soft_breaks=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert chunker.feed("这是一段比较长的话,需要先切一下继续播放") == []
|
||||
assert chunker.flush() == "这是一段比较长的话,需要先切一下继续播放"
|
||||
Reference in New Issue
Block a user