| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- // Copyright 2019 Yunion
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package sqlchemy
- import (
- "bytes"
- "fmt"
- "reflect"
- "yunion.io/x/log"
- "yunion.io/x/pkg/errors"
- "yunion.io/x/pkg/gotypes"
- "yunion.io/x/pkg/util/reflectutils"
- )
- // Increment perform an incremental update on a record, the primary key of the record is specified in diff,
- // the numeric fields of this record will be atomically added by the value of the corresponding field in diff
- // if target is given as a pointer to a variable, the result will be stored in the target
- // if target is not given, the updated result will be stored in diff
- func (t *STableSpec) Increment(diff interface{}, target interface{}) error {
- if !t.Database().backend.CanUpdate() {
- return errors.ErrNotSupported
- }
- return t.incrementInternal(diff, "+", target)
- }
- // Decrement is similar to Increment methods, the difference is that this method will atomically decrease the numeric fields
- // with the value of diff
- func (t *STableSpec) Decrement(diff interface{}, target interface{}) error {
- if !t.Database().backend.CanUpdate() {
- return errors.ErrNotSupported
- }
- return t.incrementInternal(diff, "-", target)
- }
- func (t *STableSpec) incrementInternalSql(diff interface{}, opcode string, target interface{}) (*SUpdateSQLResult, error) {
- dataValue := reflect.Indirect(reflect.ValueOf(diff))
- fields := reflectutils.FetchStructFieldValueSet(dataValue)
- var targetFields reflectutils.SStructFieldValueSet
- if target != nil {
- targetValue := reflect.Indirect(reflect.ValueOf(target))
- targetFields = reflectutils.FetchStructFieldValueSet(targetValue)
- }
- qChar := t.Database().backend.QuoteChar()
- primaries := make([]sPrimaryKeyValue, 0)
- vars := make([]interface{}, 0)
- versionFields := make([]string, 0)
- updatedFields := make([]string, 0)
- incFields := make([]string, 0)
- for _, c := range t.Columns() {
- k := c.Name()
- v, _ := fields.GetInterface(k)
- if c.IsPrimary() {
- if targetFields != nil {
- v, _ = targetFields.GetInterface(k)
- }
- if !gotypes.IsNil(v) && !c.IsZero(v) {
- primaries = append(primaries, sPrimaryKeyValue{
- key: k,
- value: v,
- })
- } else if c.IsText() {
- primaries = append(primaries, sPrimaryKeyValue{
- key: k,
- value: "",
- })
- } else {
- return nil, ErrEmptyPrimaryKey
- }
- continue
- }
- if c.IsUpdatedAt() {
- updatedFields = append(updatedFields, k)
- continue
- }
- if c.IsAutoVersion() {
- versionFields = append(versionFields, k)
- continue
- }
- if c.IsNumeric() && !c.IsZero(v) {
- incFields = append(incFields, k)
- vars = append(vars, v)
- continue
- }
- }
- if len(vars) == 0 {
- return nil, ErrNoDataToUpdate
- }
- if len(primaries) == 0 {
- return nil, ErrEmptyPrimaryKey
- }
- var buf bytes.Buffer
- buf.WriteString(fmt.Sprintf("UPDATE %s%s%s SET ", qChar, t.name, qChar))
- first := true
- for _, k := range incFields {
- if first {
- first = false
- } else {
- buf.WriteString(", ")
- }
- buf.WriteString(fmt.Sprintf("%s%s%s = %s%s%s %s ?", qChar, k, qChar, qChar, k, qChar, opcode))
- }
- for _, versionField := range versionFields {
- buf.WriteString(fmt.Sprintf(", %s%s%s = %s%s%s + 1", qChar, versionField, qChar, qChar, versionField, qChar))
- }
- for _, updatedField := range updatedFields {
- buf.WriteString(fmt.Sprintf(", %s%s%s = %s", qChar, updatedField, qChar, t.Database().backend.CurrentUTCTimeStampString()))
- }
- buf.WriteString(" WHERE ")
- for i, pkv := range primaries {
- if i > 0 {
- buf.WriteString(" AND ")
- }
- buf.WriteString(fmt.Sprintf("%s%s%s = ?", qChar, pkv.key, qChar))
- vars = append(vars, pkv.value)
- }
- if DEBUG_SQLCHEMY {
- log.Infof("Update: %s %s", buf.String(), vars)
- }
- return &SUpdateSQLResult{
- Sql: buf.String(),
- Vars: vars,
- primaries: primaries,
- }, nil
- }
- func (t *STableSpec) incrementInternal(diff interface{}, opcode string, target interface{}) error {
- if target == nil {
- if reflect.ValueOf(diff).Kind() != reflect.Ptr {
- return errors.Wrap(ErrNeedsPointer, "Incremental input must be a Pointer")
- }
- } else {
- if reflect.ValueOf(target).Kind() != reflect.Ptr {
- return errors.Wrap(ErrNeedsPointer, "Incremental update target must be a Pointer")
- }
- }
- intResult, err := t.incrementInternalSql(diff, opcode, target)
- if target != nil {
- err = t.execUpdateSql(target, intResult)
- } else {
- err = t.execUpdateSql(diff, intResult)
- }
- if err != nil {
- return errors.Wrap(err, "query after update failed")
- }
- return nil
- }
|