| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- package llm_container
- import (
- "context"
- "yunion.io/x/pkg/errors"
- computeapi "yunion.io/x/onecloud/pkg/apis/compute"
- api "yunion.io/x/onecloud/pkg/apis/llm"
- "yunion.io/x/onecloud/pkg/cloudcommon/validators"
- "yunion.io/x/onecloud/pkg/httperrors"
- "yunion.io/x/onecloud/pkg/llm/models"
- "yunion.io/x/onecloud/pkg/mcclient"
- )
- type baseDriver struct {
- drvType api.LLMContainerType
- }
- func newBaseDriver(drvType api.LLMContainerType) baseDriver {
- return baseDriver{drvType: drvType}
- }
- func (b *baseDriver) GetType() api.LLMContainerType {
- return b.drvType
- }
- func (b *baseDriver) GetPrimaryImageId(sku *models.SLLMSku) string {
- return sku.LLMImageId
- }
- func (b *baseDriver) GetPrimaryContainer(ctx context.Context, llm *models.SLLM, containers []*computeapi.PodContainerDesc) (*computeapi.PodContainerDesc, error) {
- return containers[0], nil
- }
- func (b *baseDriver) GetMountedModels(sku *models.SLLMSku) []string {
- return sku.MountedModels
- }
- func (b *baseDriver) StartLLM(ctx context.Context, userCred mcclient.TokenCredential, llm *models.SLLM) error {
- return nil
- }
- func (b *baseDriver) ValidateLLMSkuCreateData(ctx context.Context, userCred mcclient.TokenCredential, input *api.LLMSkuCreateInput) (*api.LLMSkuCreateInput, error) {
- imgObj, err := validators.ValidateModel(ctx, userCred, models.GetLLMImageManager(), &input.LLMImageId)
- if err != nil {
- return nil, errors.Wrapf(err, "validate image_id %s", input.LLMImageId)
- }
- llmImage := imgObj.(*models.SLLMImage)
- if llmImage.LLMType != input.LLMType {
- return nil, errors.Wrapf(httperrors.ErrInvalidStatus, "image %s is not of type %s", input.LLMImageId, input.LLMType)
- }
- input.LLMImageId = llmImage.Id
- if input.MountedModels != nil {
- for i, mdl := range input.MountedModels {
- instMdl, err := models.GetInstantModelManager().FetchByIdOrName(ctx, userCred, mdl)
- if err != nil {
- return nil, errors.Wrapf(err, "validate mounted model %s", mdl)
- }
- instantModle := instMdl.(*models.SInstantModel)
- if instantModle.LlmType != input.LLMType {
- return nil, errors.Wrapf(httperrors.ErrInvalidStatus, "mounted model %s is not of type %s", mdl, input.LLMType)
- }
- input.MountedModels[i] = instantModle.GetId()
- }
- }
- return input, nil
- }
- func (b *baseDriver) ValidateLLMSkuUpdateData(ctx context.Context, userCred mcclient.TokenCredential, sku *models.SLLMSku, input *api.LLMSkuUpdateInput) (*api.LLMSkuUpdateInput, error) {
- llmImageId := input.LLMImageId
- if llmImageId != "" {
- imgObj, err := validators.ValidateModel(ctx, userCred, models.GetLLMImageManager(), &llmImageId)
- if err != nil {
- return nil, errors.Wrapf(err, "validate image_id %s", llmImageId)
- }
- llmImage := imgObj.(*models.SLLMImage)
- if llmImage.LLMType != sku.LLMType {
- return nil, errors.Wrapf(httperrors.ErrInvalidStatus, "image %s is not of type %s", llmImageId, sku.LLMType)
- }
- input.LLMImageId = llmImage.Id
- }
- mountedModels := input.MountedModels
- if input.MountedModels != nil {
- mountedModels = make([]string, len(input.MountedModels))
- for i, mdl := range input.MountedModels {
- instMdl, err := models.GetInstantModelManager().FetchByIdOrName(ctx, userCred, mdl)
- if err != nil {
- return nil, errors.Wrapf(err, "validate mounted model %s", mdl)
- }
- instantModle := instMdl.(*models.SInstantModel)
- if instantModle.LlmType != sku.LLMType {
- return nil, errors.Wrapf(httperrors.ErrInvalidStatus, "mounted model %s is not of type %s", mdl, sku.LLMType)
- }
- mountedModels[i] = instantModle.GetId()
- }
- }
- input.MountedModels = mountedModels
- return input, nil
- }
- func MatchContainerToUpdateByName(ctr *computeapi.SContainer, podCtrs []*computeapi.PodContainerCreateInput) (*computeapi.PodContainerCreateInput, error) {
- ctrName := ctr.Name
- for _, podCtr := range podCtrs {
- if podCtr.Name == ctrName {
- return podCtr, nil
- }
- }
- return nil, errors.Wrapf(errors.ErrNotFound, "container %s not found", ctrName)
- }
- func (b *baseDriver) MatchContainerToUpdate(ctr *computeapi.SContainer, podCtrs []*computeapi.PodContainerCreateInput) (*computeapi.PodContainerCreateInput, error) {
- if len(podCtrs) == 1 {
- return podCtrs[0], nil
- }
- return MatchContainerToUpdateByName(ctr, podCtrs)
- }
- func (b *baseDriver) ValidateLLMCreateSpec(ctx context.Context, userCred mcclient.TokenCredential, sku *models.SLLMSku, input *api.LLMSpec) (*api.LLMSpec, error) {
- return input, nil
- }
- func (b *baseDriver) ValidateLLMUpdateSpec(ctx context.Context, userCred mcclient.TokenCredential, llm *models.SLLM, input *api.LLMSpec) (*api.LLMSpec, error) {
- return input, nil
- }
|