|
@@ -7,7 +7,8 @@ from langgraph.types import Send
|
|
|
from langgraph.graph import StateGraph
|
|
|
from langgraph.graph import START, END
|
|
|
from langchain_core.runnables import RunnableConfig
|
|
|
-from google.genai import Client
|
|
|
+from google import genai
|
|
|
+from google.genai.types import HttpOptions
|
|
|
|
|
|
from agent.state import (
|
|
|
OverallState,
|
|
@@ -23,13 +24,15 @@ from agent.prompts import (
|
|
|
reflection_instructions,
|
|
|
answer_instructions,
|
|
|
)
|
|
|
-from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
from agent.utils import (
|
|
|
get_citations,
|
|
|
get_research_topic,
|
|
|
insert_citation_markers,
|
|
|
resolve_urls,
|
|
|
)
|
|
|
+
|
|
|
+from langchain_google_vertexai import ChatVertexAI
|
|
|
+
|
|
|
import logging
|
|
|
|
|
|
load_dotenv()
|
|
@@ -41,12 +44,21 @@ logging.basicConfig(
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
-if os.getenv("GEMINI_API_KEY") is None:
|
|
|
- raise ValueError("GEMINI_API_KEY is not set")
|
|
|
+# if os.getenv("GEMINI_API_KEY") is None:
|
|
|
+# raise ValueError("GEMINI_API_KEY is not set")
|
|
|
|
|
|
-# Used for Google Search API
|
|
|
-genai_client = Client(api_key=os.getenv("GEMINI_API_KEY"))
|
|
|
+if os.getenv("GOOGLE_APPLICATION_CREDENTIALS") is None:
|
|
|
+ raise ValueError("GOOGLE_APPLICATION_CREDENTIALS is not set")
|
|
|
+if os.getenv("GOOGLE_CLOUD_PROJECT") is None:
|
|
|
+ raise ValueError("GOOGLE_CLOUD_PROJECT is not set")
|
|
|
+if os.getenv("GOOGLE_CLOUD_LOCATION") is None:
|
|
|
+ raise ValueError("GOOGLE_CLOUD_LOCATION is not set")
|
|
|
+if os.getenv("GOOGLE_GENAI_USE_VERTEXAI") is None:
|
|
|
+ raise ValueError("GOOGLE_GENAI_USE_VERTEXAI is not set")
|
|
|
|
|
|
+# Used for Google Search API
|
|
|
+# genai_client = Client(api_key=os.getenv("GEMINI_API_KEY"))
|
|
|
+genai_client = genai.Client(http_options=HttpOptions(api_version="v1"))
|
|
|
|
|
|
# Nodes
|
|
|
def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerationState:
|
|
@@ -64,21 +76,25 @@ def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerati
|
|
|
"""
|
|
|
configurable = Configuration.from_runnable_config(config)
|
|
|
logger.info("开始:generate_query")
|
|
|
- logger.info("1:generate_query")
|
|
|
# check for custom initial search query count
|
|
|
if state.get("initial_search_query_count") is None:
|
|
|
state["initial_search_query_count"] = configurable.number_of_initial_queries
|
|
|
- logger.info("2:generate_query")
|
|
|
# init Gemini 2.0 Flash
|
|
|
- llm = ChatGoogleGenerativeAI(
|
|
|
- model=configurable.query_generator_model,
|
|
|
+ # llm = ChatGoogleGenerativeAI(
|
|
|
+ # model=configurable.query_generator_model,
|
|
|
+ # temperature=1.0,
|
|
|
+ # max_retries=2,
|
|
|
+ # api_key=os.getenv("GEMINI_API_KEY"),
|
|
|
+ # )
|
|
|
+
|
|
|
+ llm = ChatVertexAI(
|
|
|
+ model_name=configurable.query_generator_model, # 如 "gemini-pro"
|
|
|
temperature=1.0,
|
|
|
max_retries=2,
|
|
|
- api_key=os.getenv("GEMINI_API_KEY"),
|
|
|
+ project=os.getenv("GOOGLE_CLOUD_PROJECT"), # 必填
|
|
|
+ location=os.getenv("GOOGLE_CLOUD_LOCATION") # 必填
|
|
|
)
|
|
|
- logger.info("3:generate_query")
|
|
|
structured_llm = llm.with_structured_output(SearchQueryList)
|
|
|
- logger.info("4:generate_query")
|
|
|
# Format the prompt
|
|
|
current_date = get_current_date()
|
|
|
formatted_prompt = query_writer_instructions.format(
|
|
@@ -139,7 +155,7 @@ def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState:
|
|
|
# for chunk in chunks:
|
|
|
# print(chunk["title"])
|
|
|
# print(chunk["url"])
|
|
|
- # print(chunk["content"])
|
|
|
+ # print(chunk["content"])
|
|
|
# resolve the urls to short urls for saving tokens and time
|
|
|
resolved_urls = resolve_urls(
|
|
|
response.candidates[0].grounding_metadata.grounding_chunks, state["id"]
|
|
@@ -184,11 +200,18 @@ def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
|
|
|
summaries="\n\n---\n\n".join(state["web_research_result"]),
|
|
|
)
|
|
|
# init Reasoning Model
|
|
|
- llm = ChatGoogleGenerativeAI(
|
|
|
- model=reasoning_model,
|
|
|
+ # llm = ChatGoogleGenerativeAI(
|
|
|
+ # model=reasoning_model,
|
|
|
+ # temperature=1.0,
|
|
|
+ # max_retries=2,
|
|
|
+ # api_key=os.getenv("GEMINI_API_KEY"),
|
|
|
+ # )
|
|
|
+ llm = ChatVertexAI(
|
|
|
+ model_name=configurable.query_generator_model, # 如 "gemini-pro"
|
|
|
temperature=1.0,
|
|
|
max_retries=2,
|
|
|
- api_key=os.getenv("GEMINI_API_KEY"),
|
|
|
+ project=os.getenv("GOOGLE_CLOUD_PROJECT"), # 必填
|
|
|
+ location=os.getenv("GOOGLE_CLOUD_LOCATION") # 必填
|
|
|
)
|
|
|
result = llm.with_structured_output(Reflection).invoke(formatted_prompt)
|
|
|
logger.info("结束:reflection")
|
|
@@ -266,11 +289,18 @@ def finalize_answer(state: OverallState, config: RunnableConfig):
|
|
|
)
|
|
|
|
|
|
# init Reasoning Model, default to Gemini 2.5 Flash
|
|
|
- llm = ChatGoogleGenerativeAI(
|
|
|
- model=reasoning_model,
|
|
|
- temperature=0,
|
|
|
+ # llm = ChatGoogleGenerativeAI(
|
|
|
+ # model=reasoning_model,
|
|
|
+ # temperature=0,
|
|
|
+ # max_retries=2,
|
|
|
+ # api_key=os.getenv("GEMINI_API_KEY"),
|
|
|
+ # )
|
|
|
+ llm = ChatVertexAI(
|
|
|
+ model_name=configurable.query_generator_model, # 如 "gemini-pro"
|
|
|
+ temperature=1.0,
|
|
|
max_retries=2,
|
|
|
- api_key=os.getenv("GEMINI_API_KEY"),
|
|
|
+ project=os.getenv("GOOGLE_CLOUD_PROJECT"), # 必填
|
|
|
+ location=os.getenv("GOOGLE_CLOUD_LOCATION") # 必填
|
|
|
)
|
|
|
logger.info("开始:llm.invoke")
|
|
|
result = llm.invoke(formatted_prompt)
|