llm_sku.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. package models
  2. import (
  3. "context"
  4. "strings"
  5. "yunion.io/x/jsonutils"
  6. "yunion.io/x/log"
  7. "yunion.io/x/pkg/errors"
  8. "yunion.io/x/sqlchemy"
  9. imageapi "yunion.io/x/onecloud/pkg/apis/image"
  10. api "yunion.io/x/onecloud/pkg/apis/llm"
  11. "yunion.io/x/onecloud/pkg/cloudcommon/db"
  12. "yunion.io/x/onecloud/pkg/httperrors"
  13. "yunion.io/x/onecloud/pkg/mcclient"
  14. "yunion.io/x/onecloud/pkg/mcclient/auth"
  15. imagemodules "yunion.io/x/onecloud/pkg/mcclient/modules/image"
  16. mcclientoptions "yunion.io/x/onecloud/pkg/mcclient/options"
  17. "yunion.io/x/onecloud/pkg/util/stringutils2"
  18. )
  19. func init() {
  20. GetLLMSkuManager()
  21. }
  22. var llmSkuManager *SLLMSkuManager
  23. func GetLLMSkuManager() *SLLMSkuManager {
  24. if llmSkuManager != nil {
  25. return llmSkuManager
  26. }
  27. llmSkuManager = &SLLMSkuManager{
  28. SLLMSkuBaseManager: NewSLLMSkuBaseManager(
  29. SLLMSku{},
  30. "llm_skus_tbl",
  31. "llm_sku",
  32. "llm_skus",
  33. ),
  34. }
  35. llmSkuManager.SetVirtualObject(llmSkuManager)
  36. return llmSkuManager
  37. }
  38. type SLLMSkuManager struct {
  39. SLLMSkuBaseManager
  40. SMountedModelsResourceManager
  41. }
  42. type SLLMSku struct {
  43. SLLMSkuBase
  44. SMountedModelsResource
  45. // primary image id of primary container
  46. LLMImageId string `width:"128" charset:"ascii" nullable:"false" list:"user" create:"required" update:"user"`
  47. LLMType string `width:"128" charset:"ascii" nullable:"false" list:"user" create:"required"`
  48. LLMSpec *api.LLMSpec `json:"llm_spec" length:"long" list:"user" create:"optional" update:"user"`
  49. }
  50. func (man *SLLMSkuManager) ListItemFilter(
  51. ctx context.Context,
  52. q *sqlchemy.SQuery,
  53. userCred mcclient.TokenCredential,
  54. input api.LLMSkuListInput,
  55. ) (*sqlchemy.SQuery, error) {
  56. var err error
  57. q, err = man.SLLMSkuBaseManager.ListItemFilter(ctx, q, userCred, input.SharableVirtualResourceListInput)
  58. if err != nil {
  59. return nil, errors.Wrapf(err, "SLLMSkuBaseManager.ListItemFilter")
  60. }
  61. if len(input.LLMType) > 0 {
  62. q = q.Equals("llm_type", input.LLMType)
  63. }
  64. if len(input.LLMTypes) > 0 {
  65. q = q.Filter(sqlchemy.In(q.Field("llm_type"), input.LLMTypes))
  66. }
  67. q, err = man.SMountedModelsResourceManager.ListItemFilter(ctx, q, userCred, input.MountedModelResourceListInput)
  68. if err != nil {
  69. return nil, errors.Wrap(err, "SMountedAppsResourceManager")
  70. }
  71. return q, nil
  72. }
  73. func (manager *SLLMSkuManager) FetchCustomizeColumns(
  74. ctx context.Context,
  75. userCred mcclient.TokenCredential,
  76. query jsonutils.JSONObject,
  77. objs []interface{},
  78. fields stringutils2.SSortedStrings,
  79. isList bool,
  80. ) []api.LLMSkuDetails {
  81. skuIds := []string{}
  82. imageIds := []string{}
  83. templateIds := []string{}
  84. skus := []SLLMSku{}
  85. jsonutils.Update(&skus, objs)
  86. virows := manager.SSharableVirtualResourceBaseManager.FetchCustomizeColumns(ctx, userCred, query, objs, fields, isList)
  87. for _, sku := range skus {
  88. skuIds = append(skuIds, sku.Id)
  89. if imgId := sku.GetLLMImageId(); imgId != "" {
  90. imageIds = append(imageIds, imgId)
  91. }
  92. if sku.Volumes != nil && len(*sku.Volumes) > 0 && len((*sku.Volumes)[0].TemplateId) > 0 {
  93. templateIds = append(templateIds, (*sku.Volumes)[0].TemplateId)
  94. }
  95. }
  96. q := GetLLMManager().Query().In("llm_sku_id", skuIds).GroupBy("llm_sku_id")
  97. q = q.AppendField(q.Field("llm_sku_id"))
  98. q = q.AppendField(sqlchemy.COUNT("llm_capacity"))
  99. details := []struct {
  100. LLMSkuId string
  101. LLMCapacity int
  102. }{}
  103. q.All(&details)
  104. res := make([]api.LLMSkuDetails, len(objs))
  105. mountedModelIds := make([]string, 0)
  106. for i, sku := range skus {
  107. res[i].SharableVirtualResourceDetails = virows[i]
  108. res[i].LLMType = sku.LLMType
  109. res[i].LLMSpec = sku.LLMSpec
  110. for _, v := range details {
  111. if v.LLMSkuId == sku.Id {
  112. res[i].LLMCapacity = v.LLMCapacity
  113. break
  114. }
  115. }
  116. if modelIds := sku.GetMountedModels(); len(modelIds) > 0 {
  117. mountedModelIds = append(mountedModelIds, modelIds...)
  118. }
  119. }
  120. // fetch mounted models
  121. if len(mountedModelIds) > 0 {
  122. instModels := make(map[string]SInstantModel)
  123. err := db.FetchModelObjectsByIds(GetInstantModelManager(), "id", mountedModelIds, &instModels)
  124. if err != nil {
  125. log.Errorf("FetchModelObjectsByIds InstantModelManager fail %s", err)
  126. } else {
  127. for i, sku := range skus {
  128. modelIds := sku.GetMountedModels()
  129. if len(modelIds) > 0 {
  130. res[i].MountedModelDetails = make([]api.MountedModelInfo, 0)
  131. for _, modelId := range modelIds {
  132. if instModel, ok := instModels[modelId]; ok {
  133. info := api.MountedModelInfo{
  134. Id: instModel.Id,
  135. ModelId: instModel.ModelId,
  136. FullName: instModel.ModelName + ":" + instModel.ModelTag,
  137. }
  138. res[i].MountedModelDetails = append(res[i].MountedModelDetails, info)
  139. }
  140. }
  141. }
  142. }
  143. }
  144. }
  145. {
  146. images := make(map[string]SLLMImage)
  147. err := db.FetchModelObjectsByIds(GetLLMImageManager(), "id", imageIds, &images)
  148. if err == nil {
  149. for i, sku := range skus {
  150. if imgId := sku.GetLLMImageId(); imgId != "" {
  151. if image, ok := images[imgId]; ok {
  152. res[i].Image = image.Name
  153. res[i].ImageLabel = image.ImageLabel
  154. res[i].ImageName = image.ImageName
  155. }
  156. }
  157. }
  158. } else {
  159. log.Errorf("FetchModelObjectsByIds LLMImageManager fail %s", err)
  160. }
  161. }
  162. if len(templateIds) > 0 {
  163. templates, err := fetchTemplates(ctx, userCred, templateIds)
  164. if err == nil {
  165. for i, sku := range skus {
  166. if templ, ok := templates[(*sku.Volumes)[0].TemplateId]; ok {
  167. res[i].Template = templ.Name
  168. }
  169. }
  170. } else {
  171. log.Errorf("fail to retrive image info %s", err)
  172. }
  173. }
  174. return res
  175. }
  176. func (man *SLLMSkuManager) ValidateCreateData(ctx context.Context, userCred mcclient.TokenCredential, ownerId mcclient.IIdentityProvider, query jsonutils.JSONObject, input *api.LLMSkuCreateInput) (*api.LLMSkuCreateInput, error) {
  177. var err error
  178. input.LLMSKuBaseCreateInput, err = man.SLLMSkuBaseManager.ValidateCreateData(ctx, userCred, ownerId, query, input.LLMSKuBaseCreateInput)
  179. if err != nil {
  180. return nil, errors.Wrap(err, "SLLMSkuBaseManager.ValidateCreateData")
  181. }
  182. if !api.IsLLMContainerType(input.LLMType) && input.LLMType != string(api.LLM_CONTAINER_DIFY) {
  183. return input, errors.Wrap(httperrors.ErrInputParameter, "llm_type must be one of "+strings.Join(api.LLM_CONTAINER_TYPES.List(), ","))
  184. }
  185. drv, err := GetLLMContainerDriverWithError(api.LLMContainerType(input.LLMType))
  186. if err != nil {
  187. return input, errors.Wrap(err, "get container driver")
  188. }
  189. input, err = drv.ValidateLLMSkuCreateData(ctx, userCred, input)
  190. if err != nil {
  191. return input, errors.Wrap(err, "validate create input")
  192. }
  193. input.Status = api.STATUS_READY
  194. return input, nil
  195. }
  196. // GetLLMImageId returns the primary image id for this SKU. Delegates to driver.
  197. func (sku *SLLMSku) GetLLMImageId() string {
  198. return sku.GetLLMContainerDriver().GetPrimaryImageId(sku)
  199. }
  200. // GetMountedModels returns mounted model ids (from Ollama or Vllm spec). Delegates to instant-model driver; returns nil for drivers that do not support instant models (e.g. Dify).
  201. func (sku *SLLMSku) GetMountedModels() []string {
  202. drv, err := GetLLMContainerInstantModelDriver(api.LLMContainerType(sku.LLMType))
  203. if err != nil {
  204. return nil
  205. }
  206. return drv.GetMountedModels(sku)
  207. }
  208. func (sku *SLLMSku) GetLLMContainerDriver() ILLMContainerDriver {
  209. return GetLLMContainerDriver(api.LLMContainerType(sku.LLMType))
  210. }
  211. func (sku *SLLMSku) ValidateUpdateData(ctx context.Context, userCred mcclient.TokenCredential, query jsonutils.JSONObject, input api.LLMSkuUpdateInput) (api.LLMSkuUpdateInput, error) {
  212. var err error
  213. input.LLMSkuBaseUpdateInput, err = sku.SLLMSkuBase.ValidateUpdateData(ctx, userCred, query, input.LLMSkuBaseUpdateInput)
  214. if err != nil {
  215. return input, errors.Wrap(err, "validate LLMSkuBaseUpdateInput")
  216. }
  217. if sku.LLMSpec == nil {
  218. return input, nil
  219. }
  220. drv := sku.GetLLMContainerDriver()
  221. updateInput, err := drv.ValidateLLMSkuUpdateData(ctx, userCred, sku, &input)
  222. if err != nil {
  223. return input, errors.Wrap(err, "validate update spec")
  224. }
  225. return *updateInput, nil
  226. }
  227. func (sku *SLLMSku) ValidateDeleteCondition(ctx context.Context, info jsonutils.JSONObject) error {
  228. count, err := GetLLMManager().Query().Equals("llm_sku_id", sku.Id).CountWithError()
  229. if err != nil {
  230. return errors.Wrap(err, "fetch llm")
  231. }
  232. if count > 0 {
  233. return errors.Wrap(errors.ErrNotSupported, "This sku is currently in use by LLM")
  234. }
  235. return nil
  236. }
  237. func fetchTemplates(ctx context.Context, userCred mcclient.TokenCredential, templateIds []string) (map[string]imageapi.ImageDetails, error) {
  238. s := auth.GetSession(ctx, userCred, "")
  239. params := mcclientoptions.BaseListOptions{}
  240. params.Id = templateIds
  241. limit := len(templateIds)
  242. params.Limit = &limit
  243. params.Scope = "maxallowed"
  244. results, err := imagemodules.Images.List(s, jsonutils.Marshal(params))
  245. if err != nil {
  246. return nil, errors.Wrap(err, "Images.List")
  247. }
  248. templates := make(map[string]imageapi.ImageDetails)
  249. for i := range results.Data {
  250. tmpl := imageapi.ImageDetails{}
  251. err := results.Data[i].Unmarshal(&tmpl)
  252. if err == nil {
  253. templates[tmpl.Id] = tmpl
  254. }
  255. }
  256. return templates, nil
  257. }