Refactor WebSocket authentication handling by removing auth requirements from the hello message. Update related documentation and schemas to reflect the changes in authentication strategy, simplifying the connection process.

This commit is contained in:
Xin Wang
2026-02-28 17:33:40 +08:00
parent 0821d73e7c
commit b4fa664d73
7 changed files with 7 additions and 67 deletions

View File

@@ -46,16 +46,12 @@ Server <- session.stopped
#### 1. Handshake: `hello` #### 1. Handshake: `hello`
客户端连接后发送的第一个消息,用于协议版本协商和认证 客户端连接后发送的第一个消息,用于协议版本协商。
```json ```json
{ {
"type": "hello", "type": "hello",
"version": "v1", "version": "v1"
"auth": {
"apiKey": "optional-api-key",
"jwt": "optional-jwt"
}
} }
``` ```
@@ -63,11 +59,6 @@ Server <- session.stopped
|---|---|---|---| |---|---|---|---|
| `type` | string | 是 | 固定为 `"hello"` | | `type` | string | 是 | 固定为 `"hello"` |
| `version` | string | 是 | 协议版本,固定为 `"v1"` | | `version` | string | 是 | 协议版本,固定为 `"v1"` |
| `auth` | object | 否 | 认证信息 |
**认证规则**
- 若配置了 `WS_API_KEY`,必须提供匹配的 `apiKey`
-`WS_REQUIRE_AUTH=true`,至少需要提供 `apiKey``jwt` 之一
--- ---
@@ -291,8 +282,6 @@ Server <- session.stopped
| `protocol.invalid_message` | 消息格式错误 | | `protocol.invalid_message` | 消息格式错误 |
| `protocol.order` | 消息顺序错误 | | `protocol.order` | 消息顺序错误 |
| `protocol.version_unsupported` | 协议版本不支持 | | `protocol.version_unsupported` | 协议版本不支持 |
| `auth.invalid_api_key` | API Key 无效 |
| `auth.required` | 需要认证 |
| `audio.invalid_pcm` | PCM 数据无效 | | `audio.invalid_pcm` | PCM 数据无效 |
| `audio.frame_size_mismatch` | 音频帧大小不匹配 | | `audio.frame_size_mismatch` | 音频帧大小不匹配 |
| `server.internal` | 服务端内部错误 | | `server.internal` | 服务端内部错误 |

View File

@@ -55,8 +55,6 @@ LOG_FORMAT=json
INACTIVITY_TIMEOUT_SEC=60 INACTIVITY_TIMEOUT_SEC=60
HEARTBEAT_INTERVAL_SEC=50 HEARTBEAT_INTERVAL_SEC=50
WS_PROTOCOL_VERSION=v1 WS_PROTOCOL_VERSION=v1
# WS_API_KEY=replace_with_shared_secret
WS_REQUIRE_AUTH=false
# CORS / ICE (JSON strings) # CORS / ICE (JSON strings)
CORS_ORIGINS=["http://localhost:3000","http://localhost:8080"] CORS_ORIGINS=["http://localhost:3000","http://localhost:8080"]

View File

@@ -493,8 +493,6 @@ class Settings(BaseSettings):
inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)") inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)")
heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds") heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds")
ws_protocol_version: str = Field(default="v1", description="Public WS protocol version") ws_protocol_version: str = Field(default="v1", description="Public WS protocol version")
ws_api_key: Optional[str] = Field(default=None, description="Optional API key required for WS hello auth")
ws_require_auth: bool = Field(default=False, description="Require auth in hello message even when ws_api_key is not set")
# Backend bridge configuration (for call/transcript persistence) # Backend bridge configuration (for call/transcript persistence)
backend_mode: str = Field( backend_mode: str = Field(

View File

@@ -124,7 +124,6 @@ class Session:
self.ws_state = WsSessionState.WAIT_HELLO self.ws_state = WsSessionState.WAIT_HELLO
self._pipeline_started = False self._pipeline_started = False
self.protocol_version: Optional[str] = None self.protocol_version: Optional[str] = None
self.authenticated: bool = False
# Track IDs # Track IDs
self.current_track_id: str = self.TRACK_CONTROL self.current_track_id: str = self.TRACK_CONTROL
@@ -264,7 +263,7 @@ class Session:
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported") await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
async def _handle_hello(self, message: HelloMessage) -> None: async def _handle_hello(self, message: HelloMessage) -> None:
"""Handle initial hello/auth/version negotiation.""" """Handle initial hello/version negotiation."""
if self.ws_state != WsSessionState.WAIT_HELLO: if self.ws_state != WsSessionState.WAIT_HELLO:
await self._send_error("client", "Duplicate hello", "protocol.order") await self._send_error("client", "Duplicate hello", "protocol.order")
return return
@@ -279,23 +278,6 @@ class Session:
self.ws_state = WsSessionState.STOPPED self.ws_state = WsSessionState.STOPPED
return return
auth_payload = message.auth
api_key = auth_payload.apiKey if auth_payload else None
jwt = auth_payload.jwt if auth_payload else None
if settings.ws_api_key:
if api_key != settings.ws_api_key:
await self._send_error("auth", "Invalid API key", "auth.invalid_api_key")
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
elif settings.ws_require_auth and not (api_key or jwt):
await self._send_error("auth", "Authentication required", "auth.required")
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
self.authenticated = True
self.protocol_version = message.version self.protocol_version = message.version
self.ws_state = WsSessionState.WAIT_START self.ws_state = WsSessionState.WAIT_START
await self._send_event( await self._send_event(
@@ -701,8 +683,6 @@ class Session:
return self.TRACK_AUDIO_IN return self.TRACK_AUDIO_IN
if stage in {"llm", "tts", "tool"}: if stage in {"llm", "tts", "tool"}:
return self.TRACK_AUDIO_OUT return self.TRACK_AUDIO_OUT
if str(code or "").strip().lower().startswith("auth."):
return self.TRACK_CONTROL
return self.TRACK_CONTROL return self.TRACK_CONTROL
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]: def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:

View File

@@ -33,18 +33,12 @@ If order is violated, server emits `error` with `code = "protocol.order"`.
```json ```json
{ {
"type": "hello", "type": "hello",
"version": "v1", "version": "v1"
"auth": {
"apiKey": "optional-api-key",
"jwt": "optional-jwt"
}
} }
``` ```
Rules: Rules:
- `version` must be `v1`. - `version` must be `v1`.
- If `WS_API_KEY` is configured on server, `auth.apiKey` must match.
- If `WS_REQUIRE_AUTH=true`, either `auth.apiKey` or `auth.jwt` must be present.
### `session.start` ### `session.start`
@@ -205,7 +199,7 @@ Common events:
- `trackId` convention: - `trackId` convention:
- `audio_in` for `stage in {audio, asr}` - `audio_in` for `stage in {audio, asr}`
- `audio_out` for `stage in {llm, tts, tool}` - `audio_out` for `stage in {llm, tts, tool}`
- `control` otherwise (including protocol/auth errors) - `control` otherwise (including protocol errors)
Track IDs (MVP fixed values): Track IDs (MVP fixed values):
- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`) - `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`)

View File

@@ -62,11 +62,7 @@
```json ```json
{ {
"type": "hello", "type": "hello",
"version": "v1", "version": "v1"
"auth": {
"apiKey": "optional-api-key",
"jwt": "optional-jwt"
}
} }
``` ```
@@ -76,13 +72,6 @@
|---|---|---|---|---|---| |---|---|---|---|---|---|
| `type` | string | 是 | 固定 `"hello"` | 消息类型 | 握手第一条消息 | | `type` | string | 是 | 固定 `"hello"` | 消息类型 | 握手第一条消息 |
| `version` | string | 是 | 固定 `"v1"` | 协议版本 | 版本不匹配会 `protocol.version_unsupported` 并断开 | | `version` | string | 是 | 固定 `"v1"` | 协议版本 | 版本不匹配会 `protocol.version_unsupported` 并断开 |
| `auth` | object \| null | 否 | 仅允许 `apiKey``jwt` | 认证载荷 | 认证策略由服务端配置决定 |
| `auth.apiKey` | string \| null | 否 | 任意字符串 | API Key | 若服务端配置 `WS_API_KEY`,必须精确匹配 |
| `auth.jwt` | string \| null | 否 | 任意字符串 | JWT 字符串 | 当 `WS_REQUIRE_AUTH=true` 时可用于满足“有认证信息”条件 |
认证行为:
- 若设置了 `WS_API_KEY`:必须提供且匹配 `auth.apiKey`,否则 `auth.invalid_api_key` 并关闭连接。
-`WS_REQUIRE_AUTH=true` 且未设置 `WS_API_KEY``auth.apiKey``auth.jwt` 至少一个非空,否则 `auth.required` 并关闭连接。
## 3.2 `session.start` ## 3.2 `session.start`
@@ -472,7 +461,7 @@
``` ```
字段语义: 字段语义:
- `sender`:错误来源角色(如 `client` / `server` / `auth` - `sender`:错误来源角色(如 `client` / `server`
- `code`:机器可读错误码 - `code`:机器可读错误码
- `message`:人类可读描述 - `message`:人类可读描述
- `stage`:阶段(`protocol|audio|asr|llm|tts|tool` - `stage`:阶段(`protocol|audio|asr|llm|tts|tool`
@@ -485,8 +474,6 @@
- `protocol.order` - `protocol.order`
- `protocol.version_unsupported` - `protocol.version_unsupported`
- `protocol.unsupported` - `protocol.unsupported`
- `auth.invalid_api_key`
- `auth.required`
- `audio.invalid_pcm` - `audio.invalid_pcm`
- `audio.frame_size_mismatch` - `audio.frame_size_mismatch`
- `audio.processing_failed` - `audio.processing_failed`

View File

@@ -19,15 +19,9 @@ class _StrictModel(BaseModel):
# Client -> Server messages # Client -> Server messages
class HelloAuth(_StrictModel):
apiKey: Optional[str] = None
jwt: Optional[str] = None
class HelloMessage(_StrictModel): class HelloMessage(_StrictModel):
type: Literal["hello"] type: Literal["hello"]
version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1") version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1")
auth: Optional[HelloAuth] = Field(default=None, description="Auth payload")
class SessionStartAudio(_StrictModel): class SessionStartAudio(_StrictModel):