inc.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package sqlchemy
  15. import (
  16. "bytes"
  17. "fmt"
  18. "reflect"
  19. "yunion.io/x/log"
  20. "yunion.io/x/pkg/errors"
  21. "yunion.io/x/pkg/gotypes"
  22. "yunion.io/x/pkg/util/reflectutils"
  23. )
  24. // Increment perform an incremental update on a record, the primary key of the record is specified in diff,
  25. // the numeric fields of this record will be atomically added by the value of the corresponding field in diff
  26. // if target is given as a pointer to a variable, the result will be stored in the target
  27. // if target is not given, the updated result will be stored in diff
  28. func (t *STableSpec) Increment(diff interface{}, target interface{}) error {
  29. if !t.Database().backend.CanUpdate() {
  30. return errors.ErrNotSupported
  31. }
  32. return t.incrementInternal(diff, "+", target)
  33. }
  34. // Decrement is similar to Increment methods, the difference is that this method will atomically decrease the numeric fields
  35. // with the value of diff
  36. func (t *STableSpec) Decrement(diff interface{}, target interface{}) error {
  37. if !t.Database().backend.CanUpdate() {
  38. return errors.ErrNotSupported
  39. }
  40. return t.incrementInternal(diff, "-", target)
  41. }
  42. func (t *STableSpec) incrementInternalSql(diff interface{}, opcode string, target interface{}) (*SUpdateSQLResult, error) {
  43. dataValue := reflect.Indirect(reflect.ValueOf(diff))
  44. fields := reflectutils.FetchStructFieldValueSet(dataValue)
  45. var targetFields reflectutils.SStructFieldValueSet
  46. if target != nil {
  47. targetValue := reflect.Indirect(reflect.ValueOf(target))
  48. targetFields = reflectutils.FetchStructFieldValueSet(targetValue)
  49. }
  50. qChar := t.Database().backend.QuoteChar()
  51. primaries := make([]sPrimaryKeyValue, 0)
  52. vars := make([]interface{}, 0)
  53. versionFields := make([]string, 0)
  54. updatedFields := make([]string, 0)
  55. incFields := make([]string, 0)
  56. for _, c := range t.Columns() {
  57. k := c.Name()
  58. v, _ := fields.GetInterface(k)
  59. if c.IsPrimary() {
  60. if targetFields != nil {
  61. v, _ = targetFields.GetInterface(k)
  62. }
  63. if !gotypes.IsNil(v) && !c.IsZero(v) {
  64. primaries = append(primaries, sPrimaryKeyValue{
  65. key: k,
  66. value: v,
  67. })
  68. } else if c.IsText() {
  69. primaries = append(primaries, sPrimaryKeyValue{
  70. key: k,
  71. value: "",
  72. })
  73. } else {
  74. return nil, ErrEmptyPrimaryKey
  75. }
  76. continue
  77. }
  78. if c.IsUpdatedAt() {
  79. updatedFields = append(updatedFields, k)
  80. continue
  81. }
  82. if c.IsAutoVersion() {
  83. versionFields = append(versionFields, k)
  84. continue
  85. }
  86. if c.IsNumeric() && !c.IsZero(v) {
  87. incFields = append(incFields, k)
  88. vars = append(vars, v)
  89. continue
  90. }
  91. }
  92. if len(vars) == 0 {
  93. return nil, ErrNoDataToUpdate
  94. }
  95. if len(primaries) == 0 {
  96. return nil, ErrEmptyPrimaryKey
  97. }
  98. var buf bytes.Buffer
  99. buf.WriteString(fmt.Sprintf("UPDATE %s%s%s SET ", qChar, t.name, qChar))
  100. first := true
  101. for _, k := range incFields {
  102. if first {
  103. first = false
  104. } else {
  105. buf.WriteString(", ")
  106. }
  107. buf.WriteString(fmt.Sprintf("%s%s%s = %s%s%s %s ?", qChar, k, qChar, qChar, k, qChar, opcode))
  108. }
  109. for _, versionField := range versionFields {
  110. buf.WriteString(fmt.Sprintf(", %s%s%s = %s%s%s + 1", qChar, versionField, qChar, qChar, versionField, qChar))
  111. }
  112. for _, updatedField := range updatedFields {
  113. buf.WriteString(fmt.Sprintf(", %s%s%s = %s", qChar, updatedField, qChar, t.Database().backend.CurrentUTCTimeStampString()))
  114. }
  115. buf.WriteString(" WHERE ")
  116. for i, pkv := range primaries {
  117. if i > 0 {
  118. buf.WriteString(" AND ")
  119. }
  120. buf.WriteString(fmt.Sprintf("%s%s%s = ?", qChar, pkv.key, qChar))
  121. vars = append(vars, pkv.value)
  122. }
  123. if DEBUG_SQLCHEMY {
  124. log.Infof("Update: %s %s", buf.String(), vars)
  125. }
  126. return &SUpdateSQLResult{
  127. Sql: buf.String(),
  128. Vars: vars,
  129. primaries: primaries,
  130. }, nil
  131. }
  132. func (t *STableSpec) incrementInternal(diff interface{}, opcode string, target interface{}) error {
  133. if target == nil {
  134. if reflect.ValueOf(diff).Kind() != reflect.Ptr {
  135. return errors.Wrap(ErrNeedsPointer, "Incremental input must be a Pointer")
  136. }
  137. } else {
  138. if reflect.ValueOf(target).Kind() != reflect.Ptr {
  139. return errors.Wrap(ErrNeedsPointer, "Incremental update target must be a Pointer")
  140. }
  141. }
  142. intResult, err := t.incrementInternalSql(diff, opcode, target)
  143. if target != nil {
  144. err = t.execUpdateSql(target, intResult)
  145. } else {
  146. err = t.execUpdateSql(diff, intResult)
  147. }
  148. if err != nil {
  149. return errors.Wrap(err, "query after update failed")
  150. }
  151. return nil
  152. }