llm_save_instant_model.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. package models
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "os"
  8. "time"
  9. "yunion.io/x/jsonutils"
  10. "yunion.io/x/log"
  11. "yunion.io/x/pkg/errors"
  12. "yunion.io/x/pkg/util/httputils"
  13. computeapi "yunion.io/x/onecloud/pkg/apis/compute"
  14. hostapi "yunion.io/x/onecloud/pkg/apis/host"
  15. api "yunion.io/x/onecloud/pkg/apis/llm"
  16. "yunion.io/x/onecloud/pkg/cloudcommon/db"
  17. "yunion.io/x/onecloud/pkg/cloudcommon/db/taskman"
  18. "yunion.io/x/onecloud/pkg/httperrors"
  19. "yunion.io/x/onecloud/pkg/mcclient"
  20. "yunion.io/x/onecloud/pkg/mcclient/modules/compute"
  21. )
  22. func (llm *SLLM) GetDetailsProbedModels(ctx context.Context, userCred mcclient.TokenCredential, query jsonutils.JSONObject) (jsonutils.JSONObject, error) {
  23. mdlInfos, err := llm.getProbedInstantModelsExt(ctx, userCred)
  24. if err != nil {
  25. return nil, errors.Wrap(err, "getProbedPackagesExt")
  26. }
  27. return jsonutils.Marshal(mdlInfos), nil
  28. }
  29. func (llm *SLLM) PerformSaveInstantModel(
  30. ctx context.Context,
  31. userCred mcclient.TokenCredential,
  32. query jsonutils.JSONObject,
  33. input api.LLMSaveInstantModelInput,
  34. ) (jsonutils.JSONObject, error) {
  35. if llm.Status != api.LLM_STATUS_RUNNING {
  36. return nil, httperrors.NewInvalidStatusError("LLM is not running")
  37. }
  38. mdlInfos, err := llm.getProbedInstantModelsExt(ctx, userCred)
  39. if err != nil {
  40. return nil, errors.Wrap(err, "getProbedPackagesExt")
  41. }
  42. var mdlInfo *api.LLMInternalInstantMdlInfo
  43. for _, info := range mdlInfos {
  44. if info.ModelId == input.ModelId {
  45. mdlInfo = &info
  46. break
  47. }
  48. }
  49. if mdlInfo == nil {
  50. return nil, httperrors.NewBadRequestError("ModelId %s not found", input.ModelId)
  51. }
  52. mountDirs, err := llm.detectModelPaths(ctx, userCred, *mdlInfo)
  53. if err != nil {
  54. return nil, errors.Wrap(err, "detectModelPaths")
  55. }
  56. if len(input.ModelFullName) == 0 {
  57. input.ModelFullName = fmt.Sprintf("%s-%s", mdlInfo.Name+":"+mdlInfo.Tag, time.Now().Format("060102"))
  58. }
  59. var ownerId mcclient.IIdentityProvider
  60. if len(input.TenantId) > 0 {
  61. domainId := input.ProjectDomainId
  62. if len(domainId) == 0 {
  63. domainId = userCred.GetProjectDomainId()
  64. } else {
  65. domain, err := db.TenantCacheManager.FetchDomainByIdOrName(ctx, domainId)
  66. if err != nil {
  67. return nil, errors.Wrap(err, "TenantCache.FetchDomainByIdOrName")
  68. }
  69. domainId = domain.GetId()
  70. }
  71. tenant, err := db.TenantCacheManager.FetchTenantByIdOrNameInDomain(ctx, input.TenantId, domainId)
  72. if err != nil {
  73. return nil, errors.Wrap(err, "TenantCache.FetchById")
  74. }
  75. ownerId = &db.SOwnerId{
  76. DomainId: domainId,
  77. Domain: tenant.Domain,
  78. ProjectId: tenant.Id,
  79. Project: tenant.Name,
  80. }
  81. } else {
  82. ownerId = userCred
  83. }
  84. input.ProjectId = ownerId.GetProjectId()
  85. input.ProjectDomainId = ownerId.GetProjectDomainId()
  86. modelName, modelTag, _ := llm.GetLargeLanguageModelName(input.ModelFullName)
  87. if len(modelName) == 0 {
  88. modelName = mdlInfo.Name
  89. }
  90. if len(modelTag) == 0 {
  91. modelTag = mdlInfo.Tag
  92. }
  93. drv := llm.GetLLMContainerDriver()
  94. instantModelCreateInput := api.InstantModelCreateInput{
  95. LlmType: drv.GetType(),
  96. ModelId: mdlInfo.ModelId,
  97. ModelName: modelName,
  98. ModelTag: modelTag,
  99. Mounts: mountDirs,
  100. }
  101. instantModelCreateInput.Name = input.ModelFullName
  102. boolTrue := true
  103. instantModelCreateInput.DoNotImport = &boolTrue
  104. log.Debugf("instantModelCreateInput: %s", jsonutils.Marshal(instantModelCreateInput))
  105. instantMdlObj, err := db.DoCreate(GetInstantModelManager(), ctx, userCred, nil, jsonutils.Marshal(instantModelCreateInput), ownerId)
  106. if err != nil {
  107. return nil, errors.Wrap(err, "GetInstantModelManager.DoCreate")
  108. }
  109. instantMdl := instantMdlObj.(*SInstantModel)
  110. input.InstantModelId = instantMdl.Id
  111. _, err = llm.StartSaveModelImageTask(ctx, userCred, input)
  112. if err != nil {
  113. return nil, errors.Wrap(err, "StartSaveAppImageTask")
  114. }
  115. return jsonutils.Marshal(instantMdl), nil
  116. }
  117. func (llm *SLLM) DoSaveModelImage(ctx context.Context, userCred mcclient.TokenCredential, session *mcclient.ClientSession, input api.LLMSaveInstantModelInput) error {
  118. llm.SetStatus(ctx, userCred, api.LLM_STATUS_SAVING_MODEL, "DoSaveModelImage")
  119. instantModelObj, err := GetInstantModelManager().FetchById(input.InstantModelId)
  120. if err != nil {
  121. return errors.Wrap(err, "GetInstantModelManager.FetchById")
  122. }
  123. instantModel := instantModelObj.(*SInstantModel)
  124. drv, err := GetLLMContainerInstantModelDriver(llm.GetLLMContainerDriver().GetType())
  125. if err != nil {
  126. return errors.Wrap(err, "GetLLMContainerInstantModelDriver")
  127. }
  128. prefix, saveDirs, err := drv.GetSaveDirectories(instantModel)
  129. if err != nil {
  130. return errors.Wrap(err, "GetSaveDirectories")
  131. }
  132. saveImageInput := computeapi.ContainerSaveVolumeMountToImageInput{
  133. GenerateName: input.ModelFullName,
  134. Notes: fmt.Sprintf("instance model image for %s(%s)", instantModel.ModelId, instantModel.ModelName+":"+instantModel.ModelTag),
  135. Index: 0,
  136. Dirs: saveDirs,
  137. UsedByPostOverlay: true,
  138. DirPrefix: prefix,
  139. }
  140. lc, err := llm.GetLLMContainer()
  141. if err != nil {
  142. return errors.Wrap(err, "GetLLMContainer")
  143. }
  144. result, err := compute.Containers.PerformAction(session, lc.CmpId, "save-volume-mount-image", jsonutils.Marshal(saveImageInput))
  145. if err != nil {
  146. return errors.Wrap(err, "compute.Containers.PerformAction")
  147. }
  148. log.Debugf("container save-volume-mount-image result: %s", result)
  149. saveImageOutput := hostapi.ContainerSaveVolumeMountToImageInput{}
  150. err = result.Unmarshal(&saveImageOutput)
  151. if err != nil {
  152. return errors.Wrap(err, "save-volume-mount-image.result.Unmarshal")
  153. }
  154. err = instantModel.saveImageId(ctx, userCred, saveImageOutput.ImageId)
  155. if err != nil {
  156. return errors.Wrap(err, "saveImageId")
  157. }
  158. return nil
  159. }
  160. func (llm *SLLM) StartSaveModelImageTask(ctx context.Context, userCred mcclient.TokenCredential, input api.LLMSaveInstantModelInput) (*taskman.STask, error) {
  161. llm.SetStatus(ctx, userCred, api.LLM_STATUS_START_SAVE_MODEL, "StartSaveModelImageTask")
  162. params := jsonutils.Marshal(input)
  163. task, err := taskman.TaskManager.NewTask(ctx, "LLMStartSaveModelImageTask", llm, userCred, params.(*jsonutils.JSONDict), "", "")
  164. if err != nil {
  165. return nil, errors.Wrap(err, "taskman.TaskManager.NewTask")
  166. }
  167. err = task.ScheduleRun(nil)
  168. if err != nil {
  169. return nil, errors.Wrap(err, "task.ScheduleRun")
  170. }
  171. return task, nil
  172. }
  173. func (llm *SLLM) detectModelPaths(ctx context.Context, userCred mcclient.TokenCredential, pkgInfo api.LLMInternalInstantMdlInfo) ([]string, error) {
  174. drv, err := GetLLMContainerInstantModelDriver(llm.GetLLMContainerDriver().GetType())
  175. if err != nil {
  176. return nil, errors.Wrap(err, "GetLLMContainerInstantModelDriver")
  177. }
  178. return drv.DetectModelPaths(ctx, userCred, llm, pkgInfo)
  179. }
  180. // HttpGet performs a GET request and returns the response body
  181. func (llm *SLLM) HttpGet(ctx context.Context, url string) ([]byte, error) {
  182. client := httputils.GetTimeoutClient(0)
  183. transport := httputils.GetTransport(true)
  184. client.Transport = transport
  185. resp, err := httputils.Request(client, ctx, httputils.GET, url, http.Header{}, nil, false)
  186. if err != nil {
  187. return nil, errors.Wrap(err, "http request failed")
  188. }
  189. defer resp.Body.Close()
  190. if resp.StatusCode != http.StatusOK {
  191. if resp.StatusCode == http.StatusNotFound {
  192. return nil, httperrors.NewResourceNotFoundError("url %s not found", url)
  193. }
  194. return nil, errors.Errorf("unexpected status code: %d", resp.StatusCode)
  195. }
  196. body, err := io.ReadAll(resp.Body)
  197. if err != nil {
  198. return nil, errors.Wrap(err, "failed to read response body")
  199. }
  200. return body, nil
  201. }
  202. // HttpDownloadFile downloads a file from URL and saves it to the specified path
  203. func (llm *SLLM) HttpDownloadFile(ctx context.Context, url string, filePath string) error {
  204. client := httputils.GetTimeoutClient(0)
  205. transport := httputils.GetTransport(true)
  206. client.Transport = transport
  207. resp, err := httputils.Request(client, ctx, httputils.GET, url, http.Header{}, nil, false)
  208. if err != nil {
  209. return errors.Wrap(err, "http request failed")
  210. }
  211. defer resp.Body.Close()
  212. if resp.StatusCode != http.StatusOK {
  213. if resp.StatusCode == http.StatusNotFound {
  214. return errors.Wrapf(httperrors.ErrResourceNotFound, "url %s not found", url)
  215. }
  216. return errors.Errorf("unexpected status code: %d", resp.StatusCode)
  217. }
  218. // create temporary file first, then rename to avoid partial downloads
  219. tmpPath := filePath + ".tmp"
  220. out, err := os.Create(tmpPath)
  221. if err != nil {
  222. return errors.Wrapf(err, "failed to create file %s", tmpPath)
  223. }
  224. written, err := io.Copy(out, resp.Body)
  225. out.Close()
  226. if err != nil {
  227. os.Remove(tmpPath)
  228. return errors.Wrap(err, "failed to write file")
  229. }
  230. log.Infof("Downloaded %d bytes to %s", written, filePath)
  231. // rename tmp file to final path
  232. if err := os.Rename(tmpPath, filePath); err != nil {
  233. os.Remove(tmpPath)
  234. return errors.Wrapf(err, "failed to rename %s to %s", tmpPath, filePath)
  235. }
  236. return nil
  237. }