| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312 |
- // Copyright 2019 Yunion
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package sqlchemy
- import (
- "fmt"
- "reflect"
- "strings"
- "yunion.io/x/jsonutils"
- "yunion.io/x/log"
- "yunion.io/x/pkg/errors"
- "yunion.io/x/pkg/gotypes"
- "yunion.io/x/pkg/util/reflectutils"
- "yunion.io/x/pkg/utils"
- )
- // SUpdateSession is a struct to store the state of a update session
- type SUpdateSession struct {
- oValue reflect.Value
- tableSpec *STableSpec
- }
- func (ts *STableSpec) PrepareUpdate(dt interface{}) (*SUpdateSession, error) {
- if reflect.ValueOf(dt).Kind() != reflect.Ptr {
- return nil, errors.Wrap(ErrNeedsPointer, "Update input must be a Pointer")
- }
- dataValue := reflect.ValueOf(dt).Elem()
- fields := reflectutils.FetchStructFieldValueSet(dataValue) // fetchStructFieldNameValue(dataType, dataValue)
- zeroPrimary := make([]string, 0)
- for _, c := range ts.Columns() {
- k := c.Name()
- ov, ok := fields.GetInterface(k)
- if !ok {
- continue
- }
- if c.IsPrimary() && c.IsZero(ov) && !c.IsText() {
- zeroPrimary = append(zeroPrimary, k)
- }
- }
- if len(zeroPrimary) > 0 {
- return nil, errors.Wrapf(ErrEmptyPrimaryKey, "not a valid data, primary key %s empty",
- strings.Join(zeroPrimary, ","))
- }
- originValue := gotypes.DeepCopyRv(dataValue)
- us := SUpdateSession{oValue: originValue, tableSpec: ts}
- return &us, nil
- }
- // SUpdateDiff is a struct to store the differences for an update of a column
- type SUpdateDiff struct {
- old interface{}
- new interface{}
- col IColumnSpec
- }
- // String of SUpdateDiff returns the string representation of a SUpdateDiff
- func (ud *SUpdateDiff) String() string {
- return fmt.Sprintf("%s->%s",
- utils.TruncateString(ud.old, 32),
- utils.TruncateString(ud.new, 32))
- }
- func (ud SUpdateDiff) jsonObj() jsonutils.JSONObject {
- r := jsonutils.NewDict()
- r.Set("old", jsonutils.Marshal(ud.old))
- r.Set("new", jsonutils.Marshal(ud.new))
- return r
- }
- // UpdateDiffs is a map of SUpdateDiff whose key is the column name
- type UpdateDiffs map[string]SUpdateDiff
- // String of UpdateDiffs returns the string representation of UpdateDiffs
- func (uds UpdateDiffs) String() string {
- obj := jsonutils.NewDict()
- for i := range uds {
- obj.Set(uds[i].col.Name(), uds[i].jsonObj())
- }
- return obj.String()
- }
- func updateDiffList2Map(diffs []SUpdateDiff) UpdateDiffs {
- ret := make(map[string]SUpdateDiff)
- for i := range diffs {
- ret[diffs[i].col.Name()] = diffs[i]
- }
- return ret
- }
- type sPrimaryKeyValue struct {
- key string
- value interface{}
- }
- type SUpdateSQLResult struct {
- Sql string
- Vars []interface{}
- setters []SUpdateDiff
- primaries []sPrimaryKeyValue
- }
- func (us *SUpdateSession) SaveUpdateSql(dt interface{}) (*SUpdateSQLResult, error) {
- beforeUpdateFunc := reflect.ValueOf(dt).MethodByName("BeforeUpdate")
- if beforeUpdateFunc.IsValid() && !beforeUpdateFunc.IsNil() {
- beforeUpdateFunc.Call([]reflect.Value{})
- }
- // dataType := reflect.TypeOf(dt).Elem()
- dataValue := reflect.ValueOf(dt).Elem()
- ofields := reflectutils.FetchStructFieldValueSet(us.oValue)
- fields := reflectutils.FetchStructFieldValueSet(dataValue)
- versionFields := make([]string, 0)
- updatedFields := make([]string, 0)
- primaries := make([]sPrimaryKeyValue, 0)
- setters := make([]SUpdateDiff, 0)
- for _, c := range us.tableSpec.Columns() {
- k := c.Name()
- of, _ := ofields.GetInterface(k)
- nf, _ := fields.GetInterface(k)
- if c.IsPrimary() {
- if !gotypes.IsNil(of) && !c.IsZero(of) {
- if c.IsText() {
- ov, _ := of.(string)
- nv, _ := nf.(string)
- if ov != nv && strings.EqualFold(ov, nv) {
- setters = append(setters, SUpdateDiff{old: of, new: nf, col: c})
- }
- }
- primaries = append(primaries, sPrimaryKeyValue{
- key: k,
- value: c.ConvertFromValue(of),
- })
- } else if c.IsText() {
- primaries = append(primaries, sPrimaryKeyValue{
- key: k,
- value: "",
- })
- } else {
- return nil, ErrEmptyPrimaryKey
- }
- continue
- }
- if c.IsAutoVersion() {
- versionFields = append(versionFields, k)
- continue
- }
- if c.IsUpdatedAt() {
- updatedFields = append(updatedFields, k)
- continue
- }
- if reflect.DeepEqual(of, nf) {
- continue
- }
- if of != nil && nf != nil {
- ofJsonStr := jsonutils.Marshal(of).String()
- nfJsonStr := jsonutils.Marshal(nf).String()
- if ofJsonStr == nfJsonStr {
- continue
- }
- if EqualsGrossValue(of, nf) {
- continue
- }
- }
- if c.IsZero(nf) && c.IsText() {
- nf = nil
- }
- setters = append(setters, SUpdateDiff{old: of, new: nf, col: c})
- }
- if len(setters) == 0 {
- return nil, ErrNoDataToUpdate
- }
- if len(primaries) == 0 {
- return nil, ErrEmptyPrimaryKey
- }
- qChar := us.tableSpec.Database().backend.QuoteChar()
- vars := make([]interface{}, 0)
- colsets := make([]string, 0)
- conditions := make([]string, 0)
- for _, udif := range setters {
- if gotypes.IsNil(udif.new) {
- colsets = append(colsets, fmt.Sprintf("%s%s%s = NULL", qChar, udif.col.Name(), qChar))
- } else {
- // validate text length
- if udif.col.IsString() && udif.col.GetWidth() > 0 {
- newStr, ok := udif.new.(string)
- if ok && len(newStr) > udif.col.GetWidth() {
- udif.new = newStr[:udif.col.GetWidth()]
- }
- }
- colsets = append(colsets, fmt.Sprintf("%s%s%s = ?", qChar, udif.col.Name(), qChar))
- vars = append(vars, udif.col.ConvertFromValue(udif.new))
- }
- }
- for _, versionField := range versionFields {
- colsets = append(colsets, fmt.Sprintf("%s%s%s = %s%s%s + 1", qChar, versionField, qChar, qChar, versionField, qChar))
- }
- for _, updatedField := range updatedFields {
- colsets = append(colsets, fmt.Sprintf("%s%s%s = %s", qChar, updatedField, qChar, us.tableSpec.Database().backend.CurrentUTCTimeStampString()))
- }
- for _, pkv := range primaries {
- conditions = append(conditions, fmt.Sprintf("%s%s%s = ?", qChar, pkv.key, qChar))
- vars = append(vars, pkv.value)
- }
- updateSql := TemplateEval(us.tableSpec.Database().backend.UpdateSQLTemplate(), struct {
- Table string
- Columns string
- Conditions string
- }{
- Table: us.tableSpec.name,
- Columns: strings.Join(colsets, ", "),
- Conditions: strings.Join(conditions, " AND "),
- })
- if DEBUG_SQLCHEMY {
- log.Infof("Update: %s", _sqlDebug(updateSql, vars))
- }
- return &SUpdateSQLResult{
- Sql: updateSql,
- Vars: vars,
- setters: setters,
- primaries: primaries,
- }, nil
- }
- func (us *SUpdateSession) saveUpdate(dt interface{}) (UpdateDiffs, error) {
- sqlResult, err := us.SaveUpdateSql(dt)
- if err != nil {
- return nil, errors.Wrap(err, "saveUpateSql")
- }
- err = us.tableSpec.execUpdateSql(dt, sqlResult)
- if err != nil {
- return nil, errors.Wrap(err, "execUpdateSql")
- }
- return updateDiffList2Map(sqlResult.setters), nil
- }
- func (ts *STableSpec) execUpdateSql(dt interface{}, result *SUpdateSQLResult) error {
- results, err := ts.Database().TxExec(result.Sql, result.Vars...)
- if err != nil {
- return errors.Wrap(err, "TxExec")
- }
- if ts.Database().backend.CanSupportRowAffected() {
- aCnt, err := results.RowsAffected()
- if err != nil {
- return errors.Wrap(err, "results.RowsAffected")
- }
- if aCnt > 1 {
- return errors.Wrapf(ErrUnexpectRowCount, "affected rows %d != 1", aCnt)
- }
- }
- q := ts.Query()
- for _, pkv := range result.primaries {
- q = q.Equals(pkv.key, pkv.value)
- }
- err = q.First(dt)
- if err != nil {
- return errors.Wrapf(err, "query after update failed %s", q.DebugString())
- }
- return nil
- }
- // Update method of STableSpec updates a record of a table,
- // dt is the point to the struct storing the record
- // doUpdate provides method to update the field of the record
- func (ts *STableSpec) Update(dt interface{}, doUpdate func() error) (UpdateDiffs, error) {
- if !ts.Database().backend.CanUpdate() {
- return nil, errors.ErrNotSupported
- }
- session, err := ts.PrepareUpdate(dt)
- if err != nil {
- return nil, errors.Wrap(err, "prepareUpdate")
- }
- err = doUpdate()
- if err != nil {
- return nil, errors.Wrap(err, "")
- }
- uds, err := session.saveUpdate(dt)
- if err != nil && errors.Cause(err) == ErrNoDataToUpdate {
- return nil, nil
- } else if err == nil {
- if DEBUG_SQLCHEMY {
- log.Debugf("Update diff: %s", uds)
- }
- }
- return uds, errors.Wrap(err, "saveUpdate")
- }
|