mcp_agent.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. package handler
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "time"
  10. "yunion.io/x/log"
  11. "yunion.io/x/pkg/errors"
  12. "yunion.io/x/pkg/util/httputils"
  13. "yunion.io/x/onecloud/pkg/apigateway/options"
  14. "yunion.io/x/onecloud/pkg/appsrv"
  15. "yunion.io/x/onecloud/pkg/httperrors"
  16. "yunion.io/x/onecloud/pkg/mcclient/auth"
  17. modules "yunion.io/x/onecloud/pkg/mcclient/modules/llm"
  18. mcpServerOption "yunion.io/x/onecloud/pkg/mcp-server/options"
  19. )
  20. func mcpServersConfigHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
  21. serviceName := "mcp-server"
  22. url, err := auth.GetPublicServiceURL(serviceName, options.Options.Region, "", httputils.GET)
  23. if err != nil {
  24. log.Warningf("GetPublicServiceURL for %s failed: %v", serviceName, err)
  25. }
  26. sseURL := fmt.Sprintf("%s/sse", url)
  27. responseType := r.URL.Query().Get("type")
  28. switch responseType {
  29. case "claude":
  30. // Claude 仅支持单个自定义 header,使用 X-API-Key。填写方式:
  31. // base64(ak:sk):`echo -n "你的AK:你的SK" | base64`,将输出填入
  32. cmd := fmt.Sprintf("claude mcp add --transport sse %s --header \"X-API-Key: <填写 token 或 base64(AK:SK)>\"", sseURL)
  33. w.Header().Set("Content-Type", "text/plain; charset=utf-8")
  34. w.Write([]byte(cmd))
  35. return
  36. case "cursor":
  37. // fall through to JSON
  38. default:
  39. // default: return JSON (cursor format)
  40. }
  41. // Cursor:在 headers 中填写控制台/CLI 获取的 Access Key 与 Secret Key
  42. config := map[string]interface{}{
  43. "mcpServers": map[string]interface{}{
  44. mcpServerOption.Options.MCPServerName: map[string]interface{}{
  45. "url": sseURL,
  46. "headers": map[string]string{
  47. "AK": "<填写 Access Key>",
  48. "SK": "<填写 Secret Key>",
  49. },
  50. },
  51. },
  52. }
  53. w.Header().Set("Content-Type", "application/json")
  54. json.NewEncoder(w).Encode(config)
  55. }
  56. func chatHandlerInfo(method, prefix string, handler func(context.Context, http.ResponseWriter, *http.Request)) *appsrv.SHandlerInfo {
  57. log.Debugf("%s - %s", method, prefix)
  58. hi := appsrv.SHandlerInfo{}
  59. hi.SetMethod(method)
  60. hi.SetPath(prefix)
  61. hi.SetHandler(handler)
  62. hi.SetProcessTimeout(6 * time.Hour)
  63. // Use default worker manager with default pool size (usually 32)
  64. // instead of uploader worker which has limited pool size (4)
  65. return &hi
  66. }
  67. func mcpAgentChatStreamHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
  68. params, _, body := appsrv.FetchEnv(ctx, w, r)
  69. id := params["<id>"]
  70. if len(id) == 0 {
  71. httperrors.MissingParameterError(ctx, w, "id")
  72. return
  73. }
  74. token := AppContextToken(ctx)
  75. s := auth.GetSession(ctx, token, FetchRegion(r))
  76. // Prepare request to backend
  77. headers := http.Header{}
  78. headers.Set("Content-Type", "application/json")
  79. // Forward the request body to the backend
  80. var bodyReader io.Reader
  81. if body != nil {
  82. bodyStr := body.String()
  83. bodyReader = strings.NewReader(bodyStr)
  84. }
  85. path := fmt.Sprintf("/mcp_agents/%s/chat-stream", id)
  86. resp, err := s.RawVersionRequest(
  87. modules.MCPAgent.ServiceType(),
  88. modules.MCPAgent.EndpointType(),
  89. "POST",
  90. path,
  91. headers,
  92. bodyReader,
  93. )
  94. if err != nil {
  95. httperrors.GeneralServerError(ctx, w, errors.Wrap(err, "request backend"))
  96. return
  97. }
  98. defer resp.Body.Close()
  99. if resp.StatusCode != 200 {
  100. // Read error body
  101. respBody, _ := io.ReadAll(resp.Body)
  102. // Try to parse as JSON error if possible, or just return as is
  103. if resp.StatusCode >= 400 && resp.StatusCode < 500 {
  104. httperrors.InputParameterError(ctx, w, "backend error: %s", string(respBody))
  105. } else {
  106. httperrors.GeneralServerError(ctx, w, fmt.Errorf("backend error %d: %s", resp.StatusCode, string(respBody)))
  107. }
  108. return
  109. }
  110. // Set SSE headers
  111. w.Header().Set("Content-Type", "text/event-stream")
  112. w.Header().Set("Cache-Control", "no-cache")
  113. w.Header().Set("Connection", "keep-alive")
  114. // For now just standard SSE headers.
  115. if f, ok := w.(http.Flusher); ok {
  116. f.Flush()
  117. }
  118. // Stream the response from backend to client
  119. buf := make([]byte, 1024)
  120. for {
  121. n, err := resp.Body.Read(buf)
  122. if n > 0 {
  123. if _, wErr := w.Write(buf[:n]); wErr != nil {
  124. log.Errorf("write response error: %v", wErr)
  125. return
  126. }
  127. if f, ok := w.(http.Flusher); ok {
  128. f.Flush()
  129. }
  130. }
  131. if err != nil {
  132. if err != io.EOF {
  133. log.Errorf("read backend response error: %v", err)
  134. }
  135. break
  136. }
  137. }
  138. }
  139. // mcpAgentDefaultChatStreamHandler 将请求转发到 region 的 default-chat-stream(使用 default_agent=true 的条目)
  140. func mcpAgentDefaultChatStreamHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
  141. token := AppContextToken(ctx)
  142. s := auth.GetSession(ctx, token, FetchRegion(r))
  143. headers := http.Header{}
  144. headers.Set("Content-Type", "application/json")
  145. var bodyReader io.Reader
  146. if r.Body != nil {
  147. bodyReader = r.Body
  148. }
  149. path := "/mcp_agents/default-chat-stream"
  150. resp, err := s.RawVersionRequest(
  151. modules.MCPAgent.ServiceType(),
  152. modules.MCPAgent.EndpointType(),
  153. "POST",
  154. path,
  155. headers,
  156. bodyReader,
  157. )
  158. if err != nil {
  159. httperrors.GeneralServerError(ctx, w, errors.Wrap(err, "request backend"))
  160. return
  161. }
  162. defer resp.Body.Close()
  163. if resp.StatusCode != 200 {
  164. respBody, _ := io.ReadAll(resp.Body)
  165. if resp.StatusCode >= 400 && resp.StatusCode < 500 {
  166. httperrors.InputParameterError(ctx, w, "backend error: %s", string(respBody))
  167. } else {
  168. httperrors.GeneralServerError(ctx, w, fmt.Errorf("backend error %d: %s", resp.StatusCode, string(respBody)))
  169. }
  170. return
  171. }
  172. w.Header().Set("Content-Type", "text/event-stream")
  173. w.Header().Set("Cache-Control", "no-cache")
  174. w.Header().Set("Connection", "keep-alive")
  175. if f, ok := w.(http.Flusher); ok {
  176. f.Flush()
  177. }
  178. buf := make([]byte, 1024)
  179. for {
  180. n, err := resp.Body.Read(buf)
  181. if n > 0 {
  182. if _, wErr := w.Write(buf[:n]); wErr != nil {
  183. log.Errorf("write response error: %v", wErr)
  184. return
  185. }
  186. if f, ok := w.(http.Flusher); ok {
  187. f.Flush()
  188. }
  189. }
  190. if err != nil {
  191. if err != io.EOF {
  192. log.Errorf("read backend response error: %v", err)
  193. }
  194. break
  195. }
  196. }
  197. }
  198. // mcpAgentDefaultToolsHandler 将 GET 请求转发到 region 的 default-mcp-tools(仅使用 options.MCPServerURL,不通过 mcp_agent 条目)
  199. func mcpAgentDefaultToolsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
  200. token := AppContextToken(ctx)
  201. s := auth.GetSession(ctx, token, FetchRegion(r))
  202. path := "/mcp_agents/default-mcp-tools"
  203. resp, err := s.RawVersionRequest(
  204. modules.MCPAgent.ServiceType(),
  205. modules.MCPAgent.EndpointType(),
  206. "GET",
  207. path,
  208. nil,
  209. nil,
  210. )
  211. if err != nil {
  212. httperrors.GeneralServerError(ctx, w, errors.Wrap(err, "request backend"))
  213. return
  214. }
  215. defer resp.Body.Close()
  216. if resp.StatusCode != 200 {
  217. respBody, _ := io.ReadAll(resp.Body)
  218. if resp.StatusCode >= 400 && resp.StatusCode < 500 {
  219. httperrors.InputParameterError(ctx, w, "backend error: %s", string(respBody))
  220. } else {
  221. httperrors.GeneralServerError(ctx, w, fmt.Errorf("backend error %d: %s", resp.StatusCode, string(respBody)))
  222. }
  223. return
  224. }
  225. w.Header().Set("Content-Type", "application/json")
  226. _, err = io.Copy(w, resp.Body)
  227. if err != nil {
  228. log.Errorf("write default mcp tools response error: %v", err)
  229. }
  230. }