fixed MCPClient to reuse session across tool calls
This commit is contained in:
@@ -146,90 +146,77 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
rijksmuseum_mcp = MCPClient(
|
||||
async with (
|
||||
MCPClient(
|
||||
server_params=StdioServerParameters(
|
||||
command=shutil.which("npx"),
|
||||
# https://github.com/r-huijts/rijksmuseum-mcp
|
||||
args=["-y", "mcp-server-rijksmuseum"],
|
||||
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up rijksmuseum mcp")
|
||||
logger.exception("error trace:")
|
||||
try:
|
||||
) as rijksmuseum_mcp,
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
github_mcp = MCPClient(
|
||||
MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={
|
||||
"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp.run")
|
||||
logger.exception("error trace:")
|
||||
|
||||
rijksmuseum_tools = {}
|
||||
github_tools = {}
|
||||
try:
|
||||
),
|
||||
) as github_mcp,
|
||||
):
|
||||
rijksmuseum_tools = await rijksmuseum_mcp.register_tools(llm)
|
||||
github_tools = await github_mcp.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
all_standard_tools = rijksmuseum_tools.standard_tools + github_tools.standard_tools
|
||||
all_tools = ToolsSchema(standard_tools=all_standard_tools)
|
||||
all_standard_tools = rijksmuseum_tools.standard_tools + github_tools.standard_tools
|
||||
all_tools = ToolsSchema(standard_tools=all_standard_tools)
|
||||
|
||||
context = LLMContext(tools=all_tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
mcp_image_processor = UrlToImageProcessor(aiohttp_session=session)
|
||||
context = LLMContext(tools=all_tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
mcp_image_processor = UrlToImageProcessor(aiohttp_session=session)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
mcp_image_processor, # URL image -> output
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
mcp_image_processor, # URL image -> output
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
|
||||
@@ -162,73 +162,63 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
mcp = MCPClient(
|
||||
server_params=StdioServerParameters(
|
||||
command=shutil.which("npx"),
|
||||
# https://github.com/r-huijts/rijksmuseum-mcp
|
||||
args=["-y", "mcp-server-rijksmuseum"],
|
||||
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
|
||||
),
|
||||
# Optional
|
||||
tools_filter=mcp_tools_filter, # Optional
|
||||
tools_output_filters={"open_image_in_browser": open_image_output_filter},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp")
|
||||
logger.exception("error trace:")
|
||||
|
||||
mcp_image = UrlToImageProcessor(aiohttp_session=session)
|
||||
|
||||
tools = {}
|
||||
try:
|
||||
tools = await mcp.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
context = LLMContext(tools=tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
mcp_image, # URL image -> output
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
async with MCPClient(
|
||||
server_params=StdioServerParameters(
|
||||
command=shutil.which("npx"),
|
||||
# https://github.com/r-huijts/rijksmuseum-mcp
|
||||
args=["-y", "mcp-server-rijksmuseum"],
|
||||
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
# Optional
|
||||
tools_filter=mcp_tools_filter, # Optional
|
||||
tools_output_filters={"open_image_in_browser": open_image_output_filter},
|
||||
) as mcp:
|
||||
tools = await mcp.register_tools(llm)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
context = LLMContext(tools=tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
mcp_image, # URL image -> output
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
await runner.run(task)
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
|
||||
@@ -63,28 +63,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
mcp = MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp")
|
||||
logger.exception("error trace:")
|
||||
|
||||
tools = {}
|
||||
try:
|
||||
tools = await mcp.get_tools_schema()
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
system = f"""
|
||||
You are a helpful LLM in a voice call.
|
||||
Your goal is to answer questions about the user's GitHub repositories and account.
|
||||
@@ -94,53 +72,65 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
Just respond with short sentences when you are carrying out tool calls.
|
||||
"""
|
||||
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=system,
|
||||
tools=tools,
|
||||
)
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
async with MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"},
|
||||
)
|
||||
) as mcp:
|
||||
tools = await mcp.get_tools_schema()
|
||||
|
||||
await mcp.register_tools_schema(tools, llm)
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=system,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
context = LLMContext([{"role": "developer", "content": "Please introduce yourself."}])
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
await mcp.register_tools_schema(tools, llm)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
context = LLMContext([{"role": "user", "content": "Please introduce yourself."}])
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
await runner.run(task)
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
|
||||
@@ -77,69 +77,59 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
system_instruction=system_prompt,
|
||||
)
|
||||
|
||||
try:
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
mcp = MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"},
|
||||
)
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
async with MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp")
|
||||
logger.exception("error trace:")
|
||||
|
||||
tools = {}
|
||||
try:
|
||||
) as mcp:
|
||||
tools = await mcp.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
context = LLMContext(tools=tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
context = LLMContext(tools=tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
"""MCP (Model Context Protocol) client for integrating external tools with LLMs."""
|
||||
|
||||
import json
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeAlias
|
||||
|
||||
from loguru import logger
|
||||
@@ -36,8 +37,14 @@ class MCPClient(BaseObject):
|
||||
"""Client for Model Context Protocol (MCP) servers.
|
||||
|
||||
Enables integration with MCP servers to provide external tools and resources
|
||||
to LLMs. Supports both stdio and SSE server connections with automatic tool
|
||||
registration and schema conversion.
|
||||
to LLMs. Supports stdio, SSE, and streamable HTTP server connections with
|
||||
automatic tool registration and schema conversion.
|
||||
|
||||
The client maintains a persistent connection to the MCP server. It must
|
||||
be used as an async context manager or explicitly started and closed::
|
||||
|
||||
async with MCPClient(server_params=...) as mcp:
|
||||
tools = await mcp.register_tools(llm)
|
||||
|
||||
Raises:
|
||||
TypeError: If server_params is not a supported parameter type.
|
||||
@@ -53,7 +60,7 @@ class MCPClient(BaseObject):
|
||||
"""Initialize the MCP client with server parameters.
|
||||
|
||||
Args:
|
||||
server_params: Server connection parameters (stdio or SSE).
|
||||
server_params: Server connection parameters (stdio, SSE, or streamable HTTP).
|
||||
tools_filter: Optional list of tool names to register. If None, all tools are registered.
|
||||
tools_output_filters: Optional dict mapping tool names to filter functions that process tool outputs.
|
||||
Each filter function receives the raw tool output (any type) and returns the processed output (any type).
|
||||
@@ -61,31 +68,84 @@ class MCPClient(BaseObject):
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._server_params = server_params
|
||||
self._session = ClientSession
|
||||
self._tools_filter = tools_filter
|
||||
self._tools_output_filters = tools_output_filters or {}
|
||||
self._exit_stack: Optional[AsyncExitStack] = None
|
||||
self._active_session: Optional[ClientSession] = None
|
||||
|
||||
if isinstance(server_params, StdioServerParameters):
|
||||
self._client = stdio_client
|
||||
self._list_tools = self._stdio_list_tools
|
||||
self._tool_wrapper = self._stdio_tool_wrapper
|
||||
elif isinstance(server_params, SseServerParameters):
|
||||
self._client = sse_client
|
||||
self._list_tools = self._sse_list_tools
|
||||
self._tool_wrapper = self._sse_tool_wrapper
|
||||
elif isinstance(server_params, StreamableHttpParameters):
|
||||
self._client = streamablehttp_client
|
||||
self._list_tools = self._streamable_http_list_tools
|
||||
self._tool_wrapper = self._streamable_http_tool_wrapper
|
||||
else:
|
||||
if not isinstance(
|
||||
server_params,
|
||||
(StdioServerParameters, SseServerParameters, StreamableHttpParameters),
|
||||
):
|
||||
raise TypeError(
|
||||
f"{self} invalid argument type: `server_params` must be either StdioServerParameters, SseServerParameters, or StreamableHttpParameters."
|
||||
f"{self} invalid argument type: `server_params` must be either "
|
||||
"StdioServerParameters, SseServerParameters, or StreamableHttpParameters."
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start a persistent connection to the MCP server.
|
||||
|
||||
Opens the transport and initializes the MCP session. The session
|
||||
is reused for all subsequent tool calls and schema requests until
|
||||
close() is called.
|
||||
|
||||
Can also be used via async context manager::
|
||||
|
||||
async with MCPClient(server_params=...) as mcp:
|
||||
...
|
||||
"""
|
||||
if self._active_session:
|
||||
return
|
||||
|
||||
# We manage the exit stack manually (not via `async with`) so we can
|
||||
# clean up partial resources on failure before assigning to self.
|
||||
exit_stack = AsyncExitStack()
|
||||
await exit_stack.__aenter__()
|
||||
|
||||
try:
|
||||
if isinstance(self._server_params, StdioServerParameters):
|
||||
streams = await exit_stack.enter_async_context(stdio_client(self._server_params))
|
||||
read_stream, write_stream = streams[0], streams[1]
|
||||
elif isinstance(self._server_params, SseServerParameters):
|
||||
read_stream, write_stream = await exit_stack.enter_async_context(
|
||||
sse_client(**self._server_params.model_dump())
|
||||
)
|
||||
else: # StreamableHttpParameters (validated in __init__)
|
||||
read_stream, write_stream, _ = await exit_stack.enter_async_context(
|
||||
streamablehttp_client(**self._server_params.model_dump())
|
||||
)
|
||||
|
||||
session = await exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
|
||||
await session.initialize()
|
||||
|
||||
self._exit_stack = exit_stack
|
||||
self._active_session = session
|
||||
|
||||
except Exception:
|
||||
await exit_stack.aclose()
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the persistent MCP connection.
|
||||
|
||||
Safe to call multiple times or without having called start().
|
||||
"""
|
||||
self._active_session = None
|
||||
if self._exit_stack:
|
||||
await self._exit_stack.aclose()
|
||||
self._exit_stack = None
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def register_tools(self, llm: LLMService | LLMSwitcher) -> ToolsSchema:
|
||||
"""Register all available MCP tools with an LLM service.
|
||||
|
||||
Connects to the MCP server, discovers available tools, converts their
|
||||
Discovers available tools from the active session, converts their
|
||||
schemas to Pipecat format, and registers them with the LLM service.
|
||||
|
||||
This is the equivalent of calling get_tools_schema() followed by
|
||||
@@ -101,18 +161,26 @@ class MCPClient(BaseObject):
|
||||
await self.register_tools_schema(tools_schema, llm)
|
||||
return tools_schema
|
||||
|
||||
def _ensure_connected(self) -> ClientSession:
|
||||
"""Return the active session or raise if not connected."""
|
||||
if not self._active_session:
|
||||
raise RuntimeError(
|
||||
"MCPClient is not connected. Use 'async with MCPClient(...) as mcp:' "
|
||||
"or call 'await mcp.start()' before using MCPClient."
|
||||
)
|
||||
return self._active_session
|
||||
|
||||
async def get_tools_schema(self) -> ToolsSchema:
|
||||
"""Get the schema of all available MCP tools without registering them.
|
||||
|
||||
Connects to the MCP server, discovers available tools, and converts their
|
||||
schemas to Pipecat format.
|
||||
Requires the client to be started via start() or async with.
|
||||
|
||||
Returns:
|
||||
A ToolsSchema containing all available tools. This can be used for
|
||||
subsequent registration using register_tools_schema().
|
||||
"""
|
||||
tools_schema = await self._list_tools()
|
||||
return tools_schema
|
||||
session = self._ensure_connected()
|
||||
return await self._list_tools_helper(session)
|
||||
|
||||
async def register_tools_schema(
|
||||
self, tools_schema: ToolsSchema, llm: LLMService | LLMSwitcher
|
||||
@@ -154,107 +222,21 @@ class MCPClient(BaseObject):
|
||||
|
||||
return schema
|
||||
|
||||
async def _sse_list_tools(self) -> ToolsSchema:
|
||||
"""List all available mcp tools with the LLM service.
|
||||
|
||||
Returns:
|
||||
A ToolsSchema containing all registered tools
|
||||
"""
|
||||
logger.debug(f"SSE server parameters: {self._server_params}")
|
||||
logger.debug(f"Starting reading mcp tools")
|
||||
|
||||
async with self._client(**self._server_params.model_dump()) as (read, write):
|
||||
async with self._session(read, write) as session:
|
||||
await session.initialize()
|
||||
tools_schema = await self._list_tools_helper(session)
|
||||
return tools_schema
|
||||
|
||||
async def _sse_tool_wrapper(self, params: FunctionCallParams) -> None:
|
||||
"""Wrapper for mcp tool calls to match Pipecat's function call interface."""
|
||||
async def _tool_wrapper(self, params: FunctionCallParams) -> None:
|
||||
"""Execute an MCP tool call using the persistent session."""
|
||||
session = self._ensure_connected()
|
||||
logger.debug(f"Executing tool '{params.function_name}' with call ID: {params.tool_call_id}")
|
||||
logger.trace(f"Tool arguments: {json.dumps(params.arguments, indent=2)}")
|
||||
try:
|
||||
async with self._client(**self._server_params.model_dump()) as (read, write):
|
||||
async with self._session(read, write) as session:
|
||||
await session.initialize()
|
||||
await self._call_tool(
|
||||
session, params.function_name, params.arguments, params.result_callback
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
async def _stdio_list_tools(self) -> ToolsSchema:
|
||||
"""List all available mcp tools with the LLM service.
|
||||
|
||||
Returns:
|
||||
A ToolsSchema containing all available tools.
|
||||
"""
|
||||
logger.debug(f"Starting reading mcp tools")
|
||||
|
||||
async with self._client(self._server_params) as streams:
|
||||
async with self._session(streams[0], streams[1]) as session:
|
||||
await session.initialize()
|
||||
tools_schema = await self._list_tools_helper(session)
|
||||
return tools_schema
|
||||
|
||||
async def _stdio_tool_wrapper(self, params: FunctionCallParams) -> None:
|
||||
"""Wrapper for mcp tool calls to match Pipecat's function call interface."""
|
||||
logger.debug(f"Executing tool '{params.function_name}' with call ID: {params.tool_call_id}")
|
||||
logger.trace(f"Tool arguments: {json.dumps(params.arguments, indent=2)}")
|
||||
try:
|
||||
async with self._client(self._server_params) as streams:
|
||||
async with self._session(streams[0], streams[1]) as session:
|
||||
await session.initialize()
|
||||
await self._call_tool(
|
||||
session, params.function_name, params.arguments, params.result_callback
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
async def _streamable_http_list_tools(self) -> ToolsSchema:
|
||||
"""List all available mcp tools with the LLM service using streamable HTTP.
|
||||
|
||||
Returns:
|
||||
A ToolsSchema containing all available tools.
|
||||
"""
|
||||
logger.debug(f"Starting reading mcp tools using streamable HTTP")
|
||||
|
||||
async with self._client(**self._server_params.model_dump()) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with self._session(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
tools_schema = await self._list_tools_helper(session)
|
||||
return tools_schema
|
||||
|
||||
async def _streamable_http_tool_wrapper(self, params: FunctionCallParams) -> None:
|
||||
"""Wrapper for mcp tool calls to match Pipecat's function call interface."""
|
||||
logger.debug(f"Executing tool '{params.function_name}' with call ID: {params.tool_call_id}")
|
||||
logger.trace(f"Tool arguments: {json.dumps(params.arguments, indent=2)}")
|
||||
try:
|
||||
async with self._client(**self._server_params.model_dump()) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with self._session(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
await self._call_tool(
|
||||
session, params.function_name, params.arguments, params.result_callback
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
await params.result_callback(error_msg)
|
||||
await self._call_tool(
|
||||
session,
|
||||
params.function_name,
|
||||
params.arguments,
|
||||
params.result_callback,
|
||||
)
|
||||
|
||||
async def _call_tool(self, session, function_name, arguments, result_callback):
|
||||
logger.debug(f"Calling mcp tool '{function_name}'")
|
||||
results = None
|
||||
try:
|
||||
results = await session.call_tool(function_name, arguments=arguments)
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user