mcp_agent.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783
  1. package models
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "net/http"
  8. "strings"
  9. "time"
  10. "yunion.io/x/jsonutils"
  11. "yunion.io/x/log"
  12. "yunion.io/x/pkg/errors"
  13. seclib "yunion.io/x/pkg/utils"
  14. "yunion.io/x/sqlchemy"
  15. api "yunion.io/x/onecloud/pkg/apis/llm"
  16. "yunion.io/x/onecloud/pkg/appsrv"
  17. "yunion.io/x/onecloud/pkg/cloudcommon/db"
  18. "yunion.io/x/onecloud/pkg/cloudcommon/policy"
  19. "yunion.io/x/onecloud/pkg/httperrors"
  20. "yunion.io/x/onecloud/pkg/llm/options"
  21. "yunion.io/x/onecloud/pkg/llm/utils"
  22. "yunion.io/x/onecloud/pkg/mcclient"
  23. "yunion.io/x/onecloud/pkg/util/stringutils2"
  24. )
  25. func init() {
  26. GetMCPAgentManager()
  27. }
  28. var mcpAgentManager *SMCPAgentManager
  29. var mcpAgentWorkerMan *appsrv.SWorkerManager
  30. func GetMCPAgentWorkerManager() *appsrv.SWorkerManager {
  31. return mcpAgentWorkerMan
  32. }
  33. func GetMCPAgentManager() *SMCPAgentManager {
  34. if mcpAgentManager != nil {
  35. return mcpAgentManager
  36. }
  37. mcpAgentManager = &SMCPAgentManager{
  38. SSharableVirtualResourceBaseManager: db.NewSharableVirtualResourceBaseManager(
  39. SMCPAgent{},
  40. "mcp_agents_tbl",
  41. "mcp_agent",
  42. "mcp_agents",
  43. ),
  44. }
  45. mcpAgentManager.SetVirtualObject(mcpAgentManager)
  46. return mcpAgentManager
  47. }
  48. type SMCPAgentManager struct {
  49. db.SSharableVirtualResourceBaseManager
  50. }
  51. // unsetOtherDefaultAgents 将除 excludeId 外所有条目的 default_agent 置为 false,保证全局唯一
  52. func (man *SMCPAgentManager) unsetOtherDefaultAgents(ctx context.Context, excludeId string) error {
  53. q := man.Query().IsTrue("default_agent")
  54. if len(excludeId) > 0 {
  55. q = q.NotEquals("id", excludeId)
  56. }
  57. agents := make([]SMCPAgent, 0)
  58. err := db.FetchModelObjects(man, q, &agents)
  59. if err != nil {
  60. return errors.Wrap(err, "FetchModelObjects")
  61. }
  62. for i := range agents {
  63. _, err := db.Update(&agents[i], func() error {
  64. agents[i].DefaultAgent = false
  65. return nil
  66. })
  67. if err != nil {
  68. return errors.Wrapf(err, "Update agent %s", agents[i].Id)
  69. }
  70. }
  71. return nil
  72. }
  73. // GetDefaultAgent 返回当前用户可见的、default_agent=true 的那条 MCP Agent(仅一条)
  74. func (man *SMCPAgentManager) GetDefaultAgent(ctx context.Context, userCred mcclient.TokenCredential) (*SMCPAgent, error) {
  75. query := jsonutils.NewDict()
  76. query.Set("default_agent", jsonutils.JSONTrue)
  77. ownerId, scope, err, _ := db.FetchCheckQueryOwnerScope(ctx, userCred, query, man, policy.PolicyActionList, true)
  78. if err != nil {
  79. return nil, errors.Wrap(err, "FetchCheckQueryOwnerScope")
  80. }
  81. q := man.Query()
  82. q = man.FilterByOwner(ctx, q, man, userCred, ownerId, scope)
  83. q = q.IsTrue("default_agent")
  84. var agent SMCPAgent
  85. err = q.First(&agent)
  86. if err != nil {
  87. if errors.Cause(err) == sql.ErrNoRows {
  88. return nil, nil
  89. }
  90. return nil, errors.Wrap(err, "First default agent")
  91. }
  92. return &agent, nil
  93. }
  94. // GetDefaultMcpServerTools 返回默认 MCP 服务器(options.Options.MCPServerURL)的 tools,不依赖任何 mcp_agent 记录
  95. func (man *SMCPAgentManager) GetDefaultMcpServerTools(ctx context.Context, userCred mcclient.TokenCredential) (jsonutils.JSONObject, error) {
  96. timeout := time.Duration(options.Options.MCPAgentTimeout) * time.Second
  97. mcpClient := utils.NewMCPClient(options.Options.MCPServerURL, timeout, userCred)
  98. defer mcpClient.Close()
  99. tools, err := mcpClient.ListTools(ctx)
  100. if err != nil {
  101. return nil, errors.Wrap(err, "list default MCP tools")
  102. }
  103. return jsonutils.Marshal(tools), nil
  104. }
  105. type SMCPAgent struct {
  106. db.SSharableVirtualResourceBase
  107. // LLMId 关联的 LLM 实例 ID
  108. LLMId string `width:"128" charset:"ascii" nullable:"true" list:"user" create:"optional" update:"user"`
  109. // LLMUrl 对应后端大模型的 base 请求地址
  110. LLMUrl string `width:"512" charset:"utf8" nullable:"false" list:"user" create:"required" update:"user"`
  111. // LLMDriver 对应使用的大模型驱动(llm_client),现在可以被设置为 ollama 或 openai
  112. LLMDriver string `width:"64" charset:"ascii" nullable:"false" list:"user" create:"required" update:"user"`
  113. // Model 使用的模型名称
  114. Model string `width:"128" charset:"ascii" nullable:"false" list:"user" create:"required" update:"user"`
  115. // ApiKey 即在 llm_driver 中需要用到的认证
  116. ApiKey string `width:"512" charset:"utf8" nullable:"true" list:"user" create:"optional" update:"user"`
  117. // McpServer 即 mcp 服务器的后端地址
  118. McpServer string `width:"512" charset:"utf8" nullable:"false" list:"user" create:"optional" update:"user"`
  119. // DefaultAgent 是否为默认 Agent,全局仅允许一条为 true
  120. DefaultAgent bool `default:"false" list:"user" create:"optional" update:"user"`
  121. }
  122. func (mcp *SMCPAgent) BeforeInsert() {
  123. if len(mcp.Id) == 0 {
  124. mcp.Id = db.DefaultUUIDGenerator()
  125. }
  126. if len(mcp.ApiKey) > 0 {
  127. sec, err := seclib.EncryptAESBase64(mcp.Id, mcp.ApiKey)
  128. if err != nil {
  129. log.Errorf("EncryptAESBase64 fail %s", err)
  130. } else {
  131. mcp.ApiKey = sec
  132. }
  133. }
  134. mcp.SSharableVirtualResourceBase.BeforeInsert()
  135. }
  136. func (mcp *SMCPAgent) BeforeUpdate() {
  137. if len(mcp.ApiKey) > 0 {
  138. // heuristic to check if it is plaintext
  139. _, err := seclib.DescryptAESBase64(mcp.Id, mcp.ApiKey)
  140. if err != nil {
  141. sec, err := seclib.EncryptAESBase64(mcp.Id, mcp.ApiKey)
  142. if err != nil {
  143. log.Errorf("EncryptAESBase64 fail %s", err)
  144. } else {
  145. mcp.ApiKey = sec
  146. }
  147. }
  148. }
  149. }
  150. func (mcp *SMCPAgent) PostCreate(ctx context.Context, userCred mcclient.TokenCredential, ownerId mcclient.IIdentityProvider, query jsonutils.JSONObject, data jsonutils.JSONObject) {
  151. mcp.SSharableVirtualResourceBase.PostCreate(ctx, userCred, ownerId, query, data)
  152. if mcp.DefaultAgent {
  153. if err := GetMCPAgentManager().unsetOtherDefaultAgents(ctx, mcp.Id); err != nil {
  154. log.Errorf("unsetOtherDefaultAgents after create: %v", err)
  155. }
  156. }
  157. }
  158. func (mcp *SMCPAgent) PostUpdate(ctx context.Context, userCred mcclient.TokenCredential, query jsonutils.JSONObject, data jsonutils.JSONObject) {
  159. mcp.SSharableVirtualResourceBase.PostUpdate(ctx, userCred, query, data)
  160. if mcp.DefaultAgent {
  161. if err := GetMCPAgentManager().unsetOtherDefaultAgents(ctx, mcp.Id); err != nil {
  162. log.Errorf("unsetOtherDefaultAgents after update: %v", err)
  163. }
  164. }
  165. }
  166. func (mcp *SMCPAgent) GetApiKey() (string, error) {
  167. if len(mcp.ApiKey) == 0 {
  168. return "", nil
  169. }
  170. // try decrypt
  171. key, err := seclib.DescryptAESBase64(mcp.Id, mcp.ApiKey)
  172. if err == nil {
  173. return key, nil
  174. }
  175. return mcp.ApiKey, nil
  176. }
  177. func (man *SMCPAgentManager) CustomizeHandlerInfo(info *appsrv.SHandlerInfo) {
  178. man.SSharableVirtualResourceBaseManager.CustomizeHandlerInfo(info)
  179. // log.Infoln("query name of handler info", info.GetName(nil))
  180. switch info.GetName(nil) {
  181. case "get_specific":
  182. info.SetProcessTimeout(time.Hour * 4).SetWorkerManager(mcpAgentWorkerMan)
  183. }
  184. }
  185. func (man *SMCPAgentManager) ValidateCreateData(ctx context.Context, userCred mcclient.TokenCredential, ownerId mcclient.IIdentityProvider, query jsonutils.JSONObject, input *api.MCPAgentCreateInput) (*api.MCPAgentCreateInput, error) {
  186. var err error
  187. input.SharableVirtualResourceCreateInput, err = man.SSharableVirtualResourceBaseManager.ValidateCreateData(ctx, userCred, ownerId, query, input.SharableVirtualResourceCreateInput)
  188. if err != nil {
  189. return input, errors.Wrap(err, "validate SharableVirtualResourceCreateInput")
  190. }
  191. // 如果提供了 llm_id,则通过 LLM 获取 llm_url 和 model
  192. if len(input.LLMId) > 0 {
  193. llmObj, err := GetLLMManager().FetchByIdOrName(ctx, userCred, input.LLMId)
  194. if err != nil {
  195. return input, errors.Wrapf(err, "fetch LLM by id %s", input.LLMId)
  196. }
  197. llm := llmObj.(*SLLM)
  198. input.LLMId = llm.Id
  199. llmUrl, err := llm.GetLLMAccessUrlInfo(ctx, userCred, query)
  200. if err != nil {
  201. return input, errors.Wrapf(err, "get LLM URL from LLM %s", input.LLMId)
  202. }
  203. input.LLMUrl = llmUrl.LoginUrl
  204. if len(input.Model) == 0 {
  205. mdlInfos, err := llm.getProbedInstantModelsExt(ctx, userCred)
  206. if err != nil {
  207. return input, errors.Wrap(err, "get probed models from LLM instance")
  208. }
  209. if len(mdlInfos) == 0 {
  210. return input, httperrors.NewBadRequestError("no available models found in LLM instance %s", input.LLMId)
  211. }
  212. var firstModel api.LLMInternalInstantMdlInfo
  213. for _, mdlInfo := range mdlInfos {
  214. firstModel = mdlInfo
  215. break
  216. }
  217. input.Model = fmt.Sprintf("%s:%s", firstModel.Name, firstModel.Tag)
  218. }
  219. }
  220. // 验证 llm_url 不为空
  221. if len(input.LLMUrl) == 0 {
  222. return input, errors.Wrap(httperrors.ErrInputParameter, "llm_url is required (or provide llm_id to auto-fetch)")
  223. }
  224. // 验证 llm_driver 必须是 ollama 或 openai
  225. input.LLMDriver = strings.ToLower(strings.TrimSpace(input.LLMDriver))
  226. if !api.IsLLMClientType(input.LLMDriver) {
  227. return input, errors.Wrapf(httperrors.ErrInputParameter, "llm_driver must be one of: %s, got: %s", api.LLM_CLIENT_TYPES.List(), input.LLMDriver)
  228. }
  229. // 验证 model 不为空
  230. if len(input.Model) == 0 {
  231. return input, errors.Wrap(httperrors.ErrInputParameter, "model is required")
  232. }
  233. // 验证 mcp_server 不为空
  234. if len(input.McpServer) == 0 {
  235. input.McpServer = options.Options.MCPServerURL
  236. }
  237. // 对于 openai 驱动,api_key 是必需的
  238. if input.LLMDriver == string(api.LLM_CLIENT_OPENAI) && len(input.ApiKey) == 0 {
  239. return input, errors.Wrap(httperrors.ErrInputParameter, "api_key is required when llm_driver is openai")
  240. }
  241. input.Status = api.STATUS_READY
  242. return input, nil
  243. }
  244. func (man *SMCPAgentManager) ValidateUpdateData(ctx context.Context, userCred mcclient.TokenCredential, ownerId mcclient.IIdentityProvider, query jsonutils.JSONObject, input *api.MCPAgentUpdateInput) (*api.MCPAgentUpdateInput, error) {
  245. var err error
  246. input.SharableVirtualResourceCreateInput, err = man.SSharableVirtualResourceBaseManager.ValidateCreateData(ctx, userCred, ownerId, query, input.SharableVirtualResourceCreateInput)
  247. if err != nil {
  248. return input, errors.Wrap(err, "validate SharableVirtualResourceCreateInput")
  249. }
  250. // 如果提供了 llm_id,则通过 LLM 获取 llm_url 和 model
  251. if input.LLMId != nil && len(*input.LLMId) > 0 {
  252. llmObj, err := GetLLMManager().FetchByIdOrName(ctx, userCred, *input.LLMId)
  253. if err != nil {
  254. return input, errors.Wrapf(err, "fetch LLM by id %s", *input.LLMId)
  255. }
  256. llm := llmObj.(*SLLM)
  257. llmUrl, err := llm.GetLLMAccessUrlInfo(ctx, userCred, query)
  258. if err != nil {
  259. return input, errors.Wrapf(err, "get LLM URL from LLM %s", *input.LLMId)
  260. }
  261. input.LLMUrl = &llmUrl.LoginUrl
  262. if input.Model == nil || len(*input.Model) == 0 {
  263. mdlInfos, err := llm.getProbedInstantModelsExt(ctx, userCred)
  264. if err != nil {
  265. return input, errors.Wrap(err, "get probed models from LLM instance")
  266. }
  267. if len(mdlInfos) == 0 {
  268. return input, httperrors.NewBadRequestError("no available models found in LLM instance %s", *input.LLMId)
  269. }
  270. var firstModel api.LLMInternalInstantMdlInfo
  271. for _, mdlInfo := range mdlInfos {
  272. firstModel = mdlInfo
  273. break
  274. }
  275. modelStr := fmt.Sprintf("%s:%s", firstModel.Name, firstModel.Tag)
  276. input.Model = &modelStr
  277. }
  278. }
  279. // 如果更新 llm_driver,验证其值
  280. if input.LLMDriver != nil {
  281. *input.LLMDriver = strings.ToLower(strings.TrimSpace(*input.LLMDriver))
  282. if !api.IsLLMClientType(*input.LLMDriver) {
  283. return input, errors.Wrapf(httperrors.ErrInputParameter, "llm_driver must be one of: %s, got: %s", api.LLM_CLIENT_TYPES.List(), *input.LLMDriver)
  284. }
  285. }
  286. return input, nil
  287. }
  288. func (man *SMCPAgentManager) ListItemFilter(
  289. ctx context.Context,
  290. q *sqlchemy.SQuery,
  291. userCred mcclient.TokenCredential,
  292. input api.MCPAgentListInput,
  293. ) (*sqlchemy.SQuery, error) {
  294. q, err := man.SSharableVirtualResourceBaseManager.ListItemFilter(ctx, q, userCred, input.SharableVirtualResourceListInput)
  295. if err != nil {
  296. return nil, errors.Wrapf(err, "SSharableVirtualResourceBaseManager.ListItemFilter")
  297. }
  298. if len(input.LLMDriver) > 0 {
  299. q = q.Equals("llm_driver", strings.ToLower(strings.TrimSpace(input.LLMDriver)))
  300. }
  301. if input.DefaultAgent != nil && *input.DefaultAgent {
  302. q = q.IsTrue("default_agent")
  303. }
  304. return q, nil
  305. }
  306. func (manager *SMCPAgentManager) FetchCustomizeColumns(
  307. ctx context.Context,
  308. userCred mcclient.TokenCredential,
  309. query jsonutils.JSONObject,
  310. objs []interface{},
  311. fields stringutils2.SSortedStrings,
  312. isList bool,
  313. ) []api.MCPAgentDetails {
  314. rows := make([]api.MCPAgentDetails, len(objs))
  315. vrows := manager.SSharableVirtualResourceBaseManager.FetchCustomizeColumns(ctx, userCred, query, objs, fields, isList)
  316. agents := []SMCPAgent{}
  317. jsonutils.Update(&agents, objs)
  318. llmIds := make([]string, 0)
  319. for i := range agents {
  320. if len(agents[i].LLMId) > 0 {
  321. llmIds = append(llmIds, agents[i].LLMId)
  322. }
  323. }
  324. var llmIdNameMap map[string]string
  325. if len(llmIds) > 0 {
  326. var err error
  327. llmIdNameMap, err = db.FetchIdNameMap2(GetLLMManager(), llmIds)
  328. if err != nil {
  329. log.Errorf("FetchIdNameMap2 for LLMs failed: %v", err)
  330. }
  331. }
  332. for i := range rows {
  333. rows[i].SharableVirtualResourceDetails = vrows[i]
  334. if i < len(agents) {
  335. rows[i].LLMId = agents[i].LLMId
  336. if name, ok := llmIdNameMap[agents[i].LLMId]; ok {
  337. rows[i].LLMName = name
  338. }
  339. rows[i].DefaultAgent = agents[i].DefaultAgent
  340. }
  341. }
  342. return rows
  343. }
  344. func (mcp *SMCPAgent) GetLLMClientDriver() ILLMClient {
  345. return GetLLMClientDriver(api.LLMClientType(mcp.LLMDriver))
  346. }
  347. func (mcp *SMCPAgent) GetMcpServerUrl(ctx context.Context, userCred mcclient.TokenCredential) (string, error) {
  348. if len(mcp.McpServer) > 0 {
  349. return mcp.McpServer, nil
  350. }
  351. return options.Options.MCPServerURL, nil
  352. }
  353. func (mcp *SMCPAgent) GetDetailsMcpTools(ctx context.Context, userCred mcclient.TokenCredential, query jsonutils.JSONObject) (jsonutils.JSONObject, error) {
  354. // 创建 MCP 客户端
  355. timeout := time.Duration(options.Options.MCPAgentTimeout) * time.Second
  356. mcpServerUrl, err := mcp.GetMcpServerUrl(ctx, userCred)
  357. if err != nil {
  358. return nil, errors.Wrap(err, "GetMcpServerUrl")
  359. }
  360. mcpClient := utils.NewMCPClient(mcpServerUrl, timeout, userCred)
  361. // 获取工具列表
  362. tools, err := mcpClient.ListTools(ctx)
  363. if err != nil {
  364. return nil, errors.Wrap(err, "list MCP tools")
  365. }
  366. return jsonutils.Marshal(tools), nil
  367. }
  368. func (mcp *SMCPAgent) GetDetailsToolRequest(
  369. ctx context.Context,
  370. userCred mcclient.TokenCredential,
  371. input api.LLMToolRequestInput,
  372. ) (jsonutils.JSONObject, error) {
  373. // 创建 MCP 客户端
  374. timeout := time.Duration(options.Options.MCPAgentTimeout) * time.Second
  375. mcpServerUrl, err := mcp.GetMcpServerUrl(ctx, userCred)
  376. if err != nil {
  377. return nil, errors.Wrap(err, "GetMcpServerUrl")
  378. }
  379. mcpClient := utils.NewMCPClient(mcpServerUrl, timeout, userCred)
  380. defer mcpClient.Close()
  381. // 调用工具
  382. result, err := mcpClient.CallTool(ctx, input.ToolName, input.Arguments)
  383. if err != nil {
  384. return nil, errors.Wrapf(err, "call tool %s", input.ToolName)
  385. }
  386. return jsonutils.Marshal(result), nil
  387. }
  388. // func (mcp *SMCPAgent) GetDetailsChatTest(
  389. // ctx context.Context,
  390. // userCred mcclient.TokenCredential,
  391. // input api.LLMChatTestInput,
  392. // ) (jsonutils.JSONObject, error) {
  393. // llmClient := mcp.GetLLMClientDriver()
  394. // if llmClient == nil {
  395. // return nil, errors.Error("failed to get LLM client driver")
  396. // }
  397. // message := llmClient.NewUserMessage(input.Message)
  398. // result, err := llmClient.Chat(ctx, mcp, []ILLMChatMessage{message}, nil)
  399. // if err != nil {
  400. // return nil, errors.Wrap(err, "chat with LLM")
  401. // }
  402. // return jsonutils.Marshal(result), nil
  403. // }
  404. func (mcp *SMCPAgent) PerformChatStream(
  405. ctx context.Context,
  406. userCred mcclient.TokenCredential,
  407. query jsonutils.JSONObject,
  408. input api.LLMMCPAgentRequestInput,
  409. ) (jsonutils.JSONObject, error) {
  410. appParams := appsrv.AppContextGetParams(ctx)
  411. if appParams == nil {
  412. return nil, errors.Error("failed to get app params")
  413. }
  414. w := appParams.Response
  415. w.Header().Set("Content-Type", "text/event-stream")
  416. w.Header().Set("Cache-Control", "no-cache")
  417. w.Header().Set("Connection", "keep-alive")
  418. if f, ok := w.(http.Flusher); ok {
  419. f.Flush()
  420. } else {
  421. return nil, errors.Error("Streaming unsupported!")
  422. }
  423. _, err := mcp.process(ctx, userCred, &input, func(content string) error {
  424. if len(content) > 0 {
  425. for line := range strings.SplitSeq(content, "\n") {
  426. fmt.Fprintf(w, "data: %s\n", line)
  427. }
  428. fmt.Fprintf(w, "\n")
  429. if f, ok := w.(http.Flusher); ok {
  430. f.Flush()
  431. }
  432. }
  433. return nil
  434. })
  435. if err != nil {
  436. fmt.Fprintf(w, "data: Error: %v\n\n", err)
  437. }
  438. return nil, nil
  439. }
  440. // process 处理用户请求
  441. func (mcp *SMCPAgent) process(ctx context.Context, userCred mcclient.TokenCredential, req *api.LLMMCPAgentRequestInput, onStream func(string) error) (*api.MCPAgentResponse, error) {
  442. // 获取 MCP Server 的工具列表
  443. mcpServerUrl, err := mcp.GetMcpServerUrl(ctx, userCred)
  444. if err != nil {
  445. return nil, errors.Wrap(err, "GetMcpServerUrl")
  446. }
  447. mcpClient := utils.NewMCPClient(mcpServerUrl, 10*time.Minute, userCred)
  448. defer mcpClient.Close()
  449. mcpTools, err := mcpClient.ListTools(ctx)
  450. if err != nil {
  451. return nil, errors.Wrap(err, "list MCP tools")
  452. }
  453. log.Infof("Got %d tools from MCP Server", len(mcpTools))
  454. // get llmClient
  455. llmClient := mcp.GetLLMClientDriver()
  456. if llmClient == nil {
  457. return nil, errors.Error("failed to get LLM client driver")
  458. }
  459. tools := llmClient.ConvertMCPTools(mcpTools)
  460. // 构建系统提示词
  461. systemPrompt := buildSystemPrompt()
  462. // 初始化消息历史
  463. messages := make([]ILLMChatMessage, 0)
  464. messages = append(messages, llmClient.NewSystemMessage(systemPrompt))
  465. // 处理历史消息
  466. if len(req.History) > 0 {
  467. historyMessages := processHistoryMessages(
  468. req.History,
  469. llmClient,
  470. options.Options.MCPAgentUserCharLimit,
  471. options.Options.MCPAgentAssistantCharLimit,
  472. )
  473. messages = append(messages, historyMessages...)
  474. }
  475. messages = append(messages, llmClient.NewUserMessage(req.Message))
  476. // 记录工具调用
  477. var toolCallRecords []api.MCPAgentToolCallRecord
  478. log.Infof("Phase 1: Thinking & Acting...")
  479. // 处理流式的工具调用参数
  480. type accumToolCall struct {
  481. Id string
  482. Name string
  483. RawArguments strings.Builder
  484. }
  485. accToolCalls := make(map[int]*accumToolCall)
  486. var accumulatedContent strings.Builder
  487. var accumulatedReasoning strings.Builder
  488. hasToolCalls := false
  489. err = llmClient.ChatStream(ctx, mcp, messages, tools, func(chunk ILLMChatResponse) error {
  490. if chunk.HasToolCalls() {
  491. hasToolCalls = true
  492. for _, tc := range chunk.GetToolCalls() {
  493. idx := tc.GetIndex()
  494. if _, exists := accToolCalls[idx]; !exists {
  495. accToolCalls[idx] = &accumToolCall{
  496. Id: tc.GetId(),
  497. }
  498. }
  499. atc := accToolCalls[idx]
  500. if id := tc.GetId(); id != "" {
  501. atc.Id = id
  502. }
  503. if name := tc.GetFunction().GetName(); name != "" {
  504. atc.Name = name
  505. }
  506. if args := tc.GetFunction().GetRawArguments(); args != "" {
  507. atc.RawArguments.WriteString(args)
  508. }
  509. }
  510. }
  511. if r := chunk.GetReasoningContent(); len(r) > 0 {
  512. accumulatedReasoning.WriteString(r)
  513. }
  514. content := chunk.GetContent()
  515. if len(content) > 0 {
  516. accumulatedContent.WriteString(content)
  517. if onStream != nil {
  518. if err := onStream(content); err != nil {
  519. return err
  520. }
  521. }
  522. }
  523. return nil
  524. })
  525. if err != nil {
  526. return nil, errors.Wrap(err, "phase 1 chat stream error")
  527. }
  528. // 检查是否有工具调用
  529. if !hasToolCalls {
  530. // 如果阶段一没有调用工具,直接返回结果
  531. return &api.MCPAgentResponse{
  532. Success: true,
  533. Answer: accumulatedContent.String(),
  534. ToolCalls: toolCallRecords,
  535. }, nil
  536. }
  537. // Convert accumulated tool calls to ILLMToolCall
  538. var toolCalls []ILLMToolCall
  539. // Find max index
  540. maxIdx := -1
  541. for idx := range accToolCalls {
  542. if idx > maxIdx {
  543. maxIdx = idx
  544. }
  545. }
  546. for i := 0; i <= maxIdx; i++ {
  547. if atc, ok := accToolCalls[i]; ok {
  548. var args map[string]interface{}
  549. rawArgs := atc.RawArguments.String()
  550. if len(rawArgs) > 0 {
  551. if err := json.Unmarshal([]byte(rawArgs), &args); err != nil {
  552. log.Errorf("Failed to unmarshal arguments for tool %s: %v. Raw: %s", atc.Name, err, rawArgs)
  553. args = make(map[string]interface{})
  554. }
  555. } else {
  556. args = make(map[string]interface{})
  557. }
  558. toolCalls = append(toolCalls, &SLLMToolCall{
  559. Id: atc.Id,
  560. Function: SLLMFunctionCall{
  561. Name: atc.Name,
  562. Arguments: args,
  563. },
  564. })
  565. }
  566. }
  567. log.Infof("Got %d tool calls from Phase 1", len(toolCalls))
  568. toolCallRecords, toolMessages, err := processToolCalls(ctx, toolCalls, accumulatedReasoning.String(), accumulatedContent.String(), mcpClient, llmClient)
  569. if err != nil {
  570. return nil, errors.Wrap(err, "process tool calls")
  571. }
  572. // 将工具调用相关的消息加入历史
  573. messages = append(messages, toolMessages...)
  574. log.Infof("Phase 2: Streaming Response...")
  575. var finalAnswer strings.Builder
  576. err = llmClient.ChatStream(ctx, mcp, messages, tools, func(chunk ILLMChatResponse) error {
  577. content := chunk.GetContent()
  578. if len(content) > 0 {
  579. // 聚合最终答案
  580. finalAnswer.WriteString(content)
  581. // 实时流式输出
  582. if onStream != nil {
  583. if err := onStream(content); err != nil {
  584. return err
  585. }
  586. }
  587. }
  588. return nil
  589. })
  590. if err != nil {
  591. return nil, errors.Wrap(err, "phase 2 stream error")
  592. }
  593. return &api.MCPAgentResponse{
  594. Success: true,
  595. Answer: finalAnswer.String(),
  596. ToolCalls: toolCallRecords,
  597. }, nil
  598. }
  599. // buildSystemPrompt 构建系统提示词
  600. func buildSystemPrompt() string {
  601. return api.MCP_AGENT_SYSTEM_PROMPT
  602. }
  603. func processHistoryMessages(
  604. history []api.MCPAgentChatMessage,
  605. llmClient ILLMClient,
  606. maxUserChars int,
  607. maxAssistantChars int,
  608. ) []ILLMChatMessage {
  609. if len(history) == 0 {
  610. return []ILLMChatMessage{}
  611. }
  612. var userChars, assistantChars int
  613. processedMessages := make([]ILLMChatMessage, 0)
  614. // 从最新的消息开始遍历,保留最新消息,丢弃最旧消息
  615. for i := len(history) - 1; i >= 0; i-- {
  616. msg := history[i]
  617. msgChars := len(msg.Content)
  618. switch msg.Role {
  619. case "user":
  620. if userChars+msgChars > maxUserChars {
  621. break
  622. }
  623. userChars += msgChars
  624. processedMessages = append(processedMessages, llmClient.NewUserMessage(msg.Content))
  625. case "assistant":
  626. if assistantChars+msgChars > maxAssistantChars {
  627. break
  628. }
  629. assistantChars += msgChars
  630. if len(msg.Content) > 0 {
  631. processedMessages = append(processedMessages, llmClient.NewAssistantMessage(msg.Content))
  632. }
  633. }
  634. }
  635. for i, j := 0, len(processedMessages)-1; i < j; i, j = i+1, j-1 {
  636. processedMessages[i], processedMessages[j] = processedMessages[j], processedMessages[i]
  637. }
  638. return processedMessages
  639. }
  640. // processToolCalls 处理工具调用
  641. func processToolCalls(
  642. ctx context.Context,
  643. toolCalls []ILLMToolCall,
  644. reasoningContent, content string,
  645. mcpClient *utils.MCPClient,
  646. llmClient ILLMClient,
  647. ) ([]api.MCPAgentToolCallRecord, []ILLMChatMessage, error) {
  648. toolCallRecords := make([]api.MCPAgentToolCallRecord, 0)
  649. messagesToAdd := make([]ILLMChatMessage, 0)
  650. // 使用带 reasoning_content 的 assistant 消息,满足 DeepSeek thinking mode + tool calls 要求
  651. messagesToAdd = append(messagesToAdd, llmClient.NewAssistantMessageWithToolCallsAndReasoning(reasoningContent, content, toolCalls))
  652. // 执行每个工具调用
  653. for _, tc := range toolCalls {
  654. fc := tc.GetFunction()
  655. toolName := fc.GetName()
  656. arguments := fc.GetArguments()
  657. if arguments == nil {
  658. arguments = make(map[string]interface{})
  659. }
  660. log.Infof("Calling tool: %s with arguments: %v", toolName, arguments)
  661. // 调用 MCP 工具
  662. result, err := mcpClient.CallTool(ctx, toolName, arguments)
  663. resultText := utils.FormatToolResult(toolName, result, err)
  664. log.Infoln("Get result from mcp query", resultText)
  665. toolCallRecords = append(toolCallRecords, api.MCPAgentToolCallRecord{
  666. Id: tc.GetId(),
  667. ToolName: toolName,
  668. Arguments: arguments,
  669. Result: resultText,
  670. })
  671. // 将工具执行结果加入历史
  672. messagesToAdd = append(messagesToAdd, llmClient.NewToolMessage(tc.GetId(), toolName, resultText))
  673. }
  674. return toolCallRecords, messagesToAdd, nil
  675. }