mcp_client.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package utils
  15. import (
  16. "bufio"
  17. "context"
  18. "encoding/json"
  19. "fmt"
  20. "io"
  21. "net/http"
  22. "strings"
  23. "sync"
  24. "sync/atomic"
  25. "time"
  26. "github.com/mark3labs/mcp-go/mcp"
  27. "yunion.io/x/jsonutils"
  28. "yunion.io/x/log"
  29. "yunion.io/x/pkg/errors"
  30. "yunion.io/x/onecloud/pkg/mcclient"
  31. "yunion.io/x/onecloud/pkg/mcclient/auth"
  32. )
  33. // mcpError represents the error object in a JSON-RPC response
  34. type mcpError struct {
  35. Code int `json:"code"`
  36. Message string `json:"message"`
  37. Data interface{} `json:"data,omitempty"`
  38. }
  39. // rawMCPResponse 用于处理 MCP 响应,支持延迟解析 Result
  40. type rawMCPResponse struct {
  41. JSONRPC string `json:"jsonrpc"`
  42. ID mcp.RequestId `json:"id"`
  43. Result json.RawMessage `json:"result,omitempty"`
  44. Error *mcpError `json:"error,omitempty"`
  45. }
  46. // MCPClient 是 MCP Server 的客户端,通过 SSE 协议与 MCP Server 通信
  47. type MCPClient struct {
  48. serverURL string
  49. client *http.Client
  50. sessionURL string
  51. sseBody io.ReadCloser
  52. messageID int64
  53. mu sync.Mutex
  54. initialized bool
  55. userCred mcclient.TokenCredential
  56. pendingReqs map[int64]chan *rawMCPResponse
  57. reqMu sync.Mutex
  58. }
  59. // NewMCPClient 创建一个新的 MCP 客户端
  60. func NewMCPClient(serverURL string, timeout time.Duration, userCred mcclient.TokenCredential) *MCPClient {
  61. return &MCPClient{
  62. serverURL: strings.TrimSuffix(serverURL, "/"),
  63. client: &http.Client{
  64. Timeout: timeout,
  65. },
  66. userCred: userCred,
  67. pendingReqs: make(map[int64]chan *rawMCPResponse),
  68. }
  69. }
  70. // connectSSE 连接 SSE 端点并开始事件循环
  71. func (c *MCPClient) connectSSE(ctx context.Context) error {
  72. // 连接 SSE 端点获取 session URL
  73. sseURL := c.serverURL + "/sse"
  74. req, err := http.NewRequestWithContext(ctx, "GET", sseURL, nil)
  75. if err != nil {
  76. return errors.Wrap(err, "create SSE request")
  77. }
  78. req.Header.Set("Accept", "text/event-stream")
  79. req.Header.Set("Cache-Control", "no-cache")
  80. resp, err := c.client.Do(req)
  81. if err != nil {
  82. return errors.Wrap(err, "connect to SSE")
  83. }
  84. if resp.StatusCode != http.StatusOK {
  85. resp.Body.Close()
  86. body, _ := io.ReadAll(resp.Body)
  87. return errors.Errorf("SSE connection failed with status %d: %s", resp.StatusCode, string(body))
  88. }
  89. c.sseBody = resp.Body
  90. // Channel to signal session URL found
  91. done := make(chan struct{})
  92. var initErr error
  93. // 读取 endpoint 事件获取 session URL
  94. go func() {
  95. reader := bufio.NewReader(c.sseBody)
  96. foundSession := false
  97. defer func() {
  98. if !foundSession {
  99. select {
  100. case <-done:
  101. default:
  102. close(done)
  103. }
  104. }
  105. }()
  106. for {
  107. line, err := reader.ReadString('\n')
  108. if err != nil {
  109. if !foundSession {
  110. initErr = err
  111. } else {
  112. log.Warningf("SSE connection closed: %v", err)
  113. }
  114. return
  115. }
  116. line = strings.TrimSpace(line)
  117. if strings.HasPrefix(line, "data:") {
  118. data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
  119. if !foundSession {
  120. if strings.Contains(data, "/message") {
  121. // 解析 session URL
  122. c.sessionURL = c.serverURL + data
  123. log.Infof("MCP Client initialized with session URL: %s", c.sessionURL)
  124. foundSession = true
  125. close(done)
  126. }
  127. } else {
  128. // 尝试解析为 JSON-RPC 响应
  129. var resp rawMCPResponse
  130. if err := json.Unmarshal([]byte(data), &resp); err == nil && resp.JSONRPC == mcp.JSONRPC_VERSION {
  131. // 提取 ID
  132. var reqID int64
  133. if idVal, ok := resp.ID.Value().(int64); ok {
  134. reqID = idVal
  135. } else if idVal, ok := resp.ID.Value().(float64); ok {
  136. reqID = int64(idVal)
  137. } else {
  138. // 可能是通知或 ID 类型不匹配,忽略
  139. continue
  140. }
  141. c.reqMu.Lock()
  142. ch, ok := c.pendingReqs[reqID]
  143. if ok {
  144. delete(c.pendingReqs, reqID)
  145. }
  146. c.reqMu.Unlock()
  147. if ok {
  148. select {
  149. case ch <- &resp:
  150. default:
  151. log.Warningf("response channel blocked for request %d", reqID)
  152. }
  153. }
  154. }
  155. }
  156. }
  157. }
  158. }()
  159. // Wait for session URL
  160. select {
  161. case <-done:
  162. if initErr != nil {
  163. c.sseBody.Close()
  164. return errors.Wrap(initErr, "read SSE event")
  165. }
  166. case <-time.After(10 * time.Second):
  167. c.sseBody.Close()
  168. return errors.Error("timeout waiting for session URL")
  169. case <-ctx.Done():
  170. c.sseBody.Close()
  171. return ctx.Err()
  172. }
  173. return nil
  174. }
  175. // Initialize 初始化 MCP 客户端连接
  176. func (c *MCPClient) Initialize(ctx context.Context) error {
  177. c.mu.Lock()
  178. defer c.mu.Unlock()
  179. if c.initialized {
  180. return nil
  181. }
  182. if err := c.connectSSE(ctx); err != nil {
  183. return err
  184. }
  185. // 发送初始化请求
  186. initParams := mcp.InitializeParams{
  187. ProtocolVersion: "2024-11-05",
  188. Capabilities: mcp.ClientCapabilities{},
  189. ClientInfo: mcp.Implementation{
  190. Name: "cloudpods-mcp-agent",
  191. Version: "1.0.0",
  192. },
  193. }
  194. initReq := mcp.JSONRPCRequest{
  195. JSONRPC: mcp.JSONRPC_VERSION,
  196. ID: mcp.NewRequestId(c.nextMessageID()),
  197. Params: initParams,
  198. }
  199. initReq.Method = string(mcp.MethodInitialize)
  200. _, err := c.sendRequest(ctx, initReq)
  201. if err != nil {
  202. c.sseBody.Close()
  203. return errors.Wrap(err, "send initialize request")
  204. }
  205. // 发送 initialized 通知
  206. notifyReq := mcp.JSONRPCRequest{
  207. JSONRPC: mcp.JSONRPC_VERSION,
  208. }
  209. notifyReq.Method = "notifications/initialized"
  210. _, err = c.sendRequest(ctx, notifyReq)
  211. if err != nil {
  212. log.Warningf("send initialized notification failed: %v", err)
  213. }
  214. c.initialized = true
  215. return nil
  216. }
  217. // nextMessageID 生成下一个消息 ID
  218. func (c *MCPClient) nextMessageID() int64 {
  219. return atomic.AddInt64(&c.messageID, 1)
  220. }
  221. // sendRequest 发送 JSON-RPC 请求
  222. func (c *MCPClient) sendRequest(ctx context.Context, req mcp.JSONRPCRequest) (*rawMCPResponse, error) {
  223. var respChan chan *rawMCPResponse
  224. var reqID int64
  225. var hasID bool
  226. if !req.ID.IsNil() {
  227. if idVal, ok := req.ID.Value().(int64); ok {
  228. reqID = idVal
  229. hasID = true
  230. }
  231. }
  232. if hasID {
  233. respChan = make(chan *rawMCPResponse, 1)
  234. c.reqMu.Lock()
  235. c.pendingReqs[reqID] = respChan
  236. c.reqMu.Unlock()
  237. // 确保在出错返回时清理 pendingReqs
  238. defer func() {
  239. c.reqMu.Lock()
  240. delete(c.pendingReqs, reqID)
  241. c.reqMu.Unlock()
  242. }()
  243. }
  244. reqBody := jsonutils.Marshal(req)
  245. log.Infof("MCP request: %s", reqBody.String())
  246. cli := auth.Client()
  247. if cli == nil {
  248. cli = mcclient.NewClient("", 0, false, true, "", "")
  249. }
  250. cred := c.userCred
  251. if cred == nil {
  252. log.Warningf("userCred is nil in sendRequest, creating empty token")
  253. cred = &mcclient.SSimpleToken{}
  254. }
  255. s := cli.NewSession(ctx, "", "", "", cred)
  256. s.SetServiceUrl("mcp", c.sessionURL)
  257. _, respBody, err := s.JSONRequest("mcp", "", "POST", "", nil, reqBody)
  258. if err != nil {
  259. return nil, errors.Wrap(err, "send request")
  260. }
  261. // 对于通知请求,可能没有响应体
  262. if !hasID {
  263. return nil, nil
  264. }
  265. // 如果有响应体,直接解析
  266. if respBody != nil {
  267. log.Debugf("MCP response (HTTP): %s", respBody.String())
  268. var mcpResp rawMCPResponse
  269. if err := respBody.Unmarshal(&mcpResp); err != nil {
  270. return nil, errors.Wrap(err, "decode response")
  271. }
  272. if mcpResp.Error != nil {
  273. return nil, errors.Errorf("MCP error %d: %s", mcpResp.Error.Code, mcpResp.Error.Message)
  274. }
  275. // 成功收到 HTTP 响应,从 pending 中移除(defer 会做,但我们可以提前返回)
  276. return &mcpResp, nil
  277. }
  278. // 如果响应为空,等待 SSE 推送
  279. select {
  280. case mcpResp := <-respChan:
  281. log.Debugf("MCP response (SSE): ID=%v", mcpResp.ID)
  282. if mcpResp.Error != nil {
  283. return nil, errors.Errorf("MCP error %d: %s", mcpResp.Error.Code, mcpResp.Error.Message)
  284. }
  285. return mcpResp, nil
  286. case <-ctx.Done():
  287. return nil, ctx.Err()
  288. case <-time.After(30 * time.Second):
  289. return nil, errors.Error("timeout waiting for SSE response")
  290. }
  291. }
  292. // ListTools 获取可用工具列表
  293. func (c *MCPClient) ListTools(ctx context.Context) ([]mcp.Tool, error) {
  294. if !c.initialized {
  295. if err := c.Initialize(ctx); err != nil {
  296. return nil, errors.Wrap(err, "initialize client")
  297. }
  298. }
  299. req := mcp.JSONRPCRequest{
  300. JSONRPC: mcp.JSONRPC_VERSION,
  301. ID: mcp.NewRequestId(c.nextMessageID()),
  302. }
  303. req.Method = string(mcp.MethodToolsList)
  304. resp, err := c.sendRequest(ctx, req)
  305. if err != nil {
  306. return nil, errors.Wrap(err, "send tools/list request")
  307. }
  308. if resp == nil {
  309. return nil, errors.Error("empty response for tools/list")
  310. }
  311. var result mcp.ListToolsResult
  312. if err := json.Unmarshal(resp.Result, &result); err != nil {
  313. return nil, errors.Wrap(err, "decode tools list result")
  314. }
  315. return result.Tools, nil
  316. }
  317. // CallTool 调用工具
  318. func (c *MCPClient) CallTool(ctx context.Context, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
  319. if !c.initialized {
  320. if err := c.Initialize(ctx); err != nil {
  321. return nil, errors.Wrap(err, "initialize client")
  322. }
  323. }
  324. params := mcp.CallToolParams{
  325. Name: toolName,
  326. Arguments: arguments,
  327. }
  328. req := mcp.JSONRPCRequest{
  329. JSONRPC: mcp.JSONRPC_VERSION,
  330. ID: mcp.NewRequestId(c.nextMessageID()),
  331. Params: params,
  332. }
  333. req.Method = string(mcp.MethodToolsCall)
  334. resp, err := c.sendRequest(ctx, req)
  335. if err != nil {
  336. return nil, errors.Wrap(err, "send tools/call request")
  337. }
  338. if resp == nil {
  339. return nil, errors.Error("empty response for tools/call")
  340. }
  341. var result mcp.CallToolResult
  342. if err := json.Unmarshal(resp.Result, &result); err != nil {
  343. return nil, errors.Wrap(err, "decode tool call result")
  344. }
  345. return &result, nil
  346. }
  347. // GetToolResultText 从工具调用结果中提取文本
  348. func GetToolResultText(r *mcp.CallToolResult) string {
  349. var texts []string
  350. for _, content := range r.Content {
  351. if textContent, ok := content.(mcp.TextContent); ok {
  352. texts = append(texts, textContent.Text)
  353. }
  354. }
  355. return strings.Join(texts, "\n")
  356. }
  357. // FormatToolResult 格式化工具调用结果
  358. func FormatToolResult(toolName string, result *mcp.CallToolResult, err error) string {
  359. if err != nil {
  360. return fmt.Sprintf("工具 %s 调用失败: %v", toolName, err)
  361. }
  362. if result.IsError {
  363. return fmt.Sprintf("工具 %s 返回错误: %s", toolName, GetToolResultText(result))
  364. }
  365. return GetToolResultText(result)
  366. }
  367. // Close 关闭客户端连接
  368. func (c *MCPClient) Close() error {
  369. c.mu.Lock()
  370. defer c.mu.Unlock()
  371. c.initialized = false
  372. c.sessionURL = ""
  373. if c.sseBody != nil {
  374. c.sseBody.Close()
  375. c.sseBody = nil
  376. }
  377. return nil
  378. }