diff --git a/pyproject.toml b/pyproject.toml index f7c73d49a..cd1ae27ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-ope livekit = [ "livekit~=0.22.0", "livekit-api~=0.8.2", "tenacity~=9.0.0" ] lmnt = [ "websockets~=13.1" ] local = [ "pyaudio~=0.2.14" ] -mcp = [ "mcp[cli]~=1.6.0" ] +mcp = [ "mcp[cli]~=1.9.4" ] mem0 = [ "mem0ai~=0.1.94" ] mlx-whisper = [ "mlx-whisper~=0.4.2" ] moondream = [ "einops~=0.8.0", "timm~=1.0.13", "transformers~=4.48.0" ] diff --git a/src/pipecat/services/mcp_service.py b/src/pipecat/services/mcp_service.py index a644d8f1b..3e1b04681 100644 --- a/src/pipecat/services/mcp_service.py +++ b/src/pipecat/services/mcp_service.py @@ -12,6 +12,7 @@ try: from mcp.client.session_group import SseServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client + from mcp.client.streamable_http import streamablehttp_client except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error("In order to use an MCP client, you need to `pip install pipecat-ai[mcp]`.") @@ -27,12 +28,17 @@ class MCPClient(BaseObject): super().__init__(**kwargs) self._server_params = server_params self._session = ClientSession + self._additional_headers = additional_headers or {} + if isinstance(server_params, StdioServerParameters): self._client = stdio_client self._register_tools = self._stdio_register_tools elif isinstance(server_params, SseServerParameters): self._client = sse_client self._register_tools = self._sse_register_tools + elif isinstance(server_params, str) and streamable_http: + self._client = streamablehttp_client + self._register_tools = self._streamable_http_register_tools else: raise TypeError( f"{self} invalid argument type: `server_params` must be either StdioServerParameters or SseServerParameters." @@ -156,6 +162,44 @@ class MCPClient(BaseObject): tools_schema = await self._list_tools(session, mcp_tool_wrapper, llm) return tools_schema + async def _streamable_http_register_tools(self, llm) -> ToolsSchema: + """Register all available mcp.run tools with the LLM service using streamable HTTP. + Args: + llm: The Pipecat LLM service to register tools with + Returns: + A ToolsSchema containing all registered tools + """ + + async def mcp_tool_wrapper( + function_name: str, + tool_call_id: str, + arguments: Dict[str, Any], + llm: any, + context: any, + result_callback: any, + ) -> None: + """Wrapper for mcp.run tool calls to match Pipecat's function call interface.""" + logger.debug(f"Executing tool '{function_name}' with call ID: {tool_call_id}") + logger.trace(f"Tool arguments: {json.dumps(arguments, indent=2)}") + try: + async with self._client(self._server_params, headers=self._additional_headers) as (read_stream, write_stream, _): + async with self._session(read_stream, write_stream) as session: + await session.initialize() + await self._call_tool(session, function_name, arguments, result_callback) + except Exception as e: + error_msg = f"Error calling mcp tool {function_name}: {str(e)}" + logger.error(error_msg) + logger.exception("Full exception details:") + await result_callback(error_msg) + + logger.debug("Starting registration of mcp.run tools using streamable HTTP") + + async with self._client(self._server_params, headers=self._additional_headers) as (read_stream, write_stream, _): + async with self._session(read_stream, write_stream) as session: + await session.initialize() + tools_schema = await self._list_tools(session, mcp_tool_wrapper, llm) + return tools_schema + async def _call_tool(self, session, function_name, arguments, result_callback): logger.debug(f"Calling mcp tool '{function_name}'") try: