base_driver.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. package llm_container
  2. import (
  3. "context"
  4. "yunion.io/x/pkg/errors"
  5. computeapi "yunion.io/x/onecloud/pkg/apis/compute"
  6. api "yunion.io/x/onecloud/pkg/apis/llm"
  7. "yunion.io/x/onecloud/pkg/cloudcommon/validators"
  8. "yunion.io/x/onecloud/pkg/httperrors"
  9. "yunion.io/x/onecloud/pkg/llm/models"
  10. "yunion.io/x/onecloud/pkg/mcclient"
  11. )
  12. type baseDriver struct {
  13. drvType api.LLMContainerType
  14. }
  15. func newBaseDriver(drvType api.LLMContainerType) baseDriver {
  16. return baseDriver{drvType: drvType}
  17. }
  18. func (b *baseDriver) GetType() api.LLMContainerType {
  19. return b.drvType
  20. }
  21. func (b *baseDriver) GetPrimaryImageId(sku *models.SLLMSku) string {
  22. return sku.LLMImageId
  23. }
  24. func (b *baseDriver) GetPrimaryContainer(ctx context.Context, llm *models.SLLM, containers []*computeapi.PodContainerDesc) (*computeapi.PodContainerDesc, error) {
  25. return containers[0], nil
  26. }
  27. func (b *baseDriver) GetMountedModels(sku *models.SLLMSku) []string {
  28. return sku.MountedModels
  29. }
  30. func (b *baseDriver) StartLLM(ctx context.Context, userCred mcclient.TokenCredential, llm *models.SLLM) error {
  31. return nil
  32. }
  33. func (b *baseDriver) ValidateLLMSkuCreateData(ctx context.Context, userCred mcclient.TokenCredential, input *api.LLMSkuCreateInput) (*api.LLMSkuCreateInput, error) {
  34. imgObj, err := validators.ValidateModel(ctx, userCred, models.GetLLMImageManager(), &input.LLMImageId)
  35. if err != nil {
  36. return nil, errors.Wrapf(err, "validate image_id %s", input.LLMImageId)
  37. }
  38. llmImage := imgObj.(*models.SLLMImage)
  39. if llmImage.LLMType != input.LLMType {
  40. return nil, errors.Wrapf(httperrors.ErrInvalidStatus, "image %s is not of type %s", input.LLMImageId, input.LLMType)
  41. }
  42. input.LLMImageId = llmImage.Id
  43. if input.MountedModels != nil {
  44. for i, mdl := range input.MountedModels {
  45. instMdl, err := models.GetInstantModelManager().FetchByIdOrName(ctx, userCred, mdl)
  46. if err != nil {
  47. return nil, errors.Wrapf(err, "validate mounted model %s", mdl)
  48. }
  49. instantModle := instMdl.(*models.SInstantModel)
  50. if instantModle.LlmType != input.LLMType {
  51. return nil, errors.Wrapf(httperrors.ErrInvalidStatus, "mounted model %s is not of type %s", mdl, input.LLMType)
  52. }
  53. input.MountedModels[i] = instantModle.GetId()
  54. }
  55. }
  56. return input, nil
  57. }
  58. func (b *baseDriver) ValidateLLMSkuUpdateData(ctx context.Context, userCred mcclient.TokenCredential, sku *models.SLLMSku, input *api.LLMSkuUpdateInput) (*api.LLMSkuUpdateInput, error) {
  59. llmImageId := input.LLMImageId
  60. if llmImageId != "" {
  61. imgObj, err := validators.ValidateModel(ctx, userCred, models.GetLLMImageManager(), &llmImageId)
  62. if err != nil {
  63. return nil, errors.Wrapf(err, "validate image_id %s", llmImageId)
  64. }
  65. llmImage := imgObj.(*models.SLLMImage)
  66. if llmImage.LLMType != sku.LLMType {
  67. return nil, errors.Wrapf(httperrors.ErrInvalidStatus, "image %s is not of type %s", llmImageId, sku.LLMType)
  68. }
  69. input.LLMImageId = llmImage.Id
  70. }
  71. mountedModels := input.MountedModels
  72. if input.MountedModels != nil {
  73. mountedModels = make([]string, len(input.MountedModels))
  74. for i, mdl := range input.MountedModels {
  75. instMdl, err := models.GetInstantModelManager().FetchByIdOrName(ctx, userCred, mdl)
  76. if err != nil {
  77. return nil, errors.Wrapf(err, "validate mounted model %s", mdl)
  78. }
  79. instantModle := instMdl.(*models.SInstantModel)
  80. if instantModle.LlmType != sku.LLMType {
  81. return nil, errors.Wrapf(httperrors.ErrInvalidStatus, "mounted model %s is not of type %s", mdl, sku.LLMType)
  82. }
  83. mountedModels[i] = instantModle.GetId()
  84. }
  85. }
  86. input.MountedModels = mountedModels
  87. return input, nil
  88. }
  89. func MatchContainerToUpdateByName(ctr *computeapi.SContainer, podCtrs []*computeapi.PodContainerCreateInput) (*computeapi.PodContainerCreateInput, error) {
  90. ctrName := ctr.Name
  91. for _, podCtr := range podCtrs {
  92. if podCtr.Name == ctrName {
  93. return podCtr, nil
  94. }
  95. }
  96. return nil, errors.Wrapf(errors.ErrNotFound, "container %s not found", ctrName)
  97. }
  98. func (b *baseDriver) MatchContainerToUpdate(ctr *computeapi.SContainer, podCtrs []*computeapi.PodContainerCreateInput) (*computeapi.PodContainerCreateInput, error) {
  99. if len(podCtrs) == 1 {
  100. return podCtrs[0], nil
  101. }
  102. return MatchContainerToUpdateByName(ctr, podCtrs)
  103. }
  104. func (b *baseDriver) ValidateLLMCreateSpec(ctx context.Context, userCred mcclient.TokenCredential, sku *models.SLLMSku, input *api.LLMSpec) (*api.LLMSpec, error) {
  105. return input, nil
  106. }
  107. func (b *baseDriver) ValidateLLMUpdateSpec(ctx context.Context, userCred mcclient.TokenCredential, llm *models.SLLM, input *api.LLMSpec) (*api.LLMSpec, error) {
  108. return input, nil
  109. }