utils.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from typing import Any, Dict, List
  2. from langchain_core.messages import AnyMessage, AIMessage, HumanMessage
  3. def get_research_topic(messages: List[AnyMessage]) -> str:
  4. """
  5. Get the research topic from the messages.
  6. """
  7. # check if request has a history and combine the messages into a single string
  8. if len(messages) == 1:
  9. research_topic = messages[-1].content
  10. else:
  11. research_topic = ""
  12. for message in messages:
  13. if isinstance(message, HumanMessage):
  14. research_topic += f"User: {message.content}\n"
  15. elif isinstance(message, AIMessage):
  16. research_topic += f"Assistant: {message.content}\n"
  17. return research_topic
  18. def resolve_urls(urls_to_resolve: List[Any], id: int) -> Dict[str, str]:
  19. """
  20. Create a map of the vertex ai search urls (very long) to a short url with a unique id for each url.
  21. Ensures each original URL gets a consistent shortened form while maintaining uniqueness.
  22. """
  23. prefix = f"https://vertexaisearch.cloud.google.com/id/"
  24. urls = [site.web.uri for site in urls_to_resolve]
  25. # Create a dictionary that maps each unique URL to its first occurrence index
  26. resolved_map = {}
  27. for idx, url in enumerate(urls):
  28. if url not in resolved_map:
  29. resolved_map[url] = f"{prefix}{id}-{idx}"
  30. return resolved_map
  31. def insert_citation_markers(text, citations_list):
  32. """
  33. Inserts citation markers into a text string based on start and end indices.
  34. Args:
  35. text (str): The original text string.
  36. citations_list (list): A list of dictionaries, where each dictionary
  37. contains 'start_index', 'end_index', and
  38. 'segment_string' (the marker to insert).
  39. Indices are assumed to be for the original text.
  40. Returns:
  41. str: The text with citation markers inserted.
  42. """
  43. # Sort citations by end_index in descending order.
  44. # If end_index is the same, secondary sort by start_index descending.
  45. # This ensures that insertions at the end of the string don't affect
  46. # the indices of earlier parts of the string that still need to be processed.
  47. sorted_citations = sorted(
  48. citations_list, key=lambda c: (c["end_index"], c["start_index"]), reverse=True
  49. )
  50. modified_text = text
  51. for citation_info in sorted_citations:
  52. # These indices refer to positions in the *original* text,
  53. # but since we iterate from the end, they remain valid for insertion
  54. # relative to the parts of the string already processed.
  55. end_idx = citation_info["end_index"]
  56. marker_to_insert = ""
  57. for segment in citation_info["segments"]:
  58. marker_to_insert += f" [{segment['label']}]({segment['short_url']})"
  59. # Insert the citation marker at the original end_idx position
  60. modified_text = (
  61. modified_text[:end_idx] + marker_to_insert + modified_text[end_idx:]
  62. )
  63. return modified_text
  64. def get_citations(response, resolved_urls_map):
  65. """
  66. Extracts and formats citation information from a Gemini model's response.
  67. This function processes the grounding metadata provided in the response to
  68. construct a list of citation objects. Each citation object includes the
  69. start and end indices of the text segment it refers to, and a string
  70. containing formatted markdown links to the supporting web chunks.
  71. Args:
  72. response: The response object from the Gemini model, expected to have
  73. a structure including `candidates[0].grounding_metadata`.
  74. It also relies on a `resolved_map` being available in its
  75. scope to map chunk URIs to resolved URLs.
  76. Returns:
  77. list: A list of dictionaries, where each dictionary represents a citation
  78. and has the following keys:
  79. - "start_index" (int): The starting character index of the cited
  80. segment in the original text. Defaults to 0
  81. if not specified.
  82. - "end_index" (int): The character index immediately after the
  83. end of the cited segment (exclusive).
  84. - "segments" (list[str]): A list of individual markdown-formatted
  85. links for each grounding chunk.
  86. - "segment_string" (str): A concatenated string of all markdown-
  87. formatted links for the citation.
  88. Returns an empty list if no valid candidates or grounding supports
  89. are found, or if essential data is missing.
  90. """
  91. citations = []
  92. # Ensure response and necessary nested structures are present
  93. if not response or not response.candidates:
  94. return citations
  95. candidate = response.candidates[0]
  96. if (
  97. not hasattr(candidate, "grounding_metadata")
  98. or not candidate.grounding_metadata
  99. or not hasattr(candidate.grounding_metadata, "grounding_supports")
  100. ):
  101. return citations
  102. for support in candidate.grounding_metadata.grounding_supports:
  103. citation = {}
  104. # Ensure segment information is present
  105. if not hasattr(support, "segment") or support.segment is None:
  106. continue # Skip this support if segment info is missing
  107. start_index = (
  108. support.segment.start_index
  109. if support.segment.start_index is not None
  110. else 0
  111. )
  112. # Ensure end_index is present to form a valid segment
  113. if support.segment.end_index is None:
  114. continue # Skip if end_index is missing, as it's crucial
  115. # Add 1 to end_index to make it an exclusive end for slicing/range purposes
  116. # (assuming the API provides an inclusive end_index)
  117. citation["start_index"] = start_index
  118. citation["end_index"] = support.segment.end_index
  119. citation["segments"] = []
  120. if (
  121. hasattr(support, "grounding_chunk_indices")
  122. and support.grounding_chunk_indices
  123. ):
  124. for ind in support.grounding_chunk_indices:
  125. try:
  126. chunk = candidate.grounding_metadata.grounding_chunks[ind]
  127. resolved_url = resolved_urls_map.get(chunk.web.uri, None)
  128. citation["segments"].append(
  129. {
  130. "label": chunk.web.title.split(".")[:-1][0],
  131. "short_url": resolved_url,
  132. "value": chunk.web.uri,
  133. }
  134. )
  135. except (IndexError, AttributeError, NameError):
  136. # Handle cases where chunk, web, uri, or resolved_map might be problematic
  137. # For simplicity, we'll just skip adding this particular segment link
  138. # In a production system, you might want to log this.
  139. pass
  140. citations.append(citation)
  141. return citations