ollama.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. package llm_client
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/url"
  10. "strings"
  11. "time"
  12. "github.com/mark3labs/mcp-go/mcp"
  13. "yunion.io/x/pkg/errors"
  14. api "yunion.io/x/onecloud/pkg/apis/llm"
  15. "yunion.io/x/onecloud/pkg/llm/models"
  16. )
  17. func init() {
  18. models.RegisterLLMClientDriver(newOllama())
  19. }
  20. type ollama struct{}
  21. func newOllama() models.ILLMClient {
  22. return new(ollama)
  23. }
  24. func (o *ollama) GetType() api.LLMClientType {
  25. return api.LLM_CLIENT_OLLAMA
  26. }
  27. func buildOllamaModelsURL(endpoint string) (string, error) {
  28. endpoint = strings.TrimSpace(endpoint)
  29. if endpoint == "" {
  30. return "", errors.Error("endpoint is empty")
  31. }
  32. baseURL, err := url.Parse(endpoint)
  33. if err != nil {
  34. return "", errors.Wrapf(err, "invalid endpoint URL %s", endpoint)
  35. }
  36. baseURL.RawQuery = ""
  37. baseURL.Fragment = ""
  38. path := strings.TrimRight(baseURL.Path, "/")
  39. switch {
  40. case path == "":
  41. baseURL.Path = "/v1/models"
  42. case strings.HasSuffix(path, "/v1/models"):
  43. baseURL.Path = path
  44. case strings.HasSuffix(path, "/v1"):
  45. baseURL.Path = path + "/models"
  46. default:
  47. baseURL.Path = path + "/v1/models"
  48. }
  49. return baseURL.String(), nil
  50. }
  51. func listOllamaModelsWithClient(ctx context.Context, client *http.Client, endpoint string) ([]string, error) {
  52. modelsURL, err := buildOllamaModelsURL(endpoint)
  53. if err != nil {
  54. return nil, err
  55. }
  56. req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
  57. if err != nil {
  58. return nil, errors.Wrap(err, "create request")
  59. }
  60. req.Header.Set("Accept", "application/json")
  61. resp, err := client.Do(req)
  62. if err != nil {
  63. return nil, errors.Wrap(err, "do request")
  64. }
  65. defer resp.Body.Close()
  66. body, err := io.ReadAll(resp.Body)
  67. if err != nil {
  68. return nil, errors.Wrap(err, "read response body")
  69. }
  70. if resp.StatusCode != http.StatusOK {
  71. return nil, errors.Errorf("unexpected status code %d: %s", resp.StatusCode, string(body))
  72. }
  73. var modelResp OllamaModelsResponse
  74. if err := json.Unmarshal(body, &modelResp); err != nil {
  75. return nil, errors.Wrapf(err, "decode response: %s", string(body))
  76. }
  77. ret := make([]string, 0, len(modelResp.Data))
  78. for _, model := range modelResp.Data {
  79. name := strings.TrimSpace(model.Name)
  80. if name == "" {
  81. name = strings.TrimSpace(model.ID)
  82. }
  83. if name == "" {
  84. continue
  85. }
  86. ret = append(ret, name)
  87. }
  88. return ret, nil
  89. }
  90. func (o *ollama) ListModels(ctx context.Context, endpoint string) ([]string, error) {
  91. client := &http.Client{
  92. Timeout: 30 * time.Second,
  93. Transport: &http.Transport{
  94. MaxIdleConns: 100,
  95. MaxIdleConnsPerHost: 10,
  96. IdleConnTimeout: 90 * time.Second,
  97. },
  98. }
  99. return listOllamaModelsWithClient(ctx, client, endpoint)
  100. }
  101. func convertMessages(messages interface{}) ([]OllamaChatMessage, error) {
  102. // 转换 messages
  103. var ollamaMessages []OllamaChatMessage
  104. if msgs, ok := messages.([]OllamaChatMessage); ok {
  105. ollamaMessages = msgs
  106. } else if msgs, ok := messages.([]models.ILLMChatMessage); ok {
  107. ollamaMessages = make([]OllamaChatMessage, len(msgs))
  108. for i, msg := range msgs {
  109. // 如果 msg 已经是 *OllamaChatMessage,直接解引用使用
  110. if ollamaMsg, ok := msg.(*OllamaChatMessage); ok {
  111. ollamaMessages[i] = *ollamaMsg
  112. } else {
  113. // 否则通过接口方法获取
  114. ollamaMessages[i] = OllamaChatMessage{
  115. Role: msg.GetRole(),
  116. Content: msg.GetContent(),
  117. }
  118. // 转换工具调用
  119. if toolCalls := msg.GetToolCalls(); len(toolCalls) > 0 {
  120. ollamaMessages[i].ToolCalls = make([]OllamaToolCall, len(toolCalls))
  121. for j, tc := range toolCalls {
  122. fc := tc.GetFunction()
  123. ollamaMessages[i].ToolCalls[j] = OllamaToolCall{
  124. Function: OllamaFunctionCall{
  125. Name: fc.GetName(),
  126. Arguments: fc.GetArguments(),
  127. },
  128. }
  129. }
  130. }
  131. }
  132. }
  133. } else if msgs, ok := messages.([]interface{}); ok {
  134. ollamaMessages = make([]OllamaChatMessage, 0, len(msgs))
  135. for _, msg := range msgs {
  136. if m, ok := msg.(OllamaChatMessage); ok {
  137. ollamaMessages = append(ollamaMessages, m)
  138. } else if m, ok := msg.(models.ILLMChatMessage); ok {
  139. ollamaMessages = append(ollamaMessages, OllamaChatMessage{
  140. Role: m.GetRole(),
  141. Content: m.GetContent(),
  142. })
  143. }
  144. }
  145. } else {
  146. return nil, errors.Error("invalid messages type, expected []OllamaChatMessage or []ILLMChatMessage")
  147. }
  148. return ollamaMessages, nil
  149. }
  150. func convertTool(tools interface{}) ([]OllamaTool, error) {
  151. // 转换 tools
  152. var ollamaTools []OllamaTool
  153. if ts, ok := tools.([]OllamaTool); ok {
  154. ollamaTools = ts
  155. } else if ts, ok := tools.([]models.ILLMTool); ok {
  156. ollamaTools = make([]OllamaTool, len(ts))
  157. for i, t := range ts {
  158. tf := t.GetFunction()
  159. ollamaTools[i] = OllamaTool{
  160. Type: t.GetType(),
  161. Function: OllamaToolFunction{
  162. Name: tf.GetName(),
  163. Description: tf.GetDescription(),
  164. Parameters: tf.GetParameters(),
  165. },
  166. }
  167. }
  168. } else if ts, ok := tools.([]interface{}); ok && ts != nil {
  169. ollamaTools = make([]OllamaTool, 0, len(ts))
  170. for _, tool := range ts {
  171. if t, ok := tool.(OllamaTool); ok {
  172. ollamaTools = append(ollamaTools, t)
  173. } else if t, ok := tool.(models.ILLMTool); ok {
  174. tf := t.GetFunction()
  175. ollamaTools = append(ollamaTools, OllamaTool{
  176. Type: t.GetType(),
  177. Function: OllamaToolFunction{
  178. Name: tf.GetName(),
  179. Description: tf.GetDescription(),
  180. Parameters: tf.GetParameters(),
  181. },
  182. })
  183. }
  184. }
  185. } else if tools == nil {
  186. ollamaTools = nil
  187. } else {
  188. return nil, errors.Error("invalid tools type, expected []OllamaTool or []ILLMTool or nil")
  189. }
  190. return ollamaTools, nil
  191. }
  192. func initRequestClient(ctx context.Context, endpoint, model string, stream bool, messages []OllamaChatMessage, tools []OllamaTool) (*http.Request, *http.Client, error) {
  193. req := OllamaChatRequest{
  194. Model: model,
  195. Messages: messages,
  196. Tools: tools,
  197. Stream: stream,
  198. }
  199. reqBody, err := json.Marshal(req)
  200. if err != nil {
  201. return nil, nil, errors.Wrap(err, "marshal request")
  202. }
  203. // 规范化 endpoint,确保以 / 结尾
  204. endpoint = strings.TrimSuffix(endpoint, "/")
  205. baseURL, err := url.Parse(endpoint)
  206. if err != nil {
  207. return nil, nil, errors.Wrapf(err, "invalid endpoint URL %s", endpoint)
  208. }
  209. // 构建完整的 URL
  210. apiURL := baseURL.JoinPath("/api/chat")
  211. httpReq, err := http.NewRequestWithContext(ctx, "POST", apiURL.String(), bytes.NewReader(reqBody))
  212. if err != nil {
  213. return nil, nil, errors.Wrap(err, "create request")
  214. }
  215. httpReq.Header.Set("Content-Type", "application/json")
  216. httpReq.Header.Set("Accept", "application/json")
  217. client := &http.Client{
  218. Timeout: 300 * time.Second,
  219. Transport: &http.Transport{
  220. MaxIdleConns: 100,
  221. MaxIdleConnsPerHost: 10,
  222. IdleConnTimeout: 90 * time.Second,
  223. },
  224. }
  225. return httpReq, client, nil
  226. }
  227. func (o *ollama) Chat(ctx context.Context, mcpAgent *models.SMCPAgent, messages interface{}, tools interface{}) (models.ILLMChatResponse, error) {
  228. ollamaMessages, err := convertMessages(messages)
  229. if err != nil {
  230. return nil, err
  231. }
  232. ollamaTools, err := convertTool(tools)
  233. if err != nil {
  234. return nil, err
  235. }
  236. httpReq, client, err := initRequestClient(ctx, mcpAgent.LLMUrl, mcpAgent.Model, false, ollamaMessages, ollamaTools)
  237. // 调用底层方法
  238. return o.doChatRequest(ctx, httpReq, client)
  239. }
  240. // doChatRequest 执行聊天请求
  241. func (o *ollama) doChatRequest(ctx context.Context, httpReq *http.Request, client *http.Client) (*OllamaChatResponse, error) {
  242. resp, err := client.Do(httpReq)
  243. if err != nil {
  244. return nil, errors.Wrap(err, "do request")
  245. }
  246. defer resp.Body.Close()
  247. // 读取响应体以便错误处理
  248. body, err := io.ReadAll(resp.Body)
  249. if err != nil {
  250. return nil, errors.Wrap(err, "read response body")
  251. }
  252. if resp.StatusCode != http.StatusOK {
  253. return nil, errors.Errorf("unexpected status code %d: %s", resp.StatusCode, string(body))
  254. }
  255. var chatResp OllamaChatResponse
  256. if err := json.Unmarshal(body, &chatResp); err != nil {
  257. return nil, errors.Wrapf(err, "decode response: %s", string(body))
  258. }
  259. return &chatResp, nil
  260. }
  261. func (o *ollama) NewUserMessage(content string) models.ILLMChatMessage {
  262. return &OllamaChatMessage{
  263. Role: "user",
  264. Content: content,
  265. }
  266. }
  267. func (o *ollama) NewAssistantMessage(content string) models.ILLMChatMessage {
  268. return &OllamaChatMessage{
  269. Role: "assistant",
  270. Content: content,
  271. }
  272. }
  273. func (o *ollama) NewAssistantMessageWithToolCalls(toolCalls []models.ILLMToolCall) models.ILLMChatMessage {
  274. // to ollama tool calls
  275. ollamaToolCalls := make([]OllamaToolCall, len(toolCalls))
  276. for i, tc := range toolCalls {
  277. if otc, ok := tc.(*OllamaToolCall); ok {
  278. ollamaToolCalls[i] = *otc
  279. } else {
  280. fc := tc.GetFunction()
  281. ollamaToolCalls[i] = OllamaToolCall{
  282. Function: OllamaFunctionCall{
  283. Name: fc.GetName(),
  284. Arguments: fc.GetArguments(),
  285. },
  286. }
  287. }
  288. }
  289. return &OllamaChatMessage{
  290. Role: "assistant",
  291. Content: "",
  292. ToolCalls: ollamaToolCalls,
  293. }
  294. }
  295. func (o *ollama) NewAssistantMessageWithToolCallsAndReasoning(reasoningContent, content string, toolCalls []models.ILLMToolCall) models.ILLMChatMessage {
  296. ollamaToolCalls := make([]OllamaToolCall, len(toolCalls))
  297. for i, tc := range toolCalls {
  298. if otc, ok := tc.(*OllamaToolCall); ok {
  299. ollamaToolCalls[i] = *otc
  300. } else {
  301. fc := tc.GetFunction()
  302. ollamaToolCalls[i] = OllamaToolCall{
  303. Function: OllamaFunctionCall{
  304. Name: fc.GetName(),
  305. Arguments: fc.GetArguments(),
  306. },
  307. }
  308. }
  309. }
  310. _ = reasoningContent // Ollama does not use reasoning_content; ignore for compatibility
  311. return &OllamaChatMessage{
  312. Role: "assistant",
  313. Content: content,
  314. ToolCalls: ollamaToolCalls,
  315. }
  316. }
  317. func (o *ollama) NewToolMessage(toolId string, toolName string, content string) models.ILLMChatMessage {
  318. return &OllamaChatMessage{
  319. Role: "tool",
  320. Content: fmt.Sprintf("[%s] %s", toolName, content),
  321. }
  322. }
  323. func (o *ollama) NewSystemMessage(content string) models.ILLMChatMessage {
  324. return &OllamaChatMessage{
  325. Role: "system",
  326. Content: content,
  327. }
  328. }
  329. func (o *ollama) ConvertMCPTools(mcpTools []mcp.Tool) []models.ILLMTool {
  330. tools := make([]models.ILLMTool, len(mcpTools))
  331. for i, t := range mcpTools {
  332. var params map[string]interface{}
  333. if t.RawInputSchema != nil {
  334. _ = json.Unmarshal(t.RawInputSchema, &params)
  335. } else {
  336. schemaBytes, _ := json.Marshal(t.InputSchema)
  337. _ = json.Unmarshal(schemaBytes, &params)
  338. }
  339. tools[i] = &OllamaTool{
  340. Type: "function",
  341. Function: OllamaToolFunction{
  342. Name: t.Name,
  343. Description: t.Description,
  344. Parameters: params,
  345. },
  346. }
  347. }
  348. return tools
  349. }
  350. // OllamaChatMessage 表示聊天消息
  351. // 实现 ILLMChatMessage 接口
  352. type OllamaChatMessage struct {
  353. Role string `json:"role"`
  354. Content string `json:"content"`
  355. ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"`
  356. }
  357. // GetRole 实现 ILLMChatMessage 接口
  358. func (m OllamaChatMessage) GetRole() string {
  359. return m.Role
  360. }
  361. // GetContent 实现 ILLMChatMessage 接口
  362. func (m OllamaChatMessage) GetContent() string {
  363. return m.Content
  364. }
  365. // GetToolCalls 实现 ILLMChatMessage 接口
  366. func (m OllamaChatMessage) GetToolCalls() []models.ILLMToolCall {
  367. if len(m.ToolCalls) == 0 {
  368. return nil
  369. }
  370. toolCalls := make([]models.ILLMToolCall, len(m.ToolCalls))
  371. for i := range m.ToolCalls {
  372. // 创建副本以避免引用问题
  373. tc := m.ToolCalls[i]
  374. toolCalls[i] = &tc
  375. }
  376. return toolCalls
  377. }
  378. // OllamaToolCall 表示工具调用
  379. // 实现 ILLMToolCall 接口
  380. type OllamaToolCall struct {
  381. Index int `json:"-"`
  382. Function OllamaFunctionCall `json:"function"`
  383. }
  384. // GetFunction 实现 ILLMToolCall 接口
  385. func (tc *OllamaToolCall) GetFunction() models.ILLMFunctionCall {
  386. return &tc.Function
  387. }
  388. // GetIndex 实现 ILLMToolCall 接口
  389. func (tc *OllamaToolCall) GetIndex() int {
  390. return tc.Index
  391. }
  392. // GetId 实现 ILLMToolCall 接口
  393. func (tc *OllamaToolCall) GetId() string {
  394. return ""
  395. }
  396. // OllamaFunctionCall 表示函数调用详情
  397. // 实现 ILLMFunctionCall 接口
  398. type OllamaFunctionCall struct {
  399. Name string `json:"name"`
  400. Arguments map[string]interface{} `json:"arguments"`
  401. }
  402. // GetName 实现 ILLMFunctionCall 接口
  403. func (fc *OllamaFunctionCall) GetName() string {
  404. return fc.Name
  405. }
  406. // GetRawArguments 实现 ILLMFunctionCall 接口
  407. func (fc *OllamaFunctionCall) GetRawArguments() string {
  408. if fc.Arguments == nil {
  409. return ""
  410. }
  411. bytes, _ := json.Marshal(fc.Arguments)
  412. return string(bytes)
  413. }
  414. // GetArguments 实现 ILLMFunctionCall 接口
  415. func (fc *OllamaFunctionCall) GetArguments() map[string]interface{} {
  416. return fc.Arguments
  417. }
  418. // OllamaTool 表示工具定义
  419. // 实现 ILLMTool 接口
  420. type OllamaTool struct {
  421. Type string `json:"type"`
  422. Function OllamaToolFunction `json:"function"`
  423. }
  424. // GetType 实现 ILLMTool 接口
  425. func (t OllamaTool) GetType() string {
  426. return t.Type
  427. }
  428. // GetFunction 实现 ILLMTool 接口
  429. func (t OllamaTool) GetFunction() models.ILLMToolFunction {
  430. return &t.Function
  431. }
  432. // OllamaToolFunction 表示工具函数定义
  433. // 实现 ILLMToolFunction 接口
  434. type OllamaToolFunction struct {
  435. Name string `json:"name"`
  436. Description string `json:"description"`
  437. Parameters map[string]interface{} `json:"parameters"`
  438. }
  439. // GetName 实现 ILLMToolFunction 接口
  440. func (tf *OllamaToolFunction) GetName() string {
  441. return tf.Name
  442. }
  443. // GetDescription 实现 ILLMToolFunction 接口
  444. func (tf *OllamaToolFunction) GetDescription() string {
  445. return tf.Description
  446. }
  447. // GetParameters 实现 ILLMToolFunction 接口
  448. func (tf *OllamaToolFunction) GetParameters() map[string]interface{} {
  449. return tf.Parameters
  450. }
  451. // OllamaChatRequest 表示聊天请求
  452. type OllamaChatRequest struct {
  453. Model string `json:"model"`
  454. Messages []OllamaChatMessage `json:"messages"`
  455. Tools []OllamaTool `json:"tools,omitempty"`
  456. Stream bool `json:"stream"`
  457. }
  458. type OllamaModelsResponse struct {
  459. Object string `json:"object,omitempty"`
  460. Data []OllamaModelEntry `json:"data"`
  461. }
  462. type OllamaModelEntry struct {
  463. ID string `json:"id"`
  464. Name string `json:"name,omitempty"`
  465. Object string `json:"object,omitempty"`
  466. OwnedBy string `json:"owned_by,omitempty"`
  467. }
  468. // OllamaChatResponse 表示聊天响应
  469. type OllamaChatResponse struct {
  470. Model string `json:"model"`
  471. CreatedAt string `json:"created_at"`
  472. Message OllamaChatMessage `json:"message"`
  473. Done bool `json:"done"`
  474. DoneReason string `json:"done_reason,omitempty"`
  475. }
  476. // GetContent 获取响应内容
  477. func (r *OllamaChatResponse) GetContent() string {
  478. return r.Message.Content
  479. }
  480. // GetReasoningContent 获取推理内容(Ollama 不支持,返回空)
  481. func (r *OllamaChatResponse) GetReasoningContent() string {
  482. return ""
  483. }
  484. // HasToolCalls 检查响应是否包含工具调用
  485. func (r *OllamaChatResponse) HasToolCalls() bool {
  486. return len(r.Message.ToolCalls) > 0
  487. }
  488. // GetToolCalls 获取工具调用列表
  489. func (r *OllamaChatResponse) GetToolCalls() []models.ILLMToolCall {
  490. if len(r.Message.ToolCalls) == 0 {
  491. return nil
  492. }
  493. toolCalls := make([]models.ILLMToolCall, len(r.Message.ToolCalls))
  494. for i := range r.Message.ToolCalls {
  495. r.Message.ToolCalls[i].Index = i
  496. toolCalls[i] = &r.Message.ToolCalls[i]
  497. }
  498. return toolCalls
  499. }
  500. func (o *ollama) ChatStream(ctx context.Context, mcpAgent *models.SMCPAgent, messages interface{}, tools interface{}, onChunk func(models.ILLMChatResponse) error) error {
  501. ollamaMessages, err := convertMessages(messages)
  502. if err != nil {
  503. return err
  504. }
  505. ollamaTools, err := convertTool(tools)
  506. if err != nil {
  507. return err
  508. }
  509. httpReq, client, err := initRequestClient(ctx, mcpAgent.LLMUrl, mcpAgent.Model, true, ollamaMessages, ollamaTools)
  510. return o.doChatStreamRequest(ctx, httpReq, client, onChunk)
  511. }
  512. func (o *ollama) doChatStreamRequest(ctx context.Context, httpReq *http.Request, client *http.Client, onChunk func(models.ILLMChatResponse) error) error {
  513. resp, err := client.Do(httpReq)
  514. if err != nil {
  515. return errors.Wrap(err, "do request")
  516. }
  517. defer resp.Body.Close()
  518. if resp.StatusCode != http.StatusOK {
  519. body, _ := io.ReadAll(resp.Body)
  520. return errors.Errorf("unexpected status code %d: %s", resp.StatusCode, string(body))
  521. }
  522. decoder := json.NewDecoder(resp.Body)
  523. for {
  524. var chunk OllamaChatResponse
  525. if err := decoder.Decode(&chunk); err != nil {
  526. if err == io.EOF {
  527. break
  528. }
  529. return errors.Wrap(err, "decode stream chunk")
  530. }
  531. if onChunk != nil {
  532. if err := onChunk(&chunk); err != nil {
  533. return errors.Wrap(err, "process chunk")
  534. }
  535. }
  536. if chunk.Done {
  537. break
  538. }
  539. }
  540. return nil
  541. }