update.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package sqlchemy
  15. import (
  16. "fmt"
  17. "reflect"
  18. "strings"
  19. "yunion.io/x/jsonutils"
  20. "yunion.io/x/log"
  21. "yunion.io/x/pkg/errors"
  22. "yunion.io/x/pkg/gotypes"
  23. "yunion.io/x/pkg/util/reflectutils"
  24. "yunion.io/x/pkg/utils"
  25. )
  26. // SUpdateSession is a struct to store the state of a update session
  27. type SUpdateSession struct {
  28. oValue reflect.Value
  29. tableSpec *STableSpec
  30. }
  31. func (ts *STableSpec) PrepareUpdate(dt interface{}) (*SUpdateSession, error) {
  32. if reflect.ValueOf(dt).Kind() != reflect.Ptr {
  33. return nil, errors.Wrap(ErrNeedsPointer, "Update input must be a Pointer")
  34. }
  35. dataValue := reflect.ValueOf(dt).Elem()
  36. fields := reflectutils.FetchStructFieldValueSet(dataValue) // fetchStructFieldNameValue(dataType, dataValue)
  37. zeroPrimary := make([]string, 0)
  38. for _, c := range ts.Columns() {
  39. k := c.Name()
  40. ov, ok := fields.GetInterface(k)
  41. if !ok {
  42. continue
  43. }
  44. if c.IsPrimary() && c.IsZero(ov) && !c.IsText() {
  45. zeroPrimary = append(zeroPrimary, k)
  46. }
  47. }
  48. if len(zeroPrimary) > 0 {
  49. return nil, errors.Wrapf(ErrEmptyPrimaryKey, "not a valid data, primary key %s empty",
  50. strings.Join(zeroPrimary, ","))
  51. }
  52. originValue := gotypes.DeepCopyRv(dataValue)
  53. us := SUpdateSession{oValue: originValue, tableSpec: ts}
  54. return &us, nil
  55. }
  56. // SUpdateDiff is a struct to store the differences for an update of a column
  57. type SUpdateDiff struct {
  58. old interface{}
  59. new interface{}
  60. col IColumnSpec
  61. }
  62. // String of SUpdateDiff returns the string representation of a SUpdateDiff
  63. func (ud *SUpdateDiff) String() string {
  64. return fmt.Sprintf("%s->%s",
  65. utils.TruncateString(ud.old, 32),
  66. utils.TruncateString(ud.new, 32))
  67. }
  68. func (ud SUpdateDiff) jsonObj() jsonutils.JSONObject {
  69. r := jsonutils.NewDict()
  70. r.Set("old", jsonutils.Marshal(ud.old))
  71. r.Set("new", jsonutils.Marshal(ud.new))
  72. return r
  73. }
  74. // UpdateDiffs is a map of SUpdateDiff whose key is the column name
  75. type UpdateDiffs map[string]SUpdateDiff
  76. // String of UpdateDiffs returns the string representation of UpdateDiffs
  77. func (uds UpdateDiffs) String() string {
  78. obj := jsonutils.NewDict()
  79. for i := range uds {
  80. obj.Set(uds[i].col.Name(), uds[i].jsonObj())
  81. }
  82. return obj.String()
  83. }
  84. func updateDiffList2Map(diffs []SUpdateDiff) UpdateDiffs {
  85. ret := make(map[string]SUpdateDiff)
  86. for i := range diffs {
  87. ret[diffs[i].col.Name()] = diffs[i]
  88. }
  89. return ret
  90. }
  91. type sPrimaryKeyValue struct {
  92. key string
  93. value interface{}
  94. }
  95. type SUpdateSQLResult struct {
  96. Sql string
  97. Vars []interface{}
  98. setters []SUpdateDiff
  99. primaries []sPrimaryKeyValue
  100. }
  101. func (us *SUpdateSession) SaveUpdateSql(dt interface{}) (*SUpdateSQLResult, error) {
  102. beforeUpdateFunc := reflect.ValueOf(dt).MethodByName("BeforeUpdate")
  103. if beforeUpdateFunc.IsValid() && !beforeUpdateFunc.IsNil() {
  104. beforeUpdateFunc.Call([]reflect.Value{})
  105. }
  106. // dataType := reflect.TypeOf(dt).Elem()
  107. dataValue := reflect.ValueOf(dt).Elem()
  108. ofields := reflectutils.FetchStructFieldValueSet(us.oValue)
  109. fields := reflectutils.FetchStructFieldValueSet(dataValue)
  110. versionFields := make([]string, 0)
  111. updatedFields := make([]string, 0)
  112. primaries := make([]sPrimaryKeyValue, 0)
  113. setters := make([]SUpdateDiff, 0)
  114. for _, c := range us.tableSpec.Columns() {
  115. k := c.Name()
  116. of, _ := ofields.GetInterface(k)
  117. nf, _ := fields.GetInterface(k)
  118. if c.IsPrimary() {
  119. if !gotypes.IsNil(of) && !c.IsZero(of) {
  120. if c.IsText() {
  121. ov, _ := of.(string)
  122. nv, _ := nf.(string)
  123. if ov != nv && strings.EqualFold(ov, nv) {
  124. setters = append(setters, SUpdateDiff{old: of, new: nf, col: c})
  125. }
  126. }
  127. primaries = append(primaries, sPrimaryKeyValue{
  128. key: k,
  129. value: c.ConvertFromValue(of),
  130. })
  131. } else if c.IsText() {
  132. primaries = append(primaries, sPrimaryKeyValue{
  133. key: k,
  134. value: "",
  135. })
  136. } else {
  137. return nil, ErrEmptyPrimaryKey
  138. }
  139. continue
  140. }
  141. if c.IsAutoVersion() {
  142. versionFields = append(versionFields, k)
  143. continue
  144. }
  145. if c.IsUpdatedAt() {
  146. updatedFields = append(updatedFields, k)
  147. continue
  148. }
  149. if reflect.DeepEqual(of, nf) {
  150. continue
  151. }
  152. if of != nil && nf != nil {
  153. ofJsonStr := jsonutils.Marshal(of).String()
  154. nfJsonStr := jsonutils.Marshal(nf).String()
  155. if ofJsonStr == nfJsonStr {
  156. continue
  157. }
  158. if EqualsGrossValue(of, nf) {
  159. continue
  160. }
  161. }
  162. if c.IsZero(nf) && c.IsText() {
  163. nf = nil
  164. }
  165. setters = append(setters, SUpdateDiff{old: of, new: nf, col: c})
  166. }
  167. if len(setters) == 0 {
  168. return nil, ErrNoDataToUpdate
  169. }
  170. if len(primaries) == 0 {
  171. return nil, ErrEmptyPrimaryKey
  172. }
  173. qChar := us.tableSpec.Database().backend.QuoteChar()
  174. vars := make([]interface{}, 0)
  175. colsets := make([]string, 0)
  176. conditions := make([]string, 0)
  177. for _, udif := range setters {
  178. if gotypes.IsNil(udif.new) {
  179. colsets = append(colsets, fmt.Sprintf("%s%s%s = NULL", qChar, udif.col.Name(), qChar))
  180. } else {
  181. // validate text length
  182. if udif.col.IsString() && udif.col.GetWidth() > 0 {
  183. newStr, ok := udif.new.(string)
  184. if ok && len(newStr) > udif.col.GetWidth() {
  185. udif.new = newStr[:udif.col.GetWidth()]
  186. }
  187. }
  188. colsets = append(colsets, fmt.Sprintf("%s%s%s = ?", qChar, udif.col.Name(), qChar))
  189. vars = append(vars, udif.col.ConvertFromValue(udif.new))
  190. }
  191. }
  192. for _, versionField := range versionFields {
  193. colsets = append(colsets, fmt.Sprintf("%s%s%s = %s%s%s + 1", qChar, versionField, qChar, qChar, versionField, qChar))
  194. }
  195. for _, updatedField := range updatedFields {
  196. colsets = append(colsets, fmt.Sprintf("%s%s%s = %s", qChar, updatedField, qChar, us.tableSpec.Database().backend.CurrentUTCTimeStampString()))
  197. }
  198. for _, pkv := range primaries {
  199. conditions = append(conditions, fmt.Sprintf("%s%s%s = ?", qChar, pkv.key, qChar))
  200. vars = append(vars, pkv.value)
  201. }
  202. updateSql := TemplateEval(us.tableSpec.Database().backend.UpdateSQLTemplate(), struct {
  203. Table string
  204. Columns string
  205. Conditions string
  206. }{
  207. Table: us.tableSpec.name,
  208. Columns: strings.Join(colsets, ", "),
  209. Conditions: strings.Join(conditions, " AND "),
  210. })
  211. if DEBUG_SQLCHEMY {
  212. log.Infof("Update: %s", _sqlDebug(updateSql, vars))
  213. }
  214. return &SUpdateSQLResult{
  215. Sql: updateSql,
  216. Vars: vars,
  217. setters: setters,
  218. primaries: primaries,
  219. }, nil
  220. }
  221. func (us *SUpdateSession) saveUpdate(dt interface{}) (UpdateDiffs, error) {
  222. sqlResult, err := us.SaveUpdateSql(dt)
  223. if err != nil {
  224. return nil, errors.Wrap(err, "saveUpateSql")
  225. }
  226. err = us.tableSpec.execUpdateSql(dt, sqlResult)
  227. if err != nil {
  228. return nil, errors.Wrap(err, "execUpdateSql")
  229. }
  230. return updateDiffList2Map(sqlResult.setters), nil
  231. }
  232. func (ts *STableSpec) execUpdateSql(dt interface{}, result *SUpdateSQLResult) error {
  233. results, err := ts.Database().TxExec(result.Sql, result.Vars...)
  234. if err != nil {
  235. return errors.Wrap(err, "TxExec")
  236. }
  237. if ts.Database().backend.CanSupportRowAffected() {
  238. aCnt, err := results.RowsAffected()
  239. if err != nil {
  240. return errors.Wrap(err, "results.RowsAffected")
  241. }
  242. if aCnt > 1 {
  243. return errors.Wrapf(ErrUnexpectRowCount, "affected rows %d != 1", aCnt)
  244. }
  245. }
  246. q := ts.Query()
  247. for _, pkv := range result.primaries {
  248. q = q.Equals(pkv.key, pkv.value)
  249. }
  250. err = q.First(dt)
  251. if err != nil {
  252. return errors.Wrapf(err, "query after update failed %s", q.DebugString())
  253. }
  254. return nil
  255. }
  256. // Update method of STableSpec updates a record of a table,
  257. // dt is the point to the struct storing the record
  258. // doUpdate provides method to update the field of the record
  259. func (ts *STableSpec) Update(dt interface{}, doUpdate func() error) (UpdateDiffs, error) {
  260. if !ts.Database().backend.CanUpdate() {
  261. return nil, errors.ErrNotSupported
  262. }
  263. session, err := ts.PrepareUpdate(dt)
  264. if err != nil {
  265. return nil, errors.Wrap(err, "prepareUpdate")
  266. }
  267. err = doUpdate()
  268. if err != nil {
  269. return nil, errors.Wrap(err, "")
  270. }
  271. uds, err := session.saveUpdate(dt)
  272. if err != nil && errors.Cause(err) == ErrNoDataToUpdate {
  273. return nil, nil
  274. } else if err == nil {
  275. if DEBUG_SQLCHEMY {
  276. log.Debugf("Update diff: %s", uds)
  277. }
  278. }
  279. return uds, errors.Wrap(err, "saveUpdate")
  280. }