1111 lines
48 KiB
Python
1111 lines
48 KiB
Python
import argparse
|
||
import asyncio
|
||
import base64
|
||
import json
|
||
import logging
|
||
import os
|
||
import random
|
||
import sys
|
||
import re
|
||
from dataclasses import asdict, dataclass
|
||
from functools import partial
|
||
|
||
import aiohttp
|
||
import httpx
|
||
from dotenv import load_dotenv
|
||
|
||
from livekit import api, rtc
|
||
from livekit.agents import (
|
||
Agent,
|
||
AgentSession,
|
||
AudioConfig,
|
||
BackgroundAudioPlayer,
|
||
BuiltinAudioClip,
|
||
JobContext,
|
||
JobProcess,
|
||
UserStateChangedEvent,
|
||
MetricsCollectedEvent,
|
||
RoomInputOptions,
|
||
RoomOutputOptions,
|
||
RunContext,
|
||
WorkerOptions,
|
||
cli,
|
||
get_job_context,
|
||
metrics,
|
||
RoomIO
|
||
)
|
||
from livekit.agents.llm import ImageContent, ToolError, function_tool
|
||
from typing import Any, List, Optional
|
||
from livekit.agents.voice.avatar import DataStreamAudioOutput
|
||
from livekit.agents.voice.io import PlaybackFinishedEvent
|
||
from livekit.agents.voice.room_io import ATTRIBUTE_PUBLISH_ON_BEHALF
|
||
from livekit.plugins import silero
|
||
from livekit.plugins import openai, azure, minimax, aliyun, volcengine
|
||
# from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
||
|
||
from datetime import datetime
|
||
|
||
logger = logging.getLogger("basic-agent")
|
||
|
||
# Load environment variables from .env file in the agents directory
|
||
env_path = os.path.join(os.path.dirname(__file__), ".env")
|
||
load_dotenv(env_path)
|
||
# Also try loading from current directory as fallback
|
||
load_dotenv()
|
||
|
||
AVATAR_IDENTITY = "avatar_worker"
|
||
|
||
DEFAULT_INSTRUCTIONS = """# 角色
|
||
你是无锡交警智能机器人,负责收集交通事故信息。
|
||
|
||
你像真人一样和用户对话,而不是机器人。你收到的文字是经过ASR识别的语音输入。
|
||
|
||
# 当前日期和时间
|
||
日期:{datetime}
|
||
星期:{weekday}
|
||
|
||
# 能力
|
||
- 你具有调用工具操作前端界面系统的能力
|
||
- ask_image_capture工具被调用后会在系统播放拍摄的目标和需求,所以你每次在调用它之前不需要重复引导用户拍摄什么
|
||
|
||
# 任务
|
||
你的职责是全流程引导用户完成:事故信息采集 -> 现场证据拍照 -> 驾驶员信息核实。
|
||
|
||
## 事故信息采集阶段
|
||
- 在事故信息采集阶段:询问是否有人受伤,请求用户简单描述事故情况,询问事故发生时间并通过复述标准化时间(xx年xx月xx日xx时xx分)向用户确认,询问事故车辆数量,询问事故发生的原因(例如追尾、刮擦、碰撞等)。采集完成后进入现场证据拍照阶段
|
||
- 如果用户回答已包含需要问题的答案,改为与用户确认答案是否正确
|
||
- 采集完成之后进入现场证据拍照阶段
|
||
|
||
## 现场证据拍照阶段
|
||
- 在现场证据拍照阶段:使用askImageCapture工具引导用户依次拍摄照片:1. 第一辆车的车牌;2. 第一辆车的碰撞位置;3. 第一辆车的驾驶员正脸;
|
||
- 如果车辆涉及一辆车,则询问一位驾驶员,如果涉及两辆车,则要求拍摄:4. 第二辆车的车牌;5. 第二辆车的碰撞位置;6. 第二辆车的驾驶员正脸;
|
||
- 拍摄完成之后和用户确认识别的车牌号是否正确
|
||
- 完成之后进入驾驶员信息核实阶段
|
||
|
||
## 驾驶员信息核实阶段
|
||
- 你只处理事故车辆有一辆或者两辆的情况,超过两辆的情况需要转人工处理
|
||
- 对于一辆车辆的情况,你首先询问司机的姓名,之后根据车牌号查询驾驶员手机号,如果查询到则用手机号后四位向用户确认,如果未查询到或者用户告知手机号后四位错误,则询问用户手机号。接着根据车牌号查询驾驶员身份证号,如果查询到则用身份证号后四位向用户确认,如果未查询到或者用户告知身份证号后四位错误,则询问用户身份证号
|
||
- 对于两辆车辆的情况,你在手机收集完成第一位驾驶员的信息后主动要求与第二位驾驶员通话,你通过主动询问确认电话转接之后再进行后续的通话
|
||
- 收集第二位驾驶员的过程与第一位驾驶员相同
|
||
- 完成之后进入后续办理提示阶段
|
||
|
||
## 后续办理提示阶段
|
||
- 在后续办理提示阶段:使用ask_important_question显示已经提取的交通事故信息,提示用户点击转人工继续处理,用户点击之后调用enter_hand_off_to_human_mode工具转人工。
|
||
|
||
# 要求
|
||
- 在通话开始
|
||
- 你会在ask_image_capture的prompt参数中告诉用户拍摄的目标,所以避免在对话中重复描述需要用户拍摄什么
|
||
- 使用get_mobile_by_plate和get_id_card_by_plate的时候不要告诉用户正在查询,执行工具查看结果即可
|
||
|
||
# 回复风格
|
||
- 使用第一人称,语言简洁
|
||
- 一次询问一个问题
|
||
- 不要在你的回复中使用 emojis, asterisks, markdown, 或其他特殊字符
|
||
- 不同阶段直接的过渡语句自然
|
||
- 你已经说过下面的开场白所以不需要重复说:“您好,这里是无锡交警,我将为您远程处理交通事故。请将人员撤离至路侧安全区域,开启危险报警双闪灯、放置三角警告牌、做好安全防护,谨防二次事故伤害。若您已经准备好了,请点击继续办理,如需人工服务,请说转人工。”
|
||
"""
|
||
|
||
DEFAULT_TALKING_MODE = 'push_to_talk'
|
||
|
||
# ## 黄金对话路径示例 (GOLDEN_CONVERSATION_PATH)
|
||
|
||
# ```
|
||
# # 事故初审阶段
|
||
# AI: 您好,请问现场有人受伤吗?
|
||
# 用户: 没有人员受伤。
|
||
# AI: 请您简单描述一下事故的经过。
|
||
# 用户: 我的车等红灯时被后面的车追尾了。
|
||
# AI: 请问事故大概发生在什么时间?
|
||
# 用户: 今天早上八点左右发生的。
|
||
# AI: 一共涉及几辆车?
|
||
# 用户: 两辆。
|
||
# AI: 事故主要原因是追尾,对吗?
|
||
# 用户: 对,是被追尾。
|
||
# AI: 好的,接下来我们将进行现场证据采集。
|
||
|
||
# # 现场证据拍照阶段
|
||
# AI: 现在需要采集事故现场照片,请先拍摄事故现场的整体照片并上传。
|
||
# (调用 ask_image_capture,prompt="请拍摄事故现场的整体照片", extract_target="事故现场", requirements="清晰、完整,无模糊和遮挡")
|
||
# 工具返回: {"success": true, "analysis": {"target_found": false, "quality_ok": false, "quality_notes": ["画面模糊"]}, "raw": "..."}
|
||
# 用户上传模糊照片
|
||
# AI:您上传的照片不够清晰或存在遮挡,请确保画面清晰、完整地拍摄事故现场,请重新拍摄并上传。
|
||
# (调用 ask_image_capture,prompt="请重新清晰完整地拍摄事故现场照片", extract_target="事故现场", requirements="清晰、完整,无模糊和遮挡")
|
||
# 工具返回: {"success": true, "analysis": {"target_found": true, "quality_ok": true, "summary": "清晰的事故现场概览"}, "raw": "..."}
|
||
# 用户上传清晰照片
|
||
# AI:很好,事故现场照片已收到。接下来请拍摄驾驶员正脸照片。
|
||
# (调用 ask_image_capture,prompt="请拍摄驾驶员正脸照片", extract_target="驾驶员正脸", requirements="面部清晰、正对镜头")
|
||
# 工具返回: {"success": true, "analysis": {"target_found": false, "quality_ok": false, "quality_notes": ["未正对镜头"]}, "raw": "..."}
|
||
# 用户上传侧脸照片
|
||
# AI:检测到照片未正对镜头或面部不清楚,请正对镜头并确保面部清晰,再次拍摄并上传正脸照片。
|
||
# (调用 ask_image_capture,prompt="请正对镜头并拍摄清晰的驾驶员正脸照片", extract_target="驾驶员正脸", requirements="面部清晰、正对镜头")
|
||
# 工具返回: {"success": true, "analysis": {"target_found": true, "quality_ok": true, "summary": "驾驶员正脸清晰"}, "raw": "..."}
|
||
# 用户上传正脸照片
|
||
# AI:驾驶员正脸照片已收到。最后,请拍摄并上传车辆号牌的清晰照片。
|
||
# (调用 ask_image_capture,prompt="请拍摄车辆号牌照片", extract_target="车辆号牌", requirements="完整、清晰、无遮挡")
|
||
# 工具返回: {"success": true, "analysis": {"target_found": true, "quality_ok": true, "summary": "车牌完整可读", "target_details": "苏A12345"}, "raw": "..."}
|
||
# 用户上传合格车牌照片
|
||
# AI:我识别到车牌号是“苏A12345”,请确认是否正确,如需修改请告知。
|
||
|
||
# # 驾驶员信息核实阶段
|
||
# AI: 请提供您的姓名,用于信息核实。
|
||
# 用户: 张三
|
||
# AI:(调用 get_mobile_by_plate,plate_number="苏A12345")
|
||
# 工具返回: {"success": true, "plate": "苏A12345", "mobile": "13800001234"}
|
||
# AI:系统查询到登记手机号尾号1234,请确认是否正确;如不正确,请告知真实手机号。
|
||
# 用户: 正确
|
||
# AI:(调用 get_id_card_by_plate,plate_number="苏A12345")
|
||
# 工具返回: {"success": true, "plate": "苏A12345", "id_card": "320101198001011234"}
|
||
# AI:系统查询到登记身份证尾号1234,请确认是否正确;如不正确,请告知完整身份证号。
|
||
# 用户: 正确
|
||
# AI: 信息采集已完成,请点击确认进入后续办理。(调用 askImportantQuestion,message="请确认已采集信息:\n姓名:张三\n手机号:13800001234(尾号1234)\n身份证号:320101198001011234(尾号1234)\n车牌号:苏A12345", options=["确认无误", "需要修改"])
|
||
# 工具返回: {"selection": "确认无误"}
|
||
# AI: 已确认,后续流程请点击转人工继续处理。(调用 enter_hand_off_to_human_mode)
|
||
# ```
|
||
|
||
|
||
@dataclass
|
||
class AvatarConnectionInfo:
|
||
room_name: str
|
||
url: str
|
||
"""LiveKit server URL"""
|
||
token: str
|
||
"""Token for avatar worker to join"""
|
||
|
||
|
||
class MyAgent(Agent):
|
||
def __init__(self, instructions: str, vision_backend: str = "aliyun") -> None:
|
||
self.vision_backend = vision_backend
|
||
self._tasks = []
|
||
self._captured_images = []
|
||
self._image_event = asyncio.Event()
|
||
super().__init__(
|
||
instructions=instructions,
|
||
)
|
||
|
||
async def on_enter(self):
|
||
self.session.generate_reply(
|
||
instructions="调用ask_important_question,message=\"您好,这里是无锡交警,我将为您远程处理交通事故。请将人员撤离至路侧安全区域,开启危险报警双闪灯、放置三角警告牌、做好安全防护,谨防二次事故伤害。若您已经准备好了,请点击继续办理。\",options=[\"继续办理\"]", allow_interruptions=False)
|
||
|
||
# Register byte stream handler for image uploads from frontend
|
||
def _image_received_handler(reader, participant_identity):
|
||
task = asyncio.create_task(
|
||
self._image_received(reader, participant_identity)
|
||
)
|
||
self._tasks.append(task)
|
||
task.add_done_callback(lambda t: self._tasks.remove(t))
|
||
|
||
# Add the handler when the agent joins
|
||
get_job_context().room.register_byte_stream_handler("image", _image_received_handler)
|
||
|
||
async def _send_chat_message(self, message: str):
|
||
"""Helper to send a chat message to the room."""
|
||
try:
|
||
room = get_job_context().room
|
||
import time
|
||
# Construct payload matching what useChat likely expects or just a simple JSON
|
||
# Standard LiveKit chat format often involves a simple JSON wrapper
|
||
payload = json.dumps({
|
||
"message": message,
|
||
"timestamp": int(time.time() * 1000)
|
||
}).encode('utf-8')
|
||
|
||
await room.local_participant.publish_data(
|
||
payload=payload,
|
||
topic="lk-chat-topic"
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Failed to send chat message: {e}")
|
||
|
||
async def _image_received(self, reader, participant_identity):
|
||
"""Handle image uploads from the frontend via byte stream."""
|
||
logger.info(f"Received image upload from participant: {participant_identity}")
|
||
image_bytes = bytes()
|
||
async for chunk in reader:
|
||
image_bytes += chunk
|
||
|
||
logger.info(f"Image received: {len(image_bytes)} bytes from {participant_identity}")
|
||
|
||
# Store the received image bytes in memory
|
||
self._captured_images.append(image_bytes)
|
||
self._image_event.set()
|
||
|
||
# Notify user that image was received
|
||
# chat_ctx = self.chat_ctx.copy()
|
||
# chat_ctx.add_message(
|
||
# role="user",
|
||
# content=["I have uploaded an image. You can analyze it when I ask you to."]
|
||
# )
|
||
# await self.update_chat_ctx(chat_ctx)
|
||
|
||
# Trigger a reply acknowledging receipt
|
||
# self.session.generate_reply()
|
||
|
||
def _normalize_vision_response(self, raw_content: Any) -> str:
|
||
"""Convert various content formats to a clean JSON string (no code fences)."""
|
||
# Handle multimodal content list responses
|
||
if isinstance(raw_content, list):
|
||
text_parts: List[str] = []
|
||
for part in raw_content:
|
||
if isinstance(part, dict):
|
||
if part.get("type") == "text" and "text" in part:
|
||
text_parts.append(str(part["text"]))
|
||
else:
|
||
text_parts.append(str(part))
|
||
raw_content = "".join(text_parts)
|
||
|
||
content = str(raw_content).strip()
|
||
if content.startswith("```"):
|
||
# Strip fences like ```json ... ```
|
||
content = content.lstrip("`")
|
||
if content.lower().startswith("json"):
|
||
content = content[4:].lstrip()
|
||
if "```" in content:
|
||
content = content.split("```", 1)[0].strip()
|
||
return content
|
||
|
||
async def describe_image_with_api(
|
||
self,
|
||
base64_image: str,
|
||
backend: str = "siliconflow",
|
||
prompt: str = "",
|
||
extract_target: str = "",
|
||
requirements: str = "",
|
||
) -> str:
|
||
"""
|
||
Send base64 image to Vision API (SiliconFlow or Aliyun) and request a
|
||
JSON-formatted analysis aligned with ask_image_capture signature.
|
||
"""
|
||
if backend == "aliyun":
|
||
return await self._describe_image_aliyun(
|
||
base64_image,
|
||
prompt=prompt,
|
||
extract_target=extract_target,
|
||
requirements=requirements,
|
||
)
|
||
else:
|
||
return await self._describe_image_siliconflow(
|
||
base64_image,
|
||
prompt=prompt,
|
||
extract_target=extract_target,
|
||
requirements=requirements,
|
||
)
|
||
|
||
async def _describe_image_aliyun(
|
||
self,
|
||
base64_image: str,
|
||
prompt: str,
|
||
extract_target: str,
|
||
requirements: str,
|
||
) -> str:
|
||
logger.info("Sending image to Aliyun Vision API (Qwen-VL-Max)...")
|
||
api_key = os.getenv("DASHSCOPE_API_KEY")
|
||
if not api_key:
|
||
logger.error("DASHSCOPE_API_KEY not found")
|
||
return "Error: DASHSCOPE_API_KEY not configured."
|
||
|
||
url = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
|
||
|
||
analysis_prompt = (
|
||
"请根据以下任务要求对图片进行结构化分析,并且只返回紧凑的JSON:"
|
||
f'{{"prompt": "{prompt}", '
|
||
f'"extract_target": "{extract_target}", '
|
||
f'"requirements": "{requirements}", '
|
||
'"summary": "<用不超过50字中文概述图片>", '
|
||
'"target_found": <true|false>, '
|
||
'"target_details": "<若存在目标,描述关键特征;否则写未发现>", '
|
||
'"quality_ok": <true|false>, '
|
||
'"quality_notes": ["<不满足要求的原因,若满足可为空数组>"]}'
|
||
"。不要添加额外说明、前缀或代码块。"
|
||
)
|
||
|
||
payload = {
|
||
"model": "qwen-vl-plus",
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||
}
|
||
},
|
||
{
|
||
"type": "text",
|
||
"text": analysis_prompt
|
||
}
|
||
]
|
||
}
|
||
]
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(url, json=payload, headers=headers) as response:
|
||
if response.status == 200:
|
||
data = await response.json()
|
||
raw_content = data['choices'][0]['message']['content']
|
||
normalized = self._normalize_vision_response(raw_content)
|
||
try:
|
||
parsed = json.loads(normalized)
|
||
description = json.dumps(parsed, ensure_ascii=False)
|
||
except Exception:
|
||
description = normalized
|
||
logger.info(f"Got image description from Aliyun: {description}")
|
||
return description
|
||
else:
|
||
error_text = await response.text()
|
||
logger.error(f"Aliyun API error: {response.status} - {error_text}")
|
||
return "Failed to analyze image due to Aliyun API error."
|
||
except Exception as e:
|
||
logger.error(f"Exception calling Aliyun Vision API: {e}")
|
||
return f"Error analyzing image: {str(e)}"
|
||
|
||
async def _describe_image_siliconflow(
|
||
self,
|
||
base64_image: str,
|
||
prompt: str,
|
||
extract_target: str,
|
||
requirements: str,
|
||
) -> str:
|
||
"""
|
||
Send base64 image to SiliconFlow Vision API.
|
||
"""
|
||
logger.info("Sending image to SiliconFlow Vision API...")
|
||
|
||
# NOTE: You need to set your SiliconFlow API key here or in your .env file
|
||
api_key = os.getenv("SILICONFLOW_API_KEY", "your-api-key")
|
||
|
||
url = "https://api.siliconflow.cn/v1/chat/completions"
|
||
|
||
analysis_prompt = (
|
||
"请根据以下任务要求输出JSON结果,不要附加任何解释或代码块:"
|
||
f'{{"prompt": "{prompt}", '
|
||
f'"extract_target": "{extract_target}", '
|
||
f'"requirements": "{requirements}", '
|
||
'"summary": "<50字内中文概述图片>", '
|
||
'"target_found": <true|false>, '
|
||
'"target_details": "<若找到目标描述关键特征,否则写未发现>", '
|
||
'"quality_ok": <true|false>, '
|
||
'"quality_notes": ["<若不满足拍摄要求,列出问题;满足则为空数组>"]}'
|
||
)
|
||
|
||
payload = {
|
||
"model": "Qwen/Qwen3-VL-8B-Instruct",
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||
}
|
||
},
|
||
{
|
||
"type": "text",
|
||
"text": analysis_prompt
|
||
}
|
||
]
|
||
}
|
||
],
|
||
"max_tokens": 512
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(url, json=payload, headers=headers) as response:
|
||
if response.status == 200:
|
||
data = await response.json()
|
||
raw_content = data['choices'][0]['message']['content']
|
||
normalized = self._normalize_vision_response(raw_content)
|
||
try:
|
||
parsed = json.loads(normalized)
|
||
description = json.dumps(parsed, ensure_ascii=False)
|
||
except Exception:
|
||
description = normalized
|
||
logger.info(f"Got image description: {description}")
|
||
return description
|
||
else:
|
||
error_text = await response.text()
|
||
logger.error(f"SiliconFlow API error: {response.status} - {error_text}")
|
||
return "Failed to analyze image due to API error."
|
||
except Exception as e:
|
||
logger.error(f"Exception calling Vision API: {e}")
|
||
return f"Error analyzing image: {str(e)}"
|
||
|
||
@function_tool()
|
||
async def ask_image_capture(
|
||
self,
|
||
context: RunContext,
|
||
prompt: str = "请拍摄照片",
|
||
extract_target: str = "车辆号牌",
|
||
requirements: str = "清晰完整、无模糊",
|
||
):
|
||
"""
|
||
请求用户根据提供的提示语拍摄或上传照片,并检查照片中是否包含指定的提取目标且满足相关要求。
|
||
|
||
参数:
|
||
context (RunContext): 代理运行上下文。
|
||
prompt (str, 可选): 提示用户拍摄什么内容(例如:“请拍摄事故现场照片”,“请拍摄车牌号”)。默认值为“请拍摄照片”。
|
||
extract_target (str, 可选): 期望照片中包含的主要目标(例如:“车辆号牌”)。默认值为“车辆号牌”。
|
||
requirements (str, 可选): 拍摄的具体要求说明(例如:清晰、构图完整)。默认值为“清晰完整、无模糊”。
|
||
|
||
返回:
|
||
dict: 简洁可序列化的工具结果。
|
||
- success: bool
|
||
- analysis: 解析后的结果(解析失败则为字符串)
|
||
- raw: 原始文本结果(字符串)
|
||
- error/details: 仅失败时返回
|
||
"""
|
||
await self._send_chat_message(
|
||
"┌─🔨 Call: ask_image_capture\n"
|
||
f"│ prompt: \"{prompt}\"\n"
|
||
"└───────────────"
|
||
)
|
||
try:
|
||
room = get_job_context().room
|
||
participant_identity = next(iter(room.remote_participants))
|
||
|
||
# Clear previous image event
|
||
self._image_event.clear()
|
||
|
||
# Speak the capture prompt so the user hears what to do
|
||
self.session.say(prompt, allow_interruptions=True)
|
||
|
||
# Ask for image capture and wait for user to capture/upload
|
||
response = await room.local_participant.perform_rpc(
|
||
destination_identity=participant_identity,
|
||
method="askImageCapture",
|
||
payload=json.dumps({
|
||
"prompt": prompt,
|
||
}),
|
||
response_timeout=120.0, # Longer timeout for user to capture image
|
||
)
|
||
logger.info(f"Received image capture response: {response}")
|
||
|
||
# Parse the response to check for success/error (image data comes via byte stream)
|
||
try:
|
||
response_data = json.loads(response)
|
||
|
||
if "error" in response_data:
|
||
return {
|
||
"success": False,
|
||
"error": "图片采集失败",
|
||
"details": response_data.get("error"),
|
||
}
|
||
|
||
# Wait for the image byte stream to complete (signaled by _image_received)
|
||
if not self._image_event.is_set():
|
||
try:
|
||
# Give a reasonable buffer for file transfer to finish after RPC returns
|
||
await asyncio.wait_for(self._image_event.wait(), timeout=10.0)
|
||
except asyncio.TimeoutError:
|
||
if not self._captured_images:
|
||
return {
|
||
"success": False,
|
||
"error": "未收到图片数据",
|
||
"details": "Timeout waiting for file stream",
|
||
}
|
||
|
||
# Use the LAST stored image
|
||
last_image_bytes = self._captured_images[-1]
|
||
logger.info(f"Analyzing last stored image, size: {len(last_image_bytes)} bytes")
|
||
|
||
base64_image = base64.b64encode(last_image_bytes).decode('utf-8')
|
||
|
||
# Analyze the image using the configured vision backend
|
||
description = await self.describe_image_with_api(
|
||
base64_image,
|
||
self.vision_backend,
|
||
prompt=prompt,
|
||
extract_target=extract_target,
|
||
requirements=requirements,
|
||
)
|
||
logger.info(f"Image analysis result: {description}")
|
||
|
||
# Try to parse analysis as JSON; fall back to raw string
|
||
raw_text = description if isinstance(description, str) else json.dumps(description, ensure_ascii=False)
|
||
parsed_analysis = raw_text
|
||
try:
|
||
parsed_analysis = json.loads(raw_text)
|
||
except Exception:
|
||
parsed_analysis = raw_text
|
||
|
||
# Derive validity signals for the LLM to decide on retries
|
||
valid = False
|
||
problems = []
|
||
target_details = None
|
||
if isinstance(parsed_analysis, dict):
|
||
target_found = bool(parsed_analysis.get("target_found"))
|
||
quality_ok = bool(parsed_analysis.get("quality_ok"))
|
||
valid = target_found and quality_ok
|
||
target_details = parsed_analysis.get("target_details")
|
||
if not target_found:
|
||
problems.append(f"未发现{extract_target}")
|
||
if not quality_ok:
|
||
notes = parsed_analysis.get("quality_notes") or []
|
||
problems.extend(notes)
|
||
|
||
# Format the full analysis output for chat
|
||
analysis_display = str(raw_text)
|
||
if isinstance(parsed_analysis, dict):
|
||
# Pretty print JSON for better readability
|
||
analysis_display = json.dumps(parsed_analysis, ensure_ascii=False, indent=2)
|
||
|
||
await self._send_chat_message(
|
||
"┌─✅ Result: ask_image_capture\n"
|
||
f"│ valid: {valid}, problems: {problems}\n"
|
||
f"│ analysis:\n{analysis_display}\n"
|
||
"└───────────────"
|
||
)
|
||
|
||
return {
|
||
"success": True,
|
||
"analysis": parsed_analysis,
|
||
"raw": raw_text,
|
||
"valid": valid,
|
||
"problems": problems,
|
||
"target_details": target_details,
|
||
}
|
||
|
||
except json.JSONDecodeError:
|
||
logger.error(f"Failed to parse response: {response}")
|
||
return {
|
||
"success": False,
|
||
"error": "图片数据解析失败",
|
||
"details": response,
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to capture and analyze image: {e}")
|
||
raise ToolError(f"无法完成图片采集和分析: {str(e)}")
|
||
|
||
@function_tool()
|
||
async def get_mobile_by_plate(
|
||
self,
|
||
context: RunContext,
|
||
plate_number: str,
|
||
):
|
||
"""
|
||
根据车牌号查询登记的手机号。
|
||
|
||
返回:
|
||
dict: { success: bool, plate: str, mobile: str | None, error?: str }
|
||
"""
|
||
normalized_plate = plate_number.strip().upper()
|
||
await self._send_chat_message(
|
||
"┌─🔨 Call: get_mobile_by_plate\n"
|
||
f"│ plate: \"{normalized_plate}\"\n"
|
||
"└───────────────"
|
||
)
|
||
# Generate random mobile number (11 digits: 1[3-9] + 9 random digits)
|
||
mobile_prefix = random.choice(['13', '14', '15', '16', '17', '18', '19'])
|
||
mobile_suffix = ''.join([str(random.randint(0, 9)) for _ in range(9)])
|
||
random_mobile = f"{mobile_prefix}{mobile_suffix}"
|
||
|
||
return {
|
||
"success": True,
|
||
"plate": normalized_plate,
|
||
"mobile": random_mobile,
|
||
}
|
||
|
||
@function_tool()
|
||
async def get_id_card_by_plate(
|
||
self,
|
||
context: RunContext,
|
||
plate_number: str,
|
||
):
|
||
"""
|
||
根据车牌号查询登记的身份证号。
|
||
|
||
返回:
|
||
dict: { success: bool, plate: str, id_card: str | None, error?: str }
|
||
"""
|
||
normalized_plate = plate_number.strip().upper()
|
||
await self._send_chat_message(
|
||
"┌─🔨 Call: get_id_card_by_plate\n"
|
||
f"│ plate: \"{normalized_plate}\"\n"
|
||
"└───────────────"
|
||
)
|
||
# Generate random ID card number (18 digits: 6-digit area code + 8-digit birth date + 3-digit sequence + 1 check digit)
|
||
# Area code: random 6 digits (typically 110000-659999 for Chinese ID cards)
|
||
area_code = random.randint(110000, 659999)
|
||
# Birth date: random date between 1950-01-01 and 2000-12-31
|
||
year = random.randint(1950, 2000)
|
||
month = random.randint(1, 12)
|
||
day = random.randint(1, 28) # Use 28 to avoid month-specific day issues
|
||
birth_date = f"{year:04d}{month:02d}{day:02d}"
|
||
# Sequence number: 3 random digits
|
||
sequence = random.randint(100, 999)
|
||
# Check digit: random digit or X (10% chance of X)
|
||
check_digit = 'X' if random.random() < 0.1 else str(random.randint(0, 9))
|
||
random_id_card = f"{area_code}{birth_date}{sequence}{check_digit}"
|
||
|
||
return {
|
||
"success": True,
|
||
"plate": normalized_plate,
|
||
"id_card": random_id_card,
|
||
}
|
||
|
||
@function_tool()
|
||
async def validate_mobile_number(
|
||
self,
|
||
context: RunContext,
|
||
mobile: str,
|
||
):
|
||
"""
|
||
检查手机号格式是否正确(大陆 11 位手机号,1[3-9] 开头)。
|
||
|
||
返回:
|
||
dict: { success: bool, valid: bool, mobile: str, error?: str }
|
||
"""
|
||
normalized = mobile.strip().replace(" ", "")
|
||
await self._send_chat_message(
|
||
"┌─🔨 Call: validate_mobile_number\n"
|
||
f"│ mobile: \"{normalized}\"\n"
|
||
"└───────────────"
|
||
)
|
||
is_valid = bool(re.fullmatch(r"1[3-9]\\d{9}", normalized))
|
||
if is_valid:
|
||
return {
|
||
"success": True,
|
||
"valid": True,
|
||
"mobile": normalized,
|
||
}
|
||
return {
|
||
"success": True,
|
||
"valid": False,
|
||
"mobile": normalized,
|
||
"error": "手机号格式不正确,应为1[3-9]开头的11位数字",
|
||
}
|
||
|
||
@function_tool()
|
||
async def validate_id_card_number(
|
||
self,
|
||
context: RunContext,
|
||
id_card: str,
|
||
):
|
||
"""
|
||
检查身份证号格式是否正确(18位含校验位X或15位纯数字)。
|
||
|
||
返回:
|
||
dict: { success: bool, valid: bool, id_card: str, error?: str }
|
||
"""
|
||
normalized = id_card.strip().replace(" ", "").upper()
|
||
await self._send_chat_message(
|
||
"┌─🔨 Call: validate_id_card_number\n"
|
||
f"│ id_card: \"{normalized}\"\n"
|
||
"└───────────────"
|
||
)
|
||
is_valid = bool(re.fullmatch(r"(\\d{17}[\\dX]|\\d{15})", normalized))
|
||
if is_valid:
|
||
return {
|
||
"success": True,
|
||
"valid": True,
|
||
"id_card": normalized,
|
||
}
|
||
return {
|
||
"success": True,
|
||
"valid": False,
|
||
"id_card": normalized,
|
||
"error": "身份证格式不正确,应为18位(末位可为X)或15位数字",
|
||
}
|
||
|
||
@function_tool()
|
||
async def enter_hand_off_to_human_mode(
|
||
self,
|
||
context: RunContext,
|
||
):
|
||
"""切换到“转人工”模式(前端电话界面进入人工处理)。返回成功/失败。"""
|
||
await self._send_chat_message("🔨 Call: enter_hand_off_to_human_mode")
|
||
try:
|
||
room = get_job_context().room
|
||
participant_identity = next(iter(room.remote_participants))
|
||
response = await room.local_participant.perform_rpc(
|
||
destination_identity=participant_identity,
|
||
method="enterHandOffToHumanMode",
|
||
payload=json.dumps({}),
|
||
response_timeout=5.0,
|
||
)
|
||
logger.info(f"Entered hand off to human mode: {response}")
|
||
await self._send_chat_message(f"✅ Result: enter_hand_off_to_human_mode\n • status: success")
|
||
return response
|
||
except Exception as e:
|
||
logger.error(f"Failed to enter hand off to human mode: {e}")
|
||
raise ToolError(f"Unable to enter hand off to human mode: {str(e)}")
|
||
|
||
@function_tool()
|
||
async def hang_up_call(
|
||
self,
|
||
context: RunContext,
|
||
):
|
||
"""挂断当前通话(结束会话),返回成功/失败。"""
|
||
await self._send_chat_message("🔨 Call: hang_up_call")
|
||
try:
|
||
room = get_job_context().room
|
||
participant_identity = next(iter(room.remote_participants))
|
||
response = await room.local_participant.perform_rpc(
|
||
destination_identity=participant_identity,
|
||
method="hangUpCall",
|
||
payload=json.dumps({}),
|
||
response_timeout=5.0,
|
||
)
|
||
logger.info(f"Hung up call: {response}")
|
||
await self._send_chat_message(f"✅ Result: hang_up_call\n • status: disconnected")
|
||
return response
|
||
except Exception as e:
|
||
logger.error(f"Failed to hang up call: {e}")
|
||
raise ToolError(f"Unable to hang up call: {str(e)}")
|
||
|
||
@function_tool()
|
||
async def ask_important_question(self, context: RunContext, message: str, options: Optional[List[str]] | str = None):
|
||
"""询问关键问题并等待用户选择选项,返回用户的选择结果。
|
||
|
||
参数:
|
||
message: 要朗读/展示的问题。
|
||
options: 可选项列表(字符串数组),如 ["继续办理", "转人工"]。
|
||
|
||
返回:
|
||
str: 用户选择的文本内容。
|
||
"""
|
||
await self._send_chat_message(
|
||
"┌─🔨 Call: ask_important_question\n"
|
||
f"│ message: \"{message}\"\n"
|
||
f"│ options: {options}\n"
|
||
"└───────────────"
|
||
)
|
||
try:
|
||
room = get_job_context().room
|
||
participant_identity = next(iter(room.remote_participants))
|
||
|
||
payload_data = {
|
||
"message": message,
|
||
}
|
||
|
||
if options:
|
||
# Support both list input and JSON-stringified list input
|
||
if isinstance(options, str):
|
||
try:
|
||
options = json.loads(options)
|
||
except json.JSONDecodeError:
|
||
logger.error(f"Failed to parse options string: {options}")
|
||
options = None
|
||
|
||
# Ensure options is a list of strings
|
||
if isinstance(options, list):
|
||
payload_data["options"] = options
|
||
|
||
# Speak the message
|
||
speech_handle = self.session.say(message, allow_interruptions=True)
|
||
|
||
# Wait for user selection with longer timeout since user needs time to respond
|
||
response = await room.local_participant.perform_rpc(
|
||
destination_identity=participant_identity,
|
||
method="askImportantQuestion",
|
||
payload=json.dumps(payload_data),
|
||
response_timeout=60.0, # Increased timeout for user interaction
|
||
)
|
||
|
||
# Interrupt speech if user makes a selection while agent is speaking
|
||
if speech_handle and hasattr(speech_handle, "interrupt"):
|
||
speech_handle.interrupt()
|
||
logger.info("Interrupted speech due to user selection")
|
||
|
||
logger.info(f"User made selection: {response}")
|
||
|
||
# Parse the response to get the user's selection
|
||
try:
|
||
response_data = json.loads(response)
|
||
user_selection = response_data.get("selection", "确认")
|
||
logger.info(f"User selected: {user_selection}")
|
||
|
||
await self._send_chat_message(
|
||
"┌─✅ Result: ask_important_question\n"
|
||
f"│ selection: \"{user_selection}\"\n"
|
||
"└───────────────"
|
||
)
|
||
return f"用户选择了: {user_selection}"
|
||
except json.JSONDecodeError:
|
||
logger.error(f"Failed to parse response: {response}")
|
||
return f"用户选择了: {response}"
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to ask important question: {e}")
|
||
raise ToolError(f"Unable to ask important question: {str(e)}")
|
||
|
||
|
||
def prewarm(proc: JobProcess):
|
||
proc.userdata["vad"] = silero.VAD.load()
|
||
|
||
|
||
async def launch_avatar(ctx: JobContext, avatar_dispatcher_url: str, avatar_identity: str) -> None:
|
||
"""
|
||
Send a request to the avatar service for it to join the room
|
||
|
||
This function should be wrapped in a avatar plugin.
|
||
"""
|
||
|
||
# create a token for the avatar to join the room
|
||
token = (
|
||
api.AccessToken()
|
||
.with_identity(avatar_identity)
|
||
.with_name("Avatar Runner")
|
||
.with_grants(api.VideoGrants(room_join=True, room=ctx.room.name))
|
||
.with_kind("agent")
|
||
.with_attributes({ATTRIBUTE_PUBLISH_ON_BEHALF: ctx.local_participant_identity})
|
||
.to_jwt()
|
||
)
|
||
|
||
logger.info(f"Sending connection info to avatar dispatcher {avatar_dispatcher_url}")
|
||
connection_info = AvatarConnectionInfo(room_name=ctx.room.name, url=ctx._info.url, token=token)
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.post(avatar_dispatcher_url, json=asdict(connection_info))
|
||
response.raise_for_status()
|
||
logger.info("Avatar handshake completed")
|
||
|
||
|
||
async def entrypoint(ctx: JobContext, avatar_dispatcher_url: str = None, vision_backend: str = "siliconflow", llm_backend: str = "deepseek"):
|
||
# inactivity_task: asyncio.Task | None = None
|
||
|
||
# async def user_presence_task():
|
||
# # try to ping the user 3 times, if we get no answer, close the session
|
||
# for _ in range(3):
|
||
# await session.generate_reply(
|
||
# instructions=(
|
||
# "你礼貌的询问用户是否还在通话。"
|
||
# )
|
||
# )
|
||
# await asyncio.sleep(10)
|
||
|
||
# session.shutdown()
|
||
|
||
# each log entry will include these fields
|
||
ctx.log_context_fields = {
|
||
"room": ctx.room.name,
|
||
}
|
||
|
||
# IMPORTANT: Connect to room first
|
||
logger.info("connecting to room")
|
||
await ctx.connect()
|
||
|
||
logger.info("waiting for participant")
|
||
participant = await ctx.wait_for_participant()
|
||
logger.info(f"starting agent for participant {participant.identity}")
|
||
|
||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
logger.info(f"Current time: {current_time}")
|
||
|
||
# Calculate weekday in Chinese
|
||
weekday_map = {
|
||
0: "星期一",
|
||
1: "星期二",
|
||
2: "星期三",
|
||
3: "星期四",
|
||
4: "星期五",
|
||
5: "星期六",
|
||
6: "星期日"
|
||
}
|
||
current_weekday = weekday_map[datetime.now().weekday()]
|
||
logger.info(f"Current weekday: {current_weekday}")
|
||
|
||
initial_voice_id = "BV001_streaming" # Female voice
|
||
if participant.attributes.get("voice"):
|
||
initial_voice_id = participant.attributes.get("voice")
|
||
logger.info(f"User selected voice: {initial_voice_id}")
|
||
|
||
# initial_instructions = """Your name is Kelly. You would interact with users via voice.
|
||
# with that in mind keep your responses concise and to the point.
|
||
# do not use emojis, asterisks, markdown, or other special characters in your responses.
|
||
# You are curious and friendly, and have a sense of humor.
|
||
# you will speak chinese to the user"
|
||
# """
|
||
initial_instructions = DEFAULT_INSTRUCTIONS
|
||
if participant.attributes.get("instructions"):
|
||
initial_instructions = participant.attributes.get("instructions")
|
||
logger.info(f"User selected instructions: {initial_instructions}")
|
||
|
||
# Replace the datetime and weekday placeholders to avoid KeyError from other braces in the prompt
|
||
initial_instructions = initial_instructions.replace("{datetime}", current_time)
|
||
initial_instructions = initial_instructions.replace("{weekday}", current_weekday)
|
||
logger.info(f"Initial instructions: {initial_instructions}")
|
||
|
||
if llm_backend == "dashscope":
|
||
logger.info("Using DashScope DeepSeek backend")
|
||
llm = openai.LLM(
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||
model="deepseek-v3.2",
|
||
)
|
||
else:
|
||
logger.info("Using default DeepSeek backend")
|
||
llm = openai.LLM.with_deepseek(
|
||
model='deepseek-chat',
|
||
)
|
||
|
||
session = AgentSession(
|
||
vad=ctx.proc.userdata["vad"],
|
||
# turn_detection=MultilingualModel(),
|
||
# any combination of STT, LLM, TTS, or realtime API can be used
|
||
# stt = aliyun.STT(model="paraformer-realtime-v2"),
|
||
stt = volcengine.BigModelSTT(
|
||
app_id="8607675070"
|
||
),
|
||
# stt = volcengine.STT(
|
||
# app_id="2931820332",
|
||
# cluster="volcengine_streaming_common"
|
||
# ),
|
||
llm=llm,
|
||
# tts=aliyun.TTS(model="cosyvoice-v2", voice="longcheng_v2"),
|
||
tts=volcengine.TTS(
|
||
app_id="2931820332",
|
||
cluster="volcano_tts",
|
||
voice=initial_voice_id,
|
||
sample_rate=8000,
|
||
speed=1.1,
|
||
),
|
||
# allow the LLM to generate a response while waiting for the end of turn
|
||
preemptive_generation=True,
|
||
# sometimes background noise could interrupt the agent session, these are considered false positive interruptions
|
||
# when it's detected, you may resume the agent's speech
|
||
resume_false_interruption=True,
|
||
false_interruption_timeout=1.0,
|
||
# Increase the maximum number of function calls per turn to avoid hitting the limit
|
||
max_tool_steps=15,
|
||
)
|
||
room_io = RoomIO(session, room=ctx.room)
|
||
await room_io.start()
|
||
|
||
# log metrics as they are emitted, and total usage after session is over
|
||
usage_collector = metrics.UsageCollector()
|
||
|
||
@session.on("metrics_collected")
|
||
def _on_metrics_collected(ev: MetricsCollectedEvent):
|
||
metrics.log_metrics(ev.metrics)
|
||
usage_collector.collect(ev.metrics)
|
||
|
||
async def log_usage():
|
||
summary = usage_collector.get_summary()
|
||
logger.info(f"Usage: {summary}")
|
||
|
||
# shutdown callbacks are triggered when the session is over
|
||
ctx.add_shutdown_callback(log_usage)
|
||
|
||
# Launch avatar if avatar_dispatcher_url is provided
|
||
if avatar_dispatcher_url:
|
||
await launch_avatar(ctx, avatar_dispatcher_url, AVATAR_IDENTITY)
|
||
session.output.audio = DataStreamAudioOutput(
|
||
ctx.room,
|
||
destination_identity=AVATAR_IDENTITY,
|
||
# (optional) wait for the avatar to publish video track before generating a reply
|
||
wait_remote_track=rtc.TrackKind.KIND_VIDEO,
|
||
)
|
||
|
||
@session.output.audio.on("playback_finished")
|
||
def on_playback_finished(ev: PlaybackFinishedEvent) -> None:
|
||
# the avatar should notify when the audio playback is finished
|
||
logger.info(
|
||
"playback_finished",
|
||
extra={
|
||
"playback_position": ev.playback_position,
|
||
"interrupted": ev.interrupted,
|
||
},
|
||
)
|
||
|
||
# @session.on("user_state_changed")
|
||
# def _user_state_changed(ev: UserStateChangedEvent):
|
||
# nonlocal inactivity_task
|
||
# if ev.new_state == "away":
|
||
# inactivity_task = asyncio.create_task(user_presence_task())
|
||
# return
|
||
|
||
# # ev.new_state: listening, speaking, ..
|
||
# if inactivity_task is not None:
|
||
# inactivity_task.cancel()
|
||
|
||
await session.start(
|
||
agent=MyAgent(initial_instructions, vision_backend=vision_backend),
|
||
room=ctx.room,
|
||
room_input_options=RoomInputOptions(
|
||
# uncomment to enable Krisp BVC noise cancellation
|
||
# noise_cancellation=noise_cancellation.BVC(),
|
||
),
|
||
room_output_options=RoomOutputOptions(transcription_enabled=True),
|
||
)
|
||
|
||
# disable input audio at the start
|
||
_talking_mode = DEFAULT_TALKING_MODE
|
||
if _talking_mode == "push_to_talk":
|
||
session.input.set_audio_enabled(False)
|
||
else:
|
||
session.input.set_audio_enabled(True)
|
||
|
||
@ctx.room.local_participant.register_rpc_method("start_turn")
|
||
async def start_turn(data: rtc.RpcInvocationData):
|
||
try:
|
||
session.interrupt()
|
||
except RuntimeError as e:
|
||
logger.error(f"Failed to interrupt session: {e}")
|
||
# Raise RPC error so client can detect interrupt failure
|
||
# Use ERROR_INTERNAL (code 13) to indicate application error
|
||
raise rtc.RpcError(
|
||
code=13, # ERROR_INTERNAL
|
||
message="Application error in method handler"
|
||
)
|
||
|
||
session.clear_user_turn()
|
||
|
||
# listen to the caller if multi-user
|
||
room_io.set_participant(data.caller_identity)
|
||
session.input.set_audio_enabled(True)
|
||
|
||
@ctx.room.local_participant.register_rpc_method("end_turn")
|
||
async def end_turn(data: rtc.RpcInvocationData):
|
||
session.input.set_audio_enabled(False)
|
||
session.commit_user_turn(
|
||
# the timeout for the final transcript to be received after committing the user turn
|
||
# increase this value if the STT is slow to respond
|
||
transcript_timeout=10.0,
|
||
# the duration of the silence to be appended to the STT to make it generate the final transcript
|
||
stt_flush_duration=2.0,
|
||
)
|
||
|
||
@ctx.room.local_participant.register_rpc_method("cancel_turn")
|
||
async def cancel_turn(data: rtc.RpcInvocationData):
|
||
session.input.set_audio_enabled(False)
|
||
session.clear_user_turn()
|
||
logger.info("cancel turn")
|
||
|
||
@ctx.room.local_participant.register_rpc_method("switch_ptt_and_rt")
|
||
async def switch_ptt_and_rt(data: rtc.RpcInvocationData):
|
||
nonlocal _talking_mode
|
||
_talking_mode = "push_to_talk" if _talking_mode == "realtime" else "realtime"
|
||
if _talking_mode == "push_to_talk":
|
||
session.input.set_audio_enabled(False)
|
||
else:
|
||
session.input.set_audio_enabled(True)
|
||
return json.dumps({"success": True, "mode": _talking_mode})
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--avatar-url", type=str, default=None, help="Avatar dispatcher URL")
|
||
parser.add_argument("--vision-backend", type=str, default="aliyun", choices=["siliconflow", "aliyun"], help="Vision API backend")
|
||
parser.add_argument("--llm-backend", type=str, default="dashscope", choices=["deepseek", "dashscope"], help="LLM backend")
|
||
args, remaining_args = parser.parse_known_args()
|
||
sys.argv = sys.argv[:1] + remaining_args
|
||
|
||
if args.avatar_url:
|
||
cli.run_app(WorkerOptions(entrypoint_fnc=partial(entrypoint, avatar_dispatcher_url=args.avatar_url, vision_backend=args.vision_backend, llm_backend=args.llm_backend), prewarm_fnc=prewarm))
|
||
else:
|
||
cli.run_app(WorkerOptions(entrypoint_fnc=partial(entrypoint, vision_backend=args.vision_backend, llm_backend=args.llm_backend), prewarm_fnc=prewarm))
|