123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324 |
- import os
- from agent.tools_and_schemas import SearchQueryList, Reflection
- from dotenv import load_dotenv
- from langchain_core.messages import AIMessage
- from langgraph.types import Send
- from langgraph.graph import StateGraph
- from langgraph.graph import START, END
- from langchain_core.runnables import RunnableConfig
- from google.genai import Client
- from agent.state import (
- OverallState,
- QueryGenerationState,
- ReflectionState,
- WebSearchState,
- )
- from agent.configuration import Configuration
- from agent.prompts import (
- get_current_date,
- query_writer_instructions,
- web_searcher_instructions,
- reflection_instructions,
- answer_instructions,
- )
- from langchain_google_genai import ChatGoogleGenerativeAI
- from agent.utils import (
- get_citations,
- get_research_topic,
- insert_citation_markers,
- resolve_urls,
- )
- import logging
- load_dotenv()
- logging.basicConfig(
- level = logging.INFO, # 设置日志级别
- format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' # 设置日志格式
- )
- logger = logging.getLogger(__name__)
- if os.getenv("GEMINI_API_KEY") is None:
- raise ValueError("GEMINI_API_KEY is not set")
- # Used for Google Search API
- genai_client = Client(api_key=os.getenv("GEMINI_API_KEY"))
- # Nodes
- def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerationState:
- """LangGraph node that generates search queries based on the User's question.
- Uses Gemini 2.0 Flash to create an optimized search queries for web research based on
- the User's question.
- Args:
- state: Current graph state containing the User's question
- config: Configuration for the runnable, including LLM provider settings
- Returns:
- Dictionary with state update, including search_query key containing the generated queries
- """
- configurable = Configuration.from_runnable_config(config)
- logger.info("开始:generate_query")
- logger.info("1:generate_query")
- # check for custom initial search query count
- if state.get("initial_search_query_count") is None:
- state["initial_search_query_count"] = configurable.number_of_initial_queries
- logger.info("2:generate_query")
- # init Gemini 2.0 Flash
- llm = ChatGoogleGenerativeAI(
- model=configurable.query_generator_model,
- temperature=1.0,
- max_retries=2,
- api_key=os.getenv("GEMINI_API_KEY"),
- )
- logger.info("3:generate_query")
- structured_llm = llm.with_structured_output(SearchQueryList)
- logger.info("4:generate_query")
- # Format the prompt
- current_date = get_current_date()
- formatted_prompt = query_writer_instructions.format(
- current_date=current_date,
- research_topic=get_research_topic(state["messages"]),
- number_queries=state["initial_search_query_count"],
- )
- # Generate the search queries
- # print("formatted_prompt: ", formatted_prompt)
- result = structured_llm.invoke(formatted_prompt)
- logger.info("结束:generate_query")
- return {"search_query": result.query}
- def continue_to_web_research(state: QueryGenerationState):
- logger.info("开始:continue_to_web_research")
- """LangGraph node that sends the search queries to the web research node.
- This is used to spawn n number of web research nodes, one for each search query.
- """
- logger.info("开始:continue_to_web_research")
- return [
- Send("web_research", {"search_query": search_query, "id": int(idx)})
- for idx, search_query in enumerate(state["search_query"])
- ]
- def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState:
- """LangGraph node that performs web research using the native Google Search API tool.
- Executes a web search using the native Google Search API tool in combination with Gemini 2.0 Flash.
- Args:
- state: Current graph state containing the search query and research loop count
- config: Configuration for the runnable, including search API settings
- Returns:
- Dictionary with state update, including sources_gathered, research_loop_count, and web_research_results
- """
- logger.info("开始:web_research")
- # Configure
- configurable = Configuration.from_runnable_config(config)
- formatted_prompt = web_searcher_instructions.format(
- current_date=get_current_date(),
- research_topic=state["search_query"],
- )
- # Uses the google genai client as the langchain client doesn't return grounding metadata
- response = genai_client.models.generate_content(
- model=configurable.query_generator_model,
- contents=formatted_prompt,
- config={
- "tools": [{"google_search": {}}],
- "temperature": 0,
- },
- )
- # chunks = response.candidates[0].grounding_metadata.grounding_chunks
- # for chunk in chunks:
- # print(chunk["title"])
- # print(chunk["url"])
- # print(chunk["content"])
- # resolve the urls to short urls for saving tokens and time
- resolved_urls = resolve_urls(
- response.candidates[0].grounding_metadata.grounding_chunks, state["id"]
- )
- # Gets the citations and adds them to the generated text
- citations = get_citations(response, resolved_urls)
- modified_text = insert_citation_markers(response.text, citations)
- sources_gathered = [item for citation in citations for item in citation["segments"]]
- logger.info("结束:web_research")
- return {
- "sources_gathered": sources_gathered,
- "search_query": [state["search_query"]],
- "web_research_result": [modified_text],
- }
- def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
- logger.info("开始:reflection")
- """LangGraph node that identifies knowledge gaps and generates potential follow-up queries.
- Analyzes the current summary to identify areas for further research and generates
- potential follow-up queries. Uses structured output to extract
- the follow-up query in JSON format.
- Args:
- state: Current graph state containing the running summary and research topic
- config: Configuration for the runnable, including LLM provider settings
- Returns:
- Dictionary with state update, including search_query key containing the generated follow-up query
- """
- configurable = Configuration.from_runnable_config(config)
- # Increment the research loop count and get the reasoning model
- state["research_loop_count"] = state.get("research_loop_count", 0) + 1
- reasoning_model = state.get("reasoning_model", configurable.reflection_model)
- # Format the prompt
- current_date = get_current_date()
- formatted_prompt = reflection_instructions.format(
- current_date=current_date,
- research_topic=get_research_topic(state["messages"]),
- summaries="\n\n---\n\n".join(state["web_research_result"]),
- )
- # init Reasoning Model
- llm = ChatGoogleGenerativeAI(
- model=reasoning_model,
- temperature=1.0,
- max_retries=2,
- api_key=os.getenv("GEMINI_API_KEY"),
- )
- result = llm.with_structured_output(Reflection).invoke(formatted_prompt)
- logger.info("结束:reflection")
- return {
- "is_sufficient": result.is_sufficient,
- "knowledge_gap": result.knowledge_gap,
- "follow_up_queries": result.follow_up_queries,
- "research_loop_count": state["research_loop_count"],
- "number_of_ran_queries": len(state["search_query"]),
- }
- def evaluate_research(
- state: ReflectionState,
- config: RunnableConfig,
- ) -> OverallState:
- logger.info("开始:evaluate_research")
- """LangGraph routing function that determines the next step in the research flow.
- Controls the research loop by deciding whether to continue gathering information
- or to finalize the summary based on the configured maximum number of research loops.
- Args:
- state: Current graph state containing the research loop count
- config: Configuration for the runnable, including max_research_loops setting
- Returns:
- String literal indicating the next node to visit ("web_research" or "finalize_summary")
- """
- configurable = Configuration.from_runnable_config(config)
- max_research_loops = (
- state.get("max_research_loops")
- if state.get("max_research_loops") is not None
- else configurable.max_research_loops
- )
- logger.info("结束:evaluate_research")
- if state["is_sufficient"] or state["research_loop_count"] >= max_research_loops:
- return "finalize_answer"
- else:
- return [
- Send(
- "web_research",
- {
- "search_query": follow_up_query,
- "id": state["number_of_ran_queries"] + int(idx),
- },
- )
- for idx, follow_up_query in enumerate(state["follow_up_queries"])
- ]
- def finalize_answer(state: OverallState, config: RunnableConfig):
- logger.info("开始:finalize_answer")
- """LangGraph node that finalizes the research summary.
- Prepares the final output by deduplicating and formatting sources, then
- combining them with the running summary to create a well-structured
- research report with proper citations.
- Args:
- state: Current graph state containing the running summary and sources gathered
- Returns:
- Dictionary with state update, including running_summary key containing the formatted final summary with sources
- """
- configurable = Configuration.from_runnable_config(config)
- reasoning_model = state.get("reasoning_model") or configurable.answer_model
- # Format the prompt
- current_date = get_current_date()
- formatted_prompt = answer_instructions.format(
- current_date=current_date,
- research_topic=get_research_topic(state["messages"]),
- summaries="\n---\n\n".join(state["web_research_result"]),
- )
- # init Reasoning Model, default to Gemini 2.5 Flash
- llm = ChatGoogleGenerativeAI(
- model=reasoning_model,
- temperature=0,
- max_retries=2,
- api_key=os.getenv("GEMINI_API_KEY"),
- )
- logger.info("开始:llm.invoke")
- result = llm.invoke(formatted_prompt)
- logger.info("结束:llm.invoke:{}",result)
- # Replace the short urls with the original urls and add all used urls to the sources_gathered
- unique_sources = []
- for source in state["sources_gathered"]:
- if source["short_url"] in result.content:
- result.content = result.content.replace(
- source["short_url"], source["value"]
- )
- unique_sources.append(source)
- #save the result to a markdown file
- # with open(f"result_{get_research_topic(state['messages'])}.md", "w", encoding="utf-8") as f:
- # f.write(result.content)
- # print(f"Result saved to {f.name}")
- logger.info("结束:finalize_answer")
- return {
- "messages": [AIMessage(content=result.content)],
- "sources_gathered": unique_sources,
- #save the result to a markdown file
- }
- # Create our Agent Graph
- builder = StateGraph(OverallState, config_schema=Configuration)
- # Define the nodes we will cycle between
- builder.add_node("generate_query", generate_query)
- builder.add_node("web_research", web_research)
- builder.add_node("reflection", reflection)
- builder.add_node("finalize_answer", finalize_answer)
- # Set the entrypoint as `generate_query`
- # This means that this node is the first one called
- builder.add_edge(START, "generate_query")
- # Add conditional edge to continue with search queries in a parallel branch
- builder.add_conditional_edges(
- "generate_query", continue_to_web_research, ["web_research"]
- )
- # Reflect on the web research
- builder.add_edge("web_research", "reflection")
- # Evaluate the research
- builder.add_conditional_edges(
- "reflection", evaluate_research, ["web_research", "finalize_answer"]
- )
- # Finalize the answer
- builder.add_edge("finalize_answer", END)
- graph = builder.compile(name="pro-search-agent")
|