Update agent flow and frontend chat experience

This commit is contained in:
Xin Wang
2026-03-31 21:58:24 +08:00
parent e34e569de4
commit 0e2dafe440
13 changed files with 345 additions and 372 deletions

View File

@@ -1 +1,3 @@
# GEMINI_API_KEY=
# OPENAI_API_KEY=
# OPENAI_BASE_URL=
# TAVILY_API_KEY=

View File

@@ -21,7 +21,7 @@ def main() -> None:
)
parser.add_argument(
"--reasoning-model",
default="gemini-2.5-pro-preview-05-06",
default="openai/gpt-5.4",
help="Model for the final answer",
)
args = parser.parse_args()

View File

@@ -11,13 +11,13 @@ requires-python = ">=3.11,<4.0"
dependencies = [
"langgraph>=0.2.6",
"langchain>=0.3.19",
"langchain-google-genai",
"langchain-openai",
"python-dotenv>=1.0.1",
"langgraph-sdk>=0.1.57",
"langgraph-cli",
"langgraph-api",
"fastapi",
"google-genai",
"tavily-python",
]

View File

@@ -9,26 +9,33 @@ class Configuration(BaseModel):
"""The configuration for the agent."""
query_generator_model: str = Field(
default="gemini-2.0-flash",
default="openai/gpt-5.4",
metadata={
"description": "The name of the language model to use for the agent's query generation."
},
)
reflection_model: str = Field(
default="gemini-2.5-flash",
default="openai/gpt-5.4",
metadata={
"description": "The name of the language model to use for the agent's reflection."
},
)
answer_model: str = Field(
default="gemini-2.5-pro",
default="openai/gpt-5.4",
metadata={
"description": "The name of the language model to use for the agent's answer."
},
)
reasoning_model: str = Field(
default="openai/gpt-5.4",
metadata={
"description": "Fallback model used when the client does not provide a supported reasoning model."
},
)
number_of_initial_queries: int = Field(
default=3,
metadata={"description": "The number of initial search queries to generate."},

View File

@@ -1,91 +1,134 @@
import os
from typing import Any
from agent.tools_and_schemas import SearchQueryList, Reflection
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langgraph.types import Send
from langgraph.graph import StateGraph
from langgraph.graph import START, END
from langchain_core.runnables import RunnableConfig
from google.genai import Client
from agent.configuration import Configuration
from agent.prompts import (
answer_instructions,
get_current_date,
query_writer_instructions,
reflection_instructions,
)
from agent.state import (
OverallState,
QueryGenerationState,
ReflectionState,
WebSearchState,
)
from agent.configuration import Configuration
from agent.prompts import (
get_current_date,
query_writer_instructions,
web_searcher_instructions,
reflection_instructions,
answer_instructions,
)
from langchain_google_genai import ChatGoogleGenerativeAI
from agent.tools_and_schemas import Reflection, SearchQueryList
from agent.utils import (
get_citations,
format_sources_for_prompt,
get_research_topic,
insert_citation_markers,
resolve_urls,
normalize_model_name,
normalize_tavily_sources,
shorten_search_query,
)
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from langgraph.types import Send
from tavily import TavilyClient
load_dotenv()
if os.getenv("GEMINI_API_KEY") is None:
raise ValueError("GEMINI_API_KEY is not set")
# Used for Google Search API
genai_client = Client(api_key=os.getenv("GEMINI_API_KEY"))
TAVILY_TOPIC = "general"
TAVILY_SEARCH_DEPTH = "advanced"
TAVILY_MAX_RESULTS = 5
TAVILY_CHUNKS_PER_SOURCE = 3
def require_env(name: str) -> str:
"""Return an environment variable or raise a clear error."""
value = os.getenv(name)
if not value:
raise ValueError(f"{name} is not set")
return value
def get_chat_model(model_name: str, temperature: float = 0) -> ChatOpenAI:
"""Create the OpenAI chat model lazily for easier testing."""
base_url = os.getenv("OPENAI_BASE_URL")
return ChatOpenAI(
model=normalize_model_name(model_name),
api_key=require_env("OPENAI_API_KEY"),
base_url=base_url or None,
temperature=temperature,
)
def get_tavily_client() -> TavilyClient:
"""Create the Tavily client lazily for easier testing."""
return TavilyClient(api_key=require_env("TAVILY_API_KEY"))
def create_structured_response(
prompt: str,
model_name: str,
schema_model: type[Any],
temperature: float = 1,
) -> Any:
"""Call the LangChain OpenAI chat model and parse structured output."""
llm = get_chat_model(model_name, temperature=temperature)
return llm.with_structured_output(
schema_model,
method="function_calling",
).invoke(prompt)
def extract_text_content(content: Any) -> str:
"""Normalize LangChain message content into a string."""
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
text_chunks = []
for chunk in content:
if isinstance(chunk, dict) and chunk.get("type") == "text":
text_chunks.append(str(chunk.get("text", "")))
return "\n".join(text_chunks).strip()
return str(content).strip()
def create_text_response(prompt: str, model_name: str, temperature: float = 0) -> str:
"""Call the LangChain OpenAI chat model and return markdown output."""
response = get_chat_model(model_name, temperature=temperature).invoke(prompt)
text_output = extract_text_content(response.content)
if not text_output:
raise ValueError("OpenAI response did not include text output")
return text_output
# Nodes
def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerationState:
"""LangGraph node that generates search queries based on the User's question.
Uses Gemini 2.0 Flash to create an optimized search queries for web research based on
the User's question.
Args:
state: Current graph state containing the User's question
config: Configuration for the runnable, including LLM provider settings
Returns:
Dictionary with state update, including search_query key containing the generated queries
"""
"""Generate initial search queries from the user's request."""
configurable = Configuration.from_runnable_config(config)
reasoning_model = normalize_model_name(
state.get("reasoning_model"),
configurable.reasoning_model or configurable.query_generator_model,
)
# check for custom initial search query count
if state.get("initial_search_query_count") is None:
state["initial_search_query_count"] = configurable.number_of_initial_queries
# init Gemini 2.0 Flash
llm = ChatGoogleGenerativeAI(
model=configurable.query_generator_model,
temperature=1.0,
max_retries=2,
api_key=os.getenv("GEMINI_API_KEY"),
)
structured_llm = llm.with_structured_output(SearchQueryList)
# Format the prompt
current_date = get_current_date()
formatted_prompt = query_writer_instructions.format(
current_date=current_date,
current_date=get_current_date(),
research_topic=get_research_topic(state["messages"]),
number_queries=state["initial_search_query_count"],
)
# Generate the search queries
result = structured_llm.invoke(formatted_prompt)
result = create_structured_response(
formatted_prompt,
reasoning_model,
SearchQueryList,
temperature=1,
)
return {"search_query": result.query}
def continue_to_web_research(state: QueryGenerationState):
"""LangGraph node that sends the search queries to the web research node.
This is used to spawn n number of web research nodes, one for each search query.
"""
def continue_to_web_research(state: QueryGenerationState) -> list[Send]:
"""Fan out search queries into parallel web research nodes."""
return [
Send("web_research", {"search_query": search_query, "id": int(idx)})
for idx, search_query in enumerate(state["search_query"])
@@ -93,83 +136,63 @@ def continue_to_web_research(state: QueryGenerationState):
def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState:
"""LangGraph node that performs web research using the native Google Search API tool.
Executes a web search using the native Google Search API tool in combination with Gemini 2.0 Flash.
Args:
state: Current graph state containing the search query and research loop count
config: Configuration for the runnable, including search API settings
Returns:
Dictionary with state update, including sources_gathered, research_loop_count, and web_research_results
"""
# Configure
configurable = Configuration.from_runnable_config(config)
formatted_prompt = web_searcher_instructions.format(
current_date=get_current_date(),
research_topic=state["search_query"],
"""Execute Tavily search and return raw evidence for one query."""
search_query = shorten_search_query(state["search_query"])
tavily_response = get_tavily_client().search(
query=search_query,
topic=TAVILY_TOPIC,
search_depth=TAVILY_SEARCH_DEPTH,
chunks_per_source=TAVILY_CHUNKS_PER_SOURCE,
max_results=TAVILY_MAX_RESULTS,
include_answer=False,
include_raw_content=False,
)
sources = normalize_tavily_sources(tavily_response.get("results", []))
# Uses the google genai client as the langchain client doesn't return grounding metadata
response = genai_client.models.generate_content(
model=configurable.query_generator_model,
contents=formatted_prompt,
config={
"tools": [{"google_search": {}}],
"temperature": 0,
},
if not sources:
return {
"sources_gathered": [],
"search_query": [search_query],
"web_research_result": [
f'No Tavily results were returned for "{search_query}".'
],
}
evidence = "\n\n".join(
[
f"Search Query: {search_query}",
"Source Evidence:",
format_sources_for_prompt(sources),
]
)
# resolve the urls to short urls for saving tokens and time
resolved_urls = resolve_urls(
response.candidates[0].grounding_metadata.grounding_chunks, state["id"]
)
# Gets the citations and adds them to the generated text
citations = get_citations(response, resolved_urls)
modified_text = insert_citation_markers(response.text, citations)
sources_gathered = [item for citation in citations for item in citation["segments"]]
return {
"sources_gathered": sources_gathered,
"search_query": [state["search_query"]],
"web_research_result": [modified_text],
"sources_gathered": sources,
"search_query": [search_query],
"web_research_result": [evidence],
}
def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
"""LangGraph node that identifies knowledge gaps and generates potential follow-up queries.
Analyzes the current summary to identify areas for further research and generates
potential follow-up queries. Uses structured output to extract
the follow-up query in JSON format.
Args:
state: Current graph state containing the running summary and research topic
config: Configuration for the runnable, including LLM provider settings
Returns:
Dictionary with state update, including search_query key containing the generated follow-up query
"""
"""Decide whether additional research is needed."""
configurable = Configuration.from_runnable_config(config)
# Increment the research loop count and get the reasoning model
state["research_loop_count"] = state.get("research_loop_count", 0) + 1
reasoning_model = state.get("reasoning_model", configurable.reflection_model)
reasoning_model = normalize_model_name(
state.get("reasoning_model"),
configurable.reasoning_model,
)
# Format the prompt
current_date = get_current_date()
formatted_prompt = reflection_instructions.format(
current_date=current_date,
current_date=get_current_date(),
research_topic=get_research_topic(state["messages"]),
summaries="\n\n---\n\n".join(state["web_research_result"]),
)
# init Reasoning Model
llm = ChatGoogleGenerativeAI(
model=reasoning_model,
temperature=1.0,
max_retries=2,
api_key=os.getenv("GEMINI_API_KEY"),
result = create_structured_response(
formatted_prompt,
reasoning_model,
Reflection,
temperature=1,
)
result = llm.with_structured_output(Reflection).invoke(formatted_prompt)
return {
"is_sufficient": result.is_sufficient,
@@ -183,19 +206,8 @@ def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
def evaluate_research(
state: ReflectionState,
config: RunnableConfig,
) -> OverallState:
"""LangGraph routing function that determines the next step in the research flow.
Controls the research loop by deciding whether to continue gathering information
or to finalize the summary based on the configured maximum number of research loops.
Args:
state: Current graph state containing the research loop count
config: Configuration for the runnable, including max_research_loops setting
Returns:
String literal indicating the next node to visit ("web_research" or "finalize_summary")
"""
) -> OverallState | list[Send]:
"""Route either back into research or into final answer synthesis."""
configurable = Configuration.from_runnable_config(config)
max_research_loops = (
state.get("max_research_loops")
@@ -204,90 +216,53 @@ def evaluate_research(
)
if state["is_sufficient"] or state["research_loop_count"] >= max_research_loops:
return "finalize_answer"
else:
return [
Send(
"web_research",
{
"search_query": follow_up_query,
"id": state["number_of_ran_queries"] + int(idx),
},
)
for idx, follow_up_query in enumerate(state["follow_up_queries"])
]
return [
Send(
"web_research",
{
"search_query": follow_up_query,
"id": state["number_of_ran_queries"] + int(idx),
},
)
for idx, follow_up_query in enumerate(state["follow_up_queries"])
]
def finalize_answer(state: OverallState, config: RunnableConfig):
"""LangGraph node that finalizes the research summary.
Prepares the final output by deduplicating and formatting sources, then
combining them with the running summary to create a well-structured
research report with proper citations.
Args:
state: Current graph state containing the running summary and sources gathered
Returns:
Dictionary with state update, including running_summary key containing the formatted final summary with sources
"""
def finalize_answer(state: OverallState, config: RunnableConfig) -> OverallState:
"""Generate the final cited answer from the accumulated research summaries."""
configurable = Configuration.from_runnable_config(config)
reasoning_model = state.get("reasoning_model") or configurable.answer_model
reasoning_model = normalize_model_name(
state.get("reasoning_model"),
configurable.reasoning_model or configurable.answer_model,
)
# Format the prompt
current_date = get_current_date()
formatted_prompt = answer_instructions.format(
current_date=current_date,
current_date=get_current_date(),
research_topic=get_research_topic(state["messages"]),
summaries="\n---\n\n".join(state["web_research_result"]),
)
# init Reasoning Model, default to Gemini 2.5 Flash
llm = ChatGoogleGenerativeAI(
model=reasoning_model,
temperature=0,
max_retries=2,
api_key=os.getenv("GEMINI_API_KEY"),
)
result = llm.invoke(formatted_prompt)
# Replace the short urls with the original urls and add all used urls to the sources_gathered
unique_sources = []
for source in state["sources_gathered"]:
if source["short_url"] in result.content:
result.content = result.content.replace(
source["short_url"], source["value"]
)
unique_sources.append(source)
answer = create_text_response(formatted_prompt, reasoning_model, temperature=0)
return {
"messages": [AIMessage(content=result.content)],
"sources_gathered": unique_sources,
"messages": [AIMessage(content=answer)],
"sources_gathered": state.get("sources_gathered", []),
}
# Create our Agent Graph
builder = StateGraph(OverallState, config_schema=Configuration)
# Define the nodes we will cycle between
builder.add_node("generate_query", generate_query)
builder.add_node("web_research", web_research)
builder.add_node("reflection", reflection)
builder.add_node("finalize_answer", finalize_answer)
# Set the entrypoint as `generate_query`
# This means that this node is the first one called
builder.add_edge(START, "generate_query")
# Add conditional edge to continue with search queries in a parallel branch
builder.add_conditional_edges(
"generate_query", continue_to_web_research, ["web_research"]
)
# Reflect on the web research
builder.add_edge("web_research", "reflection")
# Evaluate the research
builder.add_conditional_edges(
"reflection", evaluate_research, ["web_research", "finalize_answer"]
)
# Finalize the answer
builder.add_edge("finalize_answer", END)
graph = builder.compile(name="pro-search-agent")

View File

@@ -15,6 +15,7 @@ Instructions:
- Queries should be diverse, if the topic is broad, generate more than 1 query.
- Don't generate multiple similar queries, 1 is enough.
- Query should ensure that the most current information is gathered. The current date is {current_date}.
- Keep every individual query under 400 characters. Prefer concise search terms over full-sentence restatements.
Format:
- Format your response as a JSON object with ALL two of these exact keys:
@@ -34,26 +35,31 @@ Topic: What revenue grew more last year apple stock or the number of people buyi
Context: {research_topic}"""
web_searcher_instructions = """Conduct targeted Google Searches to gather the most recent, credible information on "{research_topic}" and synthesize it into a verifiable text artifact.
web_searcher_instructions = """Review the provided Tavily search results for "{research_topic}" and synthesize them into a verifiable research note.
Instructions:
- Query should ensure that the most current information is gathered. The current date is {current_date}.
- Conduct multiple, diverse searches to gather comprehensive information.
- Consolidate key findings while meticulously tracking the source(s) for each specific piece of information.
- The output should be a well-written summary or report based on your search findings.
- Only include the information found in the search results, don't make up any information.
- Use only the provided Tavily search results. Do not invent or infer facts that are not supported by those results.
- Consolidate the key findings into a concise research note for this query.
- Every factual paragraph or bullet must include at least one markdown citation using the exact title and URL from the provided sources, for example: [Reuters](https://www.reuters.com/example).
- Preserve full URLs in citations. Do not use placeholders, IDs, or shortened URLs.
- If the results are insufficient, explicitly say what is missing and cite the closest available sources.
Research Topic:
{research_topic}
Search Results:
{search_results}
"""
reflection_instructions = """You are an expert research assistant analyzing summaries about "{research_topic}".
reflection_instructions = """You are an expert research assistant analyzing collected research evidence about "{research_topic}".
Instructions:
- Identify knowledge gaps or areas that need deeper exploration and generate a follow-up query. (1 or multiple).
- If provided summaries are sufficient to answer the user's question, don't generate a follow-up query.
- If the provided evidence is sufficient to answer the user's question, don't generate a follow-up query.
- If there is a knowledge gap, generate a follow-up query that would help expand your understanding.
- Focus on technical details, implementation specifics, or emerging trends that weren't fully covered.
- Keep every follow-up query under 400 characters. Prefer compact search phrases over long natural-language questions.
Requirements:
- Ensure the follow-up query is self-contained and includes necessary context for web search.
@@ -73,24 +79,26 @@ Example:
}}
```
Reflect carefully on the Summaries to identify knowledge gaps and produce a follow-up query. Then, produce your output following this JSON format:
Reflect carefully on the Research Evidence to identify knowledge gaps and produce a follow-up query. Then, produce your output following this JSON format:
Summaries:
Research Evidence:
{summaries}
"""
answer_instructions = """Generate a high-quality answer to the user's question based on the provided summaries.
answer_instructions = """Generate a high-quality answer to the user's question based on the provided research evidence.
Instructions:
- The current date is {current_date}.
- You are the final step of a multi-step research process, don't mention that you are the final step.
- You have access to all the information gathered from the previous steps.
- You have access to the user's question.
- Generate a high-quality answer to the user's question based on the provided summaries and the user's question.
- Include the sources you used from the Summaries in the answer correctly, use markdown format (e.g. [apnews](https://vertexaisearch.cloud.google.com/id/1-0)). THIS IS A MUST.
- Generate a high-quality answer to the user's question based on the provided research evidence and the user's question.
- Use only the evidence present in the Research Evidence.
- Preserve or reuse source citations from the Research Evidence as full markdown links. Every factual paragraph should include at least one citation.
- Do not invent sources, placeholders, or shortened URLs.
User Context:
- {research_topic}
Summaries:
Research Evidence:
{summaries}"""

View File

@@ -29,13 +29,8 @@ class ReflectionState(TypedDict):
number_of_ran_queries: int
class Query(TypedDict):
query: str
rationale: str
class QueryGenerationState(TypedDict):
search_query: list[Query]
search_query: list[str]
class WebSearchState(TypedDict):

View File

@@ -1,166 +1,106 @@
from typing import Any, Dict, List
from langchain_core.messages import AnyMessage, AIMessage, HumanMessage
from __future__ import annotations
import re
from typing import Any, List
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
DEFAULT_REASONING_MODEL = "openai/gpt-5.4"
SUPPORTED_MODEL_PREFIXES = ("gpt-", "deepseek-", "glm-")
MAX_TAVILY_QUERY_LENGTH = 380
def get_research_topic(messages: List[AnyMessage]) -> str:
"""
Get the research topic from the messages.
"""
# check if request has a history and combine the messages into a single string
"""Get the research topic from the messages."""
if len(messages) == 1:
research_topic = messages[-1].content
else:
research_topic = ""
for message in messages:
if isinstance(message, HumanMessage):
research_topic += f"User: {message.content}\n"
elif isinstance(message, AIMessage):
research_topic += f"Assistant: {message.content}\n"
return str(messages[-1].content)
research_topic = ""
for message in messages:
if isinstance(message, HumanMessage):
research_topic += f"User: {message.content}\n"
elif isinstance(message, AIMessage):
research_topic += f"Assistant: {message.content}\n"
return research_topic
def resolve_urls(urls_to_resolve: List[Any], id: int) -> Dict[str, str]:
"""
Create a map of the vertex ai search urls (very long) to a short url with a unique id for each url.
Ensures each original URL gets a consistent shortened form while maintaining uniqueness.
"""
prefix = f"https://vertexaisearch.cloud.google.com/id/"
urls = [site.web.uri for site in urls_to_resolve]
def normalize_model_name(
model_name: str | None,
fallback: str = DEFAULT_REASONING_MODEL,
) -> str:
"""Normalize stale or unsupported model names to a supported default."""
candidate = (model_name or "").strip()
if candidate.startswith(SUPPORTED_MODEL_PREFIXES):
return candidate
# Create a dictionary that maps each unique URL to its first occurrence index
resolved_map = {}
for idx, url in enumerate(urls):
if url not in resolved_map:
resolved_map[url] = f"{prefix}{id}-{idx}"
provider, _, model = candidate.partition("/")
if provider and model.startswith(SUPPORTED_MODEL_PREFIXES):
return candidate
return resolved_map
return fallback
def insert_citation_markers(text, citations_list):
"""
Inserts citation markers into a text string based on start and end indices.
def _normalize_source_title(source: dict[str, Any], index: int) -> str:
title = (source.get("title") or "").strip()
return title or f"Source {index}"
Args:
text (str): The original text string.
citations_list (list): A list of dictionaries, where each dictionary
contains 'start_index', 'end_index', and
'segment_string' (the marker to insert).
Indices are assumed to be for the original text.
Returns:
str: The text with citation markers inserted.
"""
# Sort citations by end_index in descending order.
# If end_index is the same, secondary sort by start_index descending.
# This ensures that insertions at the end of the string don't affect
# the indices of earlier parts of the string that still need to be processed.
sorted_citations = sorted(
citations_list, key=lambda c: (c["end_index"], c["start_index"]), reverse=True
)
def normalize_tavily_sources(results: list[dict[str, Any]]) -> list[dict[str, str]]:
"""Normalize Tavily results into a stable source shape."""
normalized_sources: list[dict[str, str]] = []
seen_urls: set[str] = set()
modified_text = text
for citation_info in sorted_citations:
# These indices refer to positions in the *original* text,
# but since we iterate from the end, they remain valid for insertion
# relative to the parts of the string already processed.
end_idx = citation_info["end_index"]
marker_to_insert = ""
for segment in citation_info["segments"]:
marker_to_insert += f" [{segment['label']}]({segment['short_url']})"
# Insert the citation marker at the original end_idx position
modified_text = (
modified_text[:end_idx] + marker_to_insert + modified_text[end_idx:]
for index, result in enumerate(results, start=1):
url = (result.get("url") or "").strip()
if not url or url in seen_urls:
continue
normalized_sources.append(
{
"label": f"S{len(normalized_sources) + 1}",
"title": _normalize_source_title(result, index),
"value": url,
"content": (result.get("content") or "").strip(),
}
)
seen_urls.add(url)
return normalized_sources
def format_sources_for_prompt(sources: list[dict[str, str]]) -> str:
"""Format normalized search results for prompt injection."""
if not sources:
return "No search results were returned."
formatted_sources = []
for source in sources:
snippet = source.get("content") or "No snippet available."
formatted_sources.append(
"\n".join(
[
f"{source['label']}: {source['title']}",
f"URL: {source['value']}",
f"Snippet: {snippet}",
]
)
)
return modified_text
return "\n\n".join(formatted_sources)
def get_citations(response, resolved_urls_map):
"""
Extracts and formats citation information from a Gemini model's response.
def shorten_search_query(
query: str,
max_length: int = MAX_TAVILY_QUERY_LENGTH,
) -> str:
"""Normalize and trim search queries to fit Tavily's length limit."""
normalized_query = re.sub(r"\s+", " ", query).strip()
if len(normalized_query) <= max_length:
return normalized_query
This function processes the grounding metadata provided in the response to
construct a list of citation objects. Each citation object includes the
start and end indices of the text segment it refers to, and a string
containing formatted markdown links to the supporting web chunks.
truncated_query = normalized_query[:max_length]
last_space = truncated_query.rfind(" ")
if last_space > max_length // 2:
truncated_query = truncated_query[:last_space]
Args:
response: The response object from the Gemini model, expected to have
a structure including `candidates[0].grounding_metadata`.
It also relies on a `resolved_map` being available in its
scope to map chunk URIs to resolved URLs.
Returns:
list: A list of dictionaries, where each dictionary represents a citation
and has the following keys:
- "start_index" (int): The starting character index of the cited
segment in the original text. Defaults to 0
if not specified.
- "end_index" (int): The character index immediately after the
end of the cited segment (exclusive).
- "segments" (list[str]): A list of individual markdown-formatted
links for each grounding chunk.
- "segment_string" (str): A concatenated string of all markdown-
formatted links for the citation.
Returns an empty list if no valid candidates or grounding supports
are found, or if essential data is missing.
"""
citations = []
# Ensure response and necessary nested structures are present
if not response or not response.candidates:
return citations
candidate = response.candidates[0]
if (
not hasattr(candidate, "grounding_metadata")
or not candidate.grounding_metadata
or not hasattr(candidate.grounding_metadata, "grounding_supports")
):
return citations
for support in candidate.grounding_metadata.grounding_supports:
citation = {}
# Ensure segment information is present
if not hasattr(support, "segment") or support.segment is None:
continue # Skip this support if segment info is missing
start_index = (
support.segment.start_index
if support.segment.start_index is not None
else 0
)
# Ensure end_index is present to form a valid segment
if support.segment.end_index is None:
continue # Skip if end_index is missing, as it's crucial
# Add 1 to end_index to make it an exclusive end for slicing/range purposes
# (assuming the API provides an inclusive end_index)
citation["start_index"] = start_index
citation["end_index"] = support.segment.end_index
citation["segments"] = []
if (
hasattr(support, "grounding_chunk_indices")
and support.grounding_chunk_indices
):
for ind in support.grounding_chunk_indices:
try:
chunk = candidate.grounding_metadata.grounding_chunks[ind]
resolved_url = resolved_urls_map.get(chunk.web.uri, None)
citation["segments"].append(
{
"label": chunk.web.title.split(".")[:-1][0],
"short_url": resolved_url,
"value": chunk.web.uri,
}
)
except (IndexError, AttributeError, NameError):
# Handle cases where chunk, web, uri, or resolved_map might be problematic
# For simplicity, we'll just skip adding this particular segment link
# In a production system, you might want to log this.
pass
citations.append(citation)
return citations
return truncated_query.rstrip(" ,;:-")