| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437 |
- // Copyright 2019 Yunion
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package utils
- import (
- "bufio"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "strings"
- "sync"
- "sync/atomic"
- "time"
- "github.com/mark3labs/mcp-go/mcp"
- "yunion.io/x/jsonutils"
- "yunion.io/x/log"
- "yunion.io/x/pkg/errors"
- "yunion.io/x/onecloud/pkg/mcclient"
- "yunion.io/x/onecloud/pkg/mcclient/auth"
- )
- // mcpError represents the error object in a JSON-RPC response
- type mcpError struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Data interface{} `json:"data,omitempty"`
- }
- // rawMCPResponse 用于处理 MCP 响应,支持延迟解析 Result
- type rawMCPResponse struct {
- JSONRPC string `json:"jsonrpc"`
- ID mcp.RequestId `json:"id"`
- Result json.RawMessage `json:"result,omitempty"`
- Error *mcpError `json:"error,omitempty"`
- }
- // MCPClient 是 MCP Server 的客户端,通过 SSE 协议与 MCP Server 通信
- type MCPClient struct {
- serverURL string
- client *http.Client
- sessionURL string
- sseBody io.ReadCloser
- messageID int64
- mu sync.Mutex
- initialized bool
- userCred mcclient.TokenCredential
- pendingReqs map[int64]chan *rawMCPResponse
- reqMu sync.Mutex
- }
- // NewMCPClient 创建一个新的 MCP 客户端
- func NewMCPClient(serverURL string, timeout time.Duration, userCred mcclient.TokenCredential) *MCPClient {
- return &MCPClient{
- serverURL: strings.TrimSuffix(serverURL, "/"),
- client: &http.Client{
- Timeout: timeout,
- },
- userCred: userCred,
- pendingReqs: make(map[int64]chan *rawMCPResponse),
- }
- }
- // connectSSE 连接 SSE 端点并开始事件循环
- func (c *MCPClient) connectSSE(ctx context.Context) error {
- // 连接 SSE 端点获取 session URL
- sseURL := c.serverURL + "/sse"
- req, err := http.NewRequestWithContext(ctx, "GET", sseURL, nil)
- if err != nil {
- return errors.Wrap(err, "create SSE request")
- }
- req.Header.Set("Accept", "text/event-stream")
- req.Header.Set("Cache-Control", "no-cache")
- resp, err := c.client.Do(req)
- if err != nil {
- return errors.Wrap(err, "connect to SSE")
- }
- if resp.StatusCode != http.StatusOK {
- resp.Body.Close()
- body, _ := io.ReadAll(resp.Body)
- return errors.Errorf("SSE connection failed with status %d: %s", resp.StatusCode, string(body))
- }
- c.sseBody = resp.Body
- // Channel to signal session URL found
- done := make(chan struct{})
- var initErr error
- // 读取 endpoint 事件获取 session URL
- go func() {
- reader := bufio.NewReader(c.sseBody)
- foundSession := false
- defer func() {
- if !foundSession {
- select {
- case <-done:
- default:
- close(done)
- }
- }
- }()
- for {
- line, err := reader.ReadString('\n')
- if err != nil {
- if !foundSession {
- initErr = err
- } else {
- log.Warningf("SSE connection closed: %v", err)
- }
- return
- }
- line = strings.TrimSpace(line)
- if strings.HasPrefix(line, "data:") {
- data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
- if !foundSession {
- if strings.Contains(data, "/message") {
- // 解析 session URL
- c.sessionURL = c.serverURL + data
- log.Infof("MCP Client initialized with session URL: %s", c.sessionURL)
- foundSession = true
- close(done)
- }
- } else {
- // 尝试解析为 JSON-RPC 响应
- var resp rawMCPResponse
- if err := json.Unmarshal([]byte(data), &resp); err == nil && resp.JSONRPC == mcp.JSONRPC_VERSION {
- // 提取 ID
- var reqID int64
- if idVal, ok := resp.ID.Value().(int64); ok {
- reqID = idVal
- } else if idVal, ok := resp.ID.Value().(float64); ok {
- reqID = int64(idVal)
- } else {
- // 可能是通知或 ID 类型不匹配,忽略
- continue
- }
- c.reqMu.Lock()
- ch, ok := c.pendingReqs[reqID]
- if ok {
- delete(c.pendingReqs, reqID)
- }
- c.reqMu.Unlock()
- if ok {
- select {
- case ch <- &resp:
- default:
- log.Warningf("response channel blocked for request %d", reqID)
- }
- }
- }
- }
- }
- }
- }()
- // Wait for session URL
- select {
- case <-done:
- if initErr != nil {
- c.sseBody.Close()
- return errors.Wrap(initErr, "read SSE event")
- }
- case <-time.After(10 * time.Second):
- c.sseBody.Close()
- return errors.Error("timeout waiting for session URL")
- case <-ctx.Done():
- c.sseBody.Close()
- return ctx.Err()
- }
- return nil
- }
- // Initialize 初始化 MCP 客户端连接
- func (c *MCPClient) Initialize(ctx context.Context) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.initialized {
- return nil
- }
- if err := c.connectSSE(ctx); err != nil {
- return err
- }
- // 发送初始化请求
- initParams := mcp.InitializeParams{
- ProtocolVersion: "2024-11-05",
- Capabilities: mcp.ClientCapabilities{},
- ClientInfo: mcp.Implementation{
- Name: "cloudpods-mcp-agent",
- Version: "1.0.0",
- },
- }
- initReq := mcp.JSONRPCRequest{
- JSONRPC: mcp.JSONRPC_VERSION,
- ID: mcp.NewRequestId(c.nextMessageID()),
- Params: initParams,
- }
- initReq.Method = string(mcp.MethodInitialize)
- _, err := c.sendRequest(ctx, initReq)
- if err != nil {
- c.sseBody.Close()
- return errors.Wrap(err, "send initialize request")
- }
- // 发送 initialized 通知
- notifyReq := mcp.JSONRPCRequest{
- JSONRPC: mcp.JSONRPC_VERSION,
- }
- notifyReq.Method = "notifications/initialized"
- _, err = c.sendRequest(ctx, notifyReq)
- if err != nil {
- log.Warningf("send initialized notification failed: %v", err)
- }
- c.initialized = true
- return nil
- }
- // nextMessageID 生成下一个消息 ID
- func (c *MCPClient) nextMessageID() int64 {
- return atomic.AddInt64(&c.messageID, 1)
- }
- // sendRequest 发送 JSON-RPC 请求
- func (c *MCPClient) sendRequest(ctx context.Context, req mcp.JSONRPCRequest) (*rawMCPResponse, error) {
- var respChan chan *rawMCPResponse
- var reqID int64
- var hasID bool
- if !req.ID.IsNil() {
- if idVal, ok := req.ID.Value().(int64); ok {
- reqID = idVal
- hasID = true
- }
- }
- if hasID {
- respChan = make(chan *rawMCPResponse, 1)
- c.reqMu.Lock()
- c.pendingReqs[reqID] = respChan
- c.reqMu.Unlock()
- // 确保在出错返回时清理 pendingReqs
- defer func() {
- c.reqMu.Lock()
- delete(c.pendingReqs, reqID)
- c.reqMu.Unlock()
- }()
- }
- reqBody := jsonutils.Marshal(req)
- log.Infof("MCP request: %s", reqBody.String())
- cli := auth.Client()
- if cli == nil {
- cli = mcclient.NewClient("", 0, false, true, "", "")
- }
- cred := c.userCred
- if cred == nil {
- log.Warningf("userCred is nil in sendRequest, creating empty token")
- cred = &mcclient.SSimpleToken{}
- }
- s := cli.NewSession(ctx, "", "", "", cred)
- s.SetServiceUrl("mcp", c.sessionURL)
- _, respBody, err := s.JSONRequest("mcp", "", "POST", "", nil, reqBody)
- if err != nil {
- return nil, errors.Wrap(err, "send request")
- }
- // 对于通知请求,可能没有响应体
- if !hasID {
- return nil, nil
- }
- // 如果有响应体,直接解析
- if respBody != nil {
- log.Debugf("MCP response (HTTP): %s", respBody.String())
- var mcpResp rawMCPResponse
- if err := respBody.Unmarshal(&mcpResp); err != nil {
- return nil, errors.Wrap(err, "decode response")
- }
- if mcpResp.Error != nil {
- return nil, errors.Errorf("MCP error %d: %s", mcpResp.Error.Code, mcpResp.Error.Message)
- }
- // 成功收到 HTTP 响应,从 pending 中移除(defer 会做,但我们可以提前返回)
- return &mcpResp, nil
- }
- // 如果响应为空,等待 SSE 推送
- select {
- case mcpResp := <-respChan:
- log.Debugf("MCP response (SSE): ID=%v", mcpResp.ID)
- if mcpResp.Error != nil {
- return nil, errors.Errorf("MCP error %d: %s", mcpResp.Error.Code, mcpResp.Error.Message)
- }
- return mcpResp, nil
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-time.After(30 * time.Second):
- return nil, errors.Error("timeout waiting for SSE response")
- }
- }
- // ListTools 获取可用工具列表
- func (c *MCPClient) ListTools(ctx context.Context) ([]mcp.Tool, error) {
- if !c.initialized {
- if err := c.Initialize(ctx); err != nil {
- return nil, errors.Wrap(err, "initialize client")
- }
- }
- req := mcp.JSONRPCRequest{
- JSONRPC: mcp.JSONRPC_VERSION,
- ID: mcp.NewRequestId(c.nextMessageID()),
- }
- req.Method = string(mcp.MethodToolsList)
- resp, err := c.sendRequest(ctx, req)
- if err != nil {
- return nil, errors.Wrap(err, "send tools/list request")
- }
- if resp == nil {
- return nil, errors.Error("empty response for tools/list")
- }
- var result mcp.ListToolsResult
- if err := json.Unmarshal(resp.Result, &result); err != nil {
- return nil, errors.Wrap(err, "decode tools list result")
- }
- return result.Tools, nil
- }
- // CallTool 调用工具
- func (c *MCPClient) CallTool(ctx context.Context, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
- if !c.initialized {
- if err := c.Initialize(ctx); err != nil {
- return nil, errors.Wrap(err, "initialize client")
- }
- }
- params := mcp.CallToolParams{
- Name: toolName,
- Arguments: arguments,
- }
- req := mcp.JSONRPCRequest{
- JSONRPC: mcp.JSONRPC_VERSION,
- ID: mcp.NewRequestId(c.nextMessageID()),
- Params: params,
- }
- req.Method = string(mcp.MethodToolsCall)
- resp, err := c.sendRequest(ctx, req)
- if err != nil {
- return nil, errors.Wrap(err, "send tools/call request")
- }
- if resp == nil {
- return nil, errors.Error("empty response for tools/call")
- }
- var result mcp.CallToolResult
- if err := json.Unmarshal(resp.Result, &result); err != nil {
- return nil, errors.Wrap(err, "decode tool call result")
- }
- return &result, nil
- }
- // GetToolResultText 从工具调用结果中提取文本
- func GetToolResultText(r *mcp.CallToolResult) string {
- var texts []string
- for _, content := range r.Content {
- if textContent, ok := content.(mcp.TextContent); ok {
- texts = append(texts, textContent.Text)
- }
- }
- return strings.Join(texts, "\n")
- }
- // FormatToolResult 格式化工具调用结果
- func FormatToolResult(toolName string, result *mcp.CallToolResult, err error) string {
- if err != nil {
- return fmt.Sprintf("工具 %s 调用失败: %v", toolName, err)
- }
- if result.IsError {
- return fmt.Sprintf("工具 %s 返回错误: %s", toolName, GetToolResultText(result))
- }
- return GetToolResultText(result)
- }
- // Close 关闭客户端连接
- func (c *MCPClient) Close() error {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.initialized = false
- c.sessionURL = ""
- if c.sseBody != nil {
- c.sseBody.Close()
- c.sseBody = nil
- }
- return nil
- }
|