| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626 |
- package llm_client
- import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "strings"
- "time"
- "github.com/mark3labs/mcp-go/mcp"
- "yunion.io/x/pkg/errors"
- api "yunion.io/x/onecloud/pkg/apis/llm"
- "yunion.io/x/onecloud/pkg/llm/models"
- )
- func init() {
- models.RegisterLLMClientDriver(newOllama())
- }
- type ollama struct{}
- func newOllama() models.ILLMClient {
- return new(ollama)
- }
- func (o *ollama) GetType() api.LLMClientType {
- return api.LLM_CLIENT_OLLAMA
- }
- func buildOllamaModelsURL(endpoint string) (string, error) {
- endpoint = strings.TrimSpace(endpoint)
- if endpoint == "" {
- return "", errors.Error("endpoint is empty")
- }
- baseURL, err := url.Parse(endpoint)
- if err != nil {
- return "", errors.Wrapf(err, "invalid endpoint URL %s", endpoint)
- }
- baseURL.RawQuery = ""
- baseURL.Fragment = ""
- path := strings.TrimRight(baseURL.Path, "/")
- switch {
- case path == "":
- baseURL.Path = "/v1/models"
- case strings.HasSuffix(path, "/v1/models"):
- baseURL.Path = path
- case strings.HasSuffix(path, "/v1"):
- baseURL.Path = path + "/models"
- default:
- baseURL.Path = path + "/v1/models"
- }
- return baseURL.String(), nil
- }
- func listOllamaModelsWithClient(ctx context.Context, client *http.Client, endpoint string) ([]string, error) {
- modelsURL, err := buildOllamaModelsURL(endpoint)
- if err != nil {
- return nil, err
- }
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
- if err != nil {
- return nil, errors.Wrap(err, "create request")
- }
- req.Header.Set("Accept", "application/json")
- resp, err := client.Do(req)
- if err != nil {
- return nil, errors.Wrap(err, "do request")
- }
- defer resp.Body.Close()
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, errors.Wrap(err, "read response body")
- }
- if resp.StatusCode != http.StatusOK {
- return nil, errors.Errorf("unexpected status code %d: %s", resp.StatusCode, string(body))
- }
- var modelResp OllamaModelsResponse
- if err := json.Unmarshal(body, &modelResp); err != nil {
- return nil, errors.Wrapf(err, "decode response: %s", string(body))
- }
- ret := make([]string, 0, len(modelResp.Data))
- for _, model := range modelResp.Data {
- name := strings.TrimSpace(model.Name)
- if name == "" {
- name = strings.TrimSpace(model.ID)
- }
- if name == "" {
- continue
- }
- ret = append(ret, name)
- }
- return ret, nil
- }
- func (o *ollama) ListModels(ctx context.Context, endpoint string) ([]string, error) {
- client := &http.Client{
- Timeout: 30 * time.Second,
- Transport: &http.Transport{
- MaxIdleConns: 100,
- MaxIdleConnsPerHost: 10,
- IdleConnTimeout: 90 * time.Second,
- },
- }
- return listOllamaModelsWithClient(ctx, client, endpoint)
- }
- func convertMessages(messages interface{}) ([]OllamaChatMessage, error) {
- // 转换 messages
- var ollamaMessages []OllamaChatMessage
- if msgs, ok := messages.([]OllamaChatMessage); ok {
- ollamaMessages = msgs
- } else if msgs, ok := messages.([]models.ILLMChatMessage); ok {
- ollamaMessages = make([]OllamaChatMessage, len(msgs))
- for i, msg := range msgs {
- // 如果 msg 已经是 *OllamaChatMessage,直接解引用使用
- if ollamaMsg, ok := msg.(*OllamaChatMessage); ok {
- ollamaMessages[i] = *ollamaMsg
- } else {
- // 否则通过接口方法获取
- ollamaMessages[i] = OllamaChatMessage{
- Role: msg.GetRole(),
- Content: msg.GetContent(),
- }
- // 转换工具调用
- if toolCalls := msg.GetToolCalls(); len(toolCalls) > 0 {
- ollamaMessages[i].ToolCalls = make([]OllamaToolCall, len(toolCalls))
- for j, tc := range toolCalls {
- fc := tc.GetFunction()
- ollamaMessages[i].ToolCalls[j] = OllamaToolCall{
- Function: OllamaFunctionCall{
- Name: fc.GetName(),
- Arguments: fc.GetArguments(),
- },
- }
- }
- }
- }
- }
- } else if msgs, ok := messages.([]interface{}); ok {
- ollamaMessages = make([]OllamaChatMessage, 0, len(msgs))
- for _, msg := range msgs {
- if m, ok := msg.(OllamaChatMessage); ok {
- ollamaMessages = append(ollamaMessages, m)
- } else if m, ok := msg.(models.ILLMChatMessage); ok {
- ollamaMessages = append(ollamaMessages, OllamaChatMessage{
- Role: m.GetRole(),
- Content: m.GetContent(),
- })
- }
- }
- } else {
- return nil, errors.Error("invalid messages type, expected []OllamaChatMessage or []ILLMChatMessage")
- }
- return ollamaMessages, nil
- }
- func convertTool(tools interface{}) ([]OllamaTool, error) {
- // 转换 tools
- var ollamaTools []OllamaTool
- if ts, ok := tools.([]OllamaTool); ok {
- ollamaTools = ts
- } else if ts, ok := tools.([]models.ILLMTool); ok {
- ollamaTools = make([]OllamaTool, len(ts))
- for i, t := range ts {
- tf := t.GetFunction()
- ollamaTools[i] = OllamaTool{
- Type: t.GetType(),
- Function: OllamaToolFunction{
- Name: tf.GetName(),
- Description: tf.GetDescription(),
- Parameters: tf.GetParameters(),
- },
- }
- }
- } else if ts, ok := tools.([]interface{}); ok && ts != nil {
- ollamaTools = make([]OllamaTool, 0, len(ts))
- for _, tool := range ts {
- if t, ok := tool.(OllamaTool); ok {
- ollamaTools = append(ollamaTools, t)
- } else if t, ok := tool.(models.ILLMTool); ok {
- tf := t.GetFunction()
- ollamaTools = append(ollamaTools, OllamaTool{
- Type: t.GetType(),
- Function: OllamaToolFunction{
- Name: tf.GetName(),
- Description: tf.GetDescription(),
- Parameters: tf.GetParameters(),
- },
- })
- }
- }
- } else if tools == nil {
- ollamaTools = nil
- } else {
- return nil, errors.Error("invalid tools type, expected []OllamaTool or []ILLMTool or nil")
- }
- return ollamaTools, nil
- }
- func initRequestClient(ctx context.Context, endpoint, model string, stream bool, messages []OllamaChatMessage, tools []OllamaTool) (*http.Request, *http.Client, error) {
- req := OllamaChatRequest{
- Model: model,
- Messages: messages,
- Tools: tools,
- Stream: stream,
- }
- reqBody, err := json.Marshal(req)
- if err != nil {
- return nil, nil, errors.Wrap(err, "marshal request")
- }
- // 规范化 endpoint,确保以 / 结尾
- endpoint = strings.TrimSuffix(endpoint, "/")
- baseURL, err := url.Parse(endpoint)
- if err != nil {
- return nil, nil, errors.Wrapf(err, "invalid endpoint URL %s", endpoint)
- }
- // 构建完整的 URL
- apiURL := baseURL.JoinPath("/api/chat")
- httpReq, err := http.NewRequestWithContext(ctx, "POST", apiURL.String(), bytes.NewReader(reqBody))
- if err != nil {
- return nil, nil, errors.Wrap(err, "create request")
- }
- httpReq.Header.Set("Content-Type", "application/json")
- httpReq.Header.Set("Accept", "application/json")
- client := &http.Client{
- Timeout: 300 * time.Second,
- Transport: &http.Transport{
- MaxIdleConns: 100,
- MaxIdleConnsPerHost: 10,
- IdleConnTimeout: 90 * time.Second,
- },
- }
- return httpReq, client, nil
- }
- func (o *ollama) Chat(ctx context.Context, mcpAgent *models.SMCPAgent, messages interface{}, tools interface{}) (models.ILLMChatResponse, error) {
- ollamaMessages, err := convertMessages(messages)
- if err != nil {
- return nil, err
- }
- ollamaTools, err := convertTool(tools)
- if err != nil {
- return nil, err
- }
- httpReq, client, err := initRequestClient(ctx, mcpAgent.LLMUrl, mcpAgent.Model, false, ollamaMessages, ollamaTools)
- // 调用底层方法
- return o.doChatRequest(ctx, httpReq, client)
- }
- // doChatRequest 执行聊天请求
- func (o *ollama) doChatRequest(ctx context.Context, httpReq *http.Request, client *http.Client) (*OllamaChatResponse, error) {
- resp, err := client.Do(httpReq)
- if err != nil {
- return nil, errors.Wrap(err, "do request")
- }
- defer resp.Body.Close()
- // 读取响应体以便错误处理
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, errors.Wrap(err, "read response body")
- }
- if resp.StatusCode != http.StatusOK {
- return nil, errors.Errorf("unexpected status code %d: %s", resp.StatusCode, string(body))
- }
- var chatResp OllamaChatResponse
- if err := json.Unmarshal(body, &chatResp); err != nil {
- return nil, errors.Wrapf(err, "decode response: %s", string(body))
- }
- return &chatResp, nil
- }
- func (o *ollama) NewUserMessage(content string) models.ILLMChatMessage {
- return &OllamaChatMessage{
- Role: "user",
- Content: content,
- }
- }
- func (o *ollama) NewAssistantMessage(content string) models.ILLMChatMessage {
- return &OllamaChatMessage{
- Role: "assistant",
- Content: content,
- }
- }
- func (o *ollama) NewAssistantMessageWithToolCalls(toolCalls []models.ILLMToolCall) models.ILLMChatMessage {
- // to ollama tool calls
- ollamaToolCalls := make([]OllamaToolCall, len(toolCalls))
- for i, tc := range toolCalls {
- if otc, ok := tc.(*OllamaToolCall); ok {
- ollamaToolCalls[i] = *otc
- } else {
- fc := tc.GetFunction()
- ollamaToolCalls[i] = OllamaToolCall{
- Function: OllamaFunctionCall{
- Name: fc.GetName(),
- Arguments: fc.GetArguments(),
- },
- }
- }
- }
- return &OllamaChatMessage{
- Role: "assistant",
- Content: "",
- ToolCalls: ollamaToolCalls,
- }
- }
- func (o *ollama) NewAssistantMessageWithToolCallsAndReasoning(reasoningContent, content string, toolCalls []models.ILLMToolCall) models.ILLMChatMessage {
- ollamaToolCalls := make([]OllamaToolCall, len(toolCalls))
- for i, tc := range toolCalls {
- if otc, ok := tc.(*OllamaToolCall); ok {
- ollamaToolCalls[i] = *otc
- } else {
- fc := tc.GetFunction()
- ollamaToolCalls[i] = OllamaToolCall{
- Function: OllamaFunctionCall{
- Name: fc.GetName(),
- Arguments: fc.GetArguments(),
- },
- }
- }
- }
- _ = reasoningContent // Ollama does not use reasoning_content; ignore for compatibility
- return &OllamaChatMessage{
- Role: "assistant",
- Content: content,
- ToolCalls: ollamaToolCalls,
- }
- }
- func (o *ollama) NewToolMessage(toolId string, toolName string, content string) models.ILLMChatMessage {
- return &OllamaChatMessage{
- Role: "tool",
- Content: fmt.Sprintf("[%s] %s", toolName, content),
- }
- }
- func (o *ollama) NewSystemMessage(content string) models.ILLMChatMessage {
- return &OllamaChatMessage{
- Role: "system",
- Content: content,
- }
- }
- func (o *ollama) ConvertMCPTools(mcpTools []mcp.Tool) []models.ILLMTool {
- tools := make([]models.ILLMTool, len(mcpTools))
- for i, t := range mcpTools {
- var params map[string]interface{}
- if t.RawInputSchema != nil {
- _ = json.Unmarshal(t.RawInputSchema, ¶ms)
- } else {
- schemaBytes, _ := json.Marshal(t.InputSchema)
- _ = json.Unmarshal(schemaBytes, ¶ms)
- }
- tools[i] = &OllamaTool{
- Type: "function",
- Function: OllamaToolFunction{
- Name: t.Name,
- Description: t.Description,
- Parameters: params,
- },
- }
- }
- return tools
- }
- // OllamaChatMessage 表示聊天消息
- // 实现 ILLMChatMessage 接口
- type OllamaChatMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
- ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"`
- }
- // GetRole 实现 ILLMChatMessage 接口
- func (m OllamaChatMessage) GetRole() string {
- return m.Role
- }
- // GetContent 实现 ILLMChatMessage 接口
- func (m OllamaChatMessage) GetContent() string {
- return m.Content
- }
- // GetToolCalls 实现 ILLMChatMessage 接口
- func (m OllamaChatMessage) GetToolCalls() []models.ILLMToolCall {
- if len(m.ToolCalls) == 0 {
- return nil
- }
- toolCalls := make([]models.ILLMToolCall, len(m.ToolCalls))
- for i := range m.ToolCalls {
- // 创建副本以避免引用问题
- tc := m.ToolCalls[i]
- toolCalls[i] = &tc
- }
- return toolCalls
- }
- // OllamaToolCall 表示工具调用
- // 实现 ILLMToolCall 接口
- type OllamaToolCall struct {
- Index int `json:"-"`
- Function OllamaFunctionCall `json:"function"`
- }
- // GetFunction 实现 ILLMToolCall 接口
- func (tc *OllamaToolCall) GetFunction() models.ILLMFunctionCall {
- return &tc.Function
- }
- // GetIndex 实现 ILLMToolCall 接口
- func (tc *OllamaToolCall) GetIndex() int {
- return tc.Index
- }
- // GetId 实现 ILLMToolCall 接口
- func (tc *OllamaToolCall) GetId() string {
- return ""
- }
- // OllamaFunctionCall 表示函数调用详情
- // 实现 ILLMFunctionCall 接口
- type OllamaFunctionCall struct {
- Name string `json:"name"`
- Arguments map[string]interface{} `json:"arguments"`
- }
- // GetName 实现 ILLMFunctionCall 接口
- func (fc *OllamaFunctionCall) GetName() string {
- return fc.Name
- }
- // GetRawArguments 实现 ILLMFunctionCall 接口
- func (fc *OllamaFunctionCall) GetRawArguments() string {
- if fc.Arguments == nil {
- return ""
- }
- bytes, _ := json.Marshal(fc.Arguments)
- return string(bytes)
- }
- // GetArguments 实现 ILLMFunctionCall 接口
- func (fc *OllamaFunctionCall) GetArguments() map[string]interface{} {
- return fc.Arguments
- }
- // OllamaTool 表示工具定义
- // 实现 ILLMTool 接口
- type OllamaTool struct {
- Type string `json:"type"`
- Function OllamaToolFunction `json:"function"`
- }
- // GetType 实现 ILLMTool 接口
- func (t OllamaTool) GetType() string {
- return t.Type
- }
- // GetFunction 实现 ILLMTool 接口
- func (t OllamaTool) GetFunction() models.ILLMToolFunction {
- return &t.Function
- }
- // OllamaToolFunction 表示工具函数定义
- // 实现 ILLMToolFunction 接口
- type OllamaToolFunction struct {
- Name string `json:"name"`
- Description string `json:"description"`
- Parameters map[string]interface{} `json:"parameters"`
- }
- // GetName 实现 ILLMToolFunction 接口
- func (tf *OllamaToolFunction) GetName() string {
- return tf.Name
- }
- // GetDescription 实现 ILLMToolFunction 接口
- func (tf *OllamaToolFunction) GetDescription() string {
- return tf.Description
- }
- // GetParameters 实现 ILLMToolFunction 接口
- func (tf *OllamaToolFunction) GetParameters() map[string]interface{} {
- return tf.Parameters
- }
- // OllamaChatRequest 表示聊天请求
- type OllamaChatRequest struct {
- Model string `json:"model"`
- Messages []OllamaChatMessage `json:"messages"`
- Tools []OllamaTool `json:"tools,omitempty"`
- Stream bool `json:"stream"`
- }
- type OllamaModelsResponse struct {
- Object string `json:"object,omitempty"`
- Data []OllamaModelEntry `json:"data"`
- }
- type OllamaModelEntry struct {
- ID string `json:"id"`
- Name string `json:"name,omitempty"`
- Object string `json:"object,omitempty"`
- OwnedBy string `json:"owned_by,omitempty"`
- }
- // OllamaChatResponse 表示聊天响应
- type OllamaChatResponse struct {
- Model string `json:"model"`
- CreatedAt string `json:"created_at"`
- Message OllamaChatMessage `json:"message"`
- Done bool `json:"done"`
- DoneReason string `json:"done_reason,omitempty"`
- }
- // GetContent 获取响应内容
- func (r *OllamaChatResponse) GetContent() string {
- return r.Message.Content
- }
- // GetReasoningContent 获取推理内容(Ollama 不支持,返回空)
- func (r *OllamaChatResponse) GetReasoningContent() string {
- return ""
- }
- // HasToolCalls 检查响应是否包含工具调用
- func (r *OllamaChatResponse) HasToolCalls() bool {
- return len(r.Message.ToolCalls) > 0
- }
- // GetToolCalls 获取工具调用列表
- func (r *OllamaChatResponse) GetToolCalls() []models.ILLMToolCall {
- if len(r.Message.ToolCalls) == 0 {
- return nil
- }
- toolCalls := make([]models.ILLMToolCall, len(r.Message.ToolCalls))
- for i := range r.Message.ToolCalls {
- r.Message.ToolCalls[i].Index = i
- toolCalls[i] = &r.Message.ToolCalls[i]
- }
- return toolCalls
- }
- func (o *ollama) ChatStream(ctx context.Context, mcpAgent *models.SMCPAgent, messages interface{}, tools interface{}, onChunk func(models.ILLMChatResponse) error) error {
- ollamaMessages, err := convertMessages(messages)
- if err != nil {
- return err
- }
- ollamaTools, err := convertTool(tools)
- if err != nil {
- return err
- }
- httpReq, client, err := initRequestClient(ctx, mcpAgent.LLMUrl, mcpAgent.Model, true, ollamaMessages, ollamaTools)
- return o.doChatStreamRequest(ctx, httpReq, client, onChunk)
- }
- func (o *ollama) doChatStreamRequest(ctx context.Context, httpReq *http.Request, client *http.Client, onChunk func(models.ILLMChatResponse) error) error {
- resp, err := client.Do(httpReq)
- if err != nil {
- return errors.Wrap(err, "do request")
- }
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- body, _ := io.ReadAll(resp.Body)
- return errors.Errorf("unexpected status code %d: %s", resp.StatusCode, string(body))
- }
- decoder := json.NewDecoder(resp.Body)
- for {
- var chunk OllamaChatResponse
- if err := decoder.Decode(&chunk); err != nil {
- if err == io.EOF {
- break
- }
- return errors.Wrap(err, "decode stream chunk")
- }
- if onChunk != nil {
- if err := onChunk(&chunk); err != nil {
- return errors.Wrap(err, "process chunk")
- }
- }
- if chunk.Done {
- break
- }
- }
- return nil
- }
|