graph.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. import os
  2. from agent.tools_and_schemas import SearchQueryList, Reflection
  3. from dotenv import load_dotenv
  4. from langchain_core.messages import AIMessage
  5. from langgraph.types import Send
  6. from langgraph.graph import StateGraph
  7. from langgraph.graph import START, END
  8. from langchain_core.runnables import RunnableConfig
  9. from google.genai import Client
  10. from agent.state import (
  11. OverallState,
  12. QueryGenerationState,
  13. ReflectionState,
  14. WebSearchState,
  15. )
  16. from agent.configuration import Configuration
  17. from agent.prompts import (
  18. get_current_date,
  19. query_writer_instructions,
  20. web_searcher_instructions,
  21. reflection_instructions,
  22. answer_instructions,
  23. )
  24. from langchain_google_genai import ChatGoogleGenerativeAI
  25. from agent.utils import (
  26. get_citations,
  27. get_research_topic,
  28. insert_citation_markers,
  29. resolve_urls,
  30. )
  31. import logging
  32. load_dotenv()
  33. logging.basicConfig(
  34. level = logging.INFO, # 设置日志级别
  35. format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' # 设置日志格式
  36. )
  37. logger = logging.getLogger(__name__)
  38. if os.getenv("GEMINI_API_KEY") is None:
  39. raise ValueError("GEMINI_API_KEY is not set")
  40. # Used for Google Search API
  41. genai_client = Client(api_key=os.getenv("GEMINI_API_KEY"))
  42. # Nodes
  43. def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerationState:
  44. """LangGraph node that generates search queries based on the User's question.
  45. Uses Gemini 2.0 Flash to create an optimized search queries for web research based on
  46. the User's question.
  47. Args:
  48. state: Current graph state containing the User's question
  49. config: Configuration for the runnable, including LLM provider settings
  50. Returns:
  51. Dictionary with state update, including search_query key containing the generated queries
  52. """
  53. configurable = Configuration.from_runnable_config(config)
  54. logger.info("开始:generate_query")
  55. logger.info("1:generate_query")
  56. # check for custom initial search query count
  57. if state.get("initial_search_query_count") is None:
  58. state["initial_search_query_count"] = configurable.number_of_initial_queries
  59. logger.info("2:generate_query")
  60. # init Gemini 2.0 Flash
  61. llm = ChatGoogleGenerativeAI(
  62. model=configurable.query_generator_model,
  63. temperature=1.0,
  64. max_retries=2,
  65. api_key=os.getenv("GEMINI_API_KEY"),
  66. )
  67. logger.info("3:generate_query")
  68. structured_llm = llm.with_structured_output(SearchQueryList)
  69. logger.info("4:generate_query")
  70. # Format the prompt
  71. current_date = get_current_date()
  72. formatted_prompt = query_writer_instructions.format(
  73. current_date=current_date,
  74. research_topic=get_research_topic(state["messages"]),
  75. number_queries=state["initial_search_query_count"],
  76. )
  77. # Generate the search queries
  78. # print("formatted_prompt: ", formatted_prompt)
  79. result = structured_llm.invoke(formatted_prompt)
  80. logger.info("结束:generate_query")
  81. return {"search_query": result.query}
  82. def continue_to_web_research(state: QueryGenerationState):
  83. logger.info("开始:continue_to_web_research")
  84. """LangGraph node that sends the search queries to the web research node.
  85. This is used to spawn n number of web research nodes, one for each search query.
  86. """
  87. logger.info("开始:continue_to_web_research")
  88. return [
  89. Send("web_research", {"search_query": search_query, "id": int(idx)})
  90. for idx, search_query in enumerate(state["search_query"])
  91. ]
  92. def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState:
  93. """LangGraph node that performs web research using the native Google Search API tool.
  94. Executes a web search using the native Google Search API tool in combination with Gemini 2.0 Flash.
  95. Args:
  96. state: Current graph state containing the search query and research loop count
  97. config: Configuration for the runnable, including search API settings
  98. Returns:
  99. Dictionary with state update, including sources_gathered, research_loop_count, and web_research_results
  100. """
  101. logger.info("开始:web_research")
  102. # Configure
  103. configurable = Configuration.from_runnable_config(config)
  104. formatted_prompt = web_searcher_instructions.format(
  105. current_date=get_current_date(),
  106. research_topic=state["search_query"],
  107. )
  108. # Uses the google genai client as the langchain client doesn't return grounding metadata
  109. response = genai_client.models.generate_content(
  110. model=configurable.query_generator_model,
  111. contents=formatted_prompt,
  112. config={
  113. "tools": [{"google_search": {}}],
  114. "temperature": 0,
  115. },
  116. )
  117. # chunks = response.candidates[0].grounding_metadata.grounding_chunks
  118. # for chunk in chunks:
  119. # print(chunk["title"])
  120. # print(chunk["url"])
  121. # print(chunk["content"])
  122. # resolve the urls to short urls for saving tokens and time
  123. resolved_urls = resolve_urls(
  124. response.candidates[0].grounding_metadata.grounding_chunks, state["id"]
  125. )
  126. # Gets the citations and adds them to the generated text
  127. citations = get_citations(response, resolved_urls)
  128. modified_text = insert_citation_markers(response.text, citations)
  129. sources_gathered = [item for citation in citations for item in citation["segments"]]
  130. logger.info("结束:web_research")
  131. return {
  132. "sources_gathered": sources_gathered,
  133. "search_query": [state["search_query"]],
  134. "web_research_result": [modified_text],
  135. }
  136. def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
  137. logger.info("开始:reflection")
  138. """LangGraph node that identifies knowledge gaps and generates potential follow-up queries.
  139. Analyzes the current summary to identify areas for further research and generates
  140. potential follow-up queries. Uses structured output to extract
  141. the follow-up query in JSON format.
  142. Args:
  143. state: Current graph state containing the running summary and research topic
  144. config: Configuration for the runnable, including LLM provider settings
  145. Returns:
  146. Dictionary with state update, including search_query key containing the generated follow-up query
  147. """
  148. configurable = Configuration.from_runnable_config(config)
  149. # Increment the research loop count and get the reasoning model
  150. state["research_loop_count"] = state.get("research_loop_count", 0) + 1
  151. reasoning_model = state.get("reasoning_model", configurable.reflection_model)
  152. # Format the prompt
  153. current_date = get_current_date()
  154. formatted_prompt = reflection_instructions.format(
  155. current_date=current_date,
  156. research_topic=get_research_topic(state["messages"]),
  157. summaries="\n\n---\n\n".join(state["web_research_result"]),
  158. )
  159. # init Reasoning Model
  160. llm = ChatGoogleGenerativeAI(
  161. model=reasoning_model,
  162. temperature=1.0,
  163. max_retries=2,
  164. api_key=os.getenv("GEMINI_API_KEY"),
  165. )
  166. result = llm.with_structured_output(Reflection).invoke(formatted_prompt)
  167. logger.info("结束:reflection")
  168. return {
  169. "is_sufficient": result.is_sufficient,
  170. "knowledge_gap": result.knowledge_gap,
  171. "follow_up_queries": result.follow_up_queries,
  172. "research_loop_count": state["research_loop_count"],
  173. "number_of_ran_queries": len(state["search_query"]),
  174. }
  175. def evaluate_research(
  176. state: ReflectionState,
  177. config: RunnableConfig,
  178. ) -> OverallState:
  179. logger.info("开始:evaluate_research")
  180. """LangGraph routing function that determines the next step in the research flow.
  181. Controls the research loop by deciding whether to continue gathering information
  182. or to finalize the summary based on the configured maximum number of research loops.
  183. Args:
  184. state: Current graph state containing the research loop count
  185. config: Configuration for the runnable, including max_research_loops setting
  186. Returns:
  187. String literal indicating the next node to visit ("web_research" or "finalize_summary")
  188. """
  189. configurable = Configuration.from_runnable_config(config)
  190. max_research_loops = (
  191. state.get("max_research_loops")
  192. if state.get("max_research_loops") is not None
  193. else configurable.max_research_loops
  194. )
  195. logger.info("结束:evaluate_research")
  196. if state["is_sufficient"] or state["research_loop_count"] >= max_research_loops:
  197. return "finalize_answer"
  198. else:
  199. return [
  200. Send(
  201. "web_research",
  202. {
  203. "search_query": follow_up_query,
  204. "id": state["number_of_ran_queries"] + int(idx),
  205. },
  206. )
  207. for idx, follow_up_query in enumerate(state["follow_up_queries"])
  208. ]
  209. def finalize_answer(state: OverallState, config: RunnableConfig):
  210. logger.info("开始:finalize_answer")
  211. """LangGraph node that finalizes the research summary.
  212. Prepares the final output by deduplicating and formatting sources, then
  213. combining them with the running summary to create a well-structured
  214. research report with proper citations.
  215. Args:
  216. state: Current graph state containing the running summary and sources gathered
  217. Returns:
  218. Dictionary with state update, including running_summary key containing the formatted final summary with sources
  219. """
  220. configurable = Configuration.from_runnable_config(config)
  221. reasoning_model = state.get("reasoning_model") or configurable.answer_model
  222. # Format the prompt
  223. current_date = get_current_date()
  224. formatted_prompt = answer_instructions.format(
  225. current_date=current_date,
  226. research_topic=get_research_topic(state["messages"]),
  227. summaries="\n---\n\n".join(state["web_research_result"]),
  228. )
  229. # init Reasoning Model, default to Gemini 2.5 Flash
  230. llm = ChatGoogleGenerativeAI(
  231. model=reasoning_model,
  232. temperature=0,
  233. max_retries=2,
  234. api_key=os.getenv("GEMINI_API_KEY"),
  235. )
  236. logger.info("开始:llm.invoke")
  237. result = llm.invoke(formatted_prompt)
  238. logger.info("结束:llm.invoke:{}",result)
  239. # Replace the short urls with the original urls and add all used urls to the sources_gathered
  240. unique_sources = []
  241. for source in state["sources_gathered"]:
  242. if source["short_url"] in result.content:
  243. result.content = result.content.replace(
  244. source["short_url"], source["value"]
  245. )
  246. unique_sources.append(source)
  247. #save the result to a markdown file
  248. # with open(f"result_{get_research_topic(state['messages'])}.md", "w", encoding="utf-8") as f:
  249. # f.write(result.content)
  250. # print(f"Result saved to {f.name}")
  251. logger.info("结束:finalize_answer")
  252. return {
  253. "messages": [AIMessage(content=result.content)],
  254. "sources_gathered": unique_sources,
  255. #save the result to a markdown file
  256. }
  257. # Create our Agent Graph
  258. builder = StateGraph(OverallState, config_schema=Configuration)
  259. # Define the nodes we will cycle between
  260. builder.add_node("generate_query", generate_query)
  261. builder.add_node("web_research", web_research)
  262. builder.add_node("reflection", reflection)
  263. builder.add_node("finalize_answer", finalize_answer)
  264. # Set the entrypoint as `generate_query`
  265. # This means that this node is the first one called
  266. builder.add_edge(START, "generate_query")
  267. # Add conditional edge to continue with search queries in a parallel branch
  268. builder.add_conditional_edges(
  269. "generate_query", continue_to_web_research, ["web_research"]
  270. )
  271. # Reflect on the web research
  272. builder.add_edge("web_research", "reflection")
  273. # Evaluate the research
  274. builder.add_conditional_edges(
  275. "reflection", evaluate_research, ["web_research", "finalize_answer"]
  276. )
  277. # Finalize the answer
  278. builder.add_edge("finalize_answer", END)
  279. graph = builder.compile(name="pro-search-agent")