insert.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package sqlchemy
  15. import (
  16. "fmt"
  17. "reflect"
  18. "strings"
  19. "time"
  20. "yunion.io/x/log"
  21. "yunion.io/x/pkg/errors"
  22. "yunion.io/x/pkg/gotypes"
  23. "yunion.io/x/pkg/util/reflectutils"
  24. )
  25. // Insert perform a insert operation, the value of the record is store in dt
  26. func (t *STableSpec) Insert(dt interface{}) error {
  27. if !t.Database().backend.CanInsert() {
  28. return errors.Wrap(errors.ErrNotSupported, "Insert")
  29. }
  30. return t.insert(dt, false, false)
  31. }
  32. // InsertOrUpdate perform a insert or update operation, the value of the record is string in dt
  33. // MySQL: INSERT INTO ... ON DUPLICATE KEY UPDATE ...
  34. // works only for the cases that all values of primary keys are determeted before insert
  35. func (t *STableSpec) InsertOrUpdate(dt interface{}) error {
  36. if !t.Database().backend.CanInsertOrUpdate() {
  37. if !t.Database().backend.CanUpdate() {
  38. return t.insert(dt, false, false)
  39. } else {
  40. return errors.Wrap(errors.ErrNotSupported, "InsertOrUpdate")
  41. }
  42. }
  43. return t.insert(dt, true, false)
  44. }
  45. type InsertSqlResult struct {
  46. Sql string
  47. Values []interface{}
  48. Primaries map[string]interface{}
  49. }
  50. func (t *STableSpec) InsertSqlPrep(data interface{}, update bool) (*InsertSqlResult, error) {
  51. beforeInsert(reflect.ValueOf(data))
  52. dataValue := reflect.ValueOf(data).Elem()
  53. dataFields := reflectutils.FetchStructFieldValueSet(dataValue)
  54. var autoIncField string
  55. createdAtFields := make([]string, 0)
  56. now := time.Now().UTC()
  57. names := make([]string, 0)
  58. format := make([]string, 0)
  59. values := make([]interface{}, 0)
  60. updates := make([]string, 0)
  61. updateValues := make([]interface{}, 0)
  62. primaryKeys := make([]string, 0)
  63. primaries := make(map[string]interface{})
  64. qChar := t.Database().backend.QuoteChar()
  65. for _, c := range t.Columns() {
  66. isAutoInc := false
  67. if c.IsAutoIncrement() {
  68. isAutoInc = true
  69. }
  70. k := c.Name()
  71. ov, find := dataFields.GetInterface(k)
  72. if !find {
  73. continue
  74. }
  75. if c.IsPrimary() {
  76. primaryKeys = append(primaryKeys, fmt.Sprintf("%s%s%s", qChar, k, qChar))
  77. }
  78. // created_at or updated_at but must not be a primary key
  79. if c.IsCreatedAt() || c.IsUpdatedAt() {
  80. createdAtFields = append(createdAtFields, k)
  81. names = append(names, fmt.Sprintf("%s%s%s", qChar, k, qChar))
  82. if c.IsZero(ov) {
  83. if t.Database().backend.SupportMixedInsertVariables() {
  84. format = append(format, t.Database().backend.CurrentUTCTimeStampString())
  85. } else {
  86. values = append(values, now)
  87. format = append(format, "?")
  88. }
  89. } else {
  90. values = append(values, ov)
  91. format = append(format, "?")
  92. }
  93. if update && c.IsUpdatedAt() && !c.IsPrimary() {
  94. if c.IsZero(ov) {
  95. updates = append(updates, fmt.Sprintf("%s%s%s = %s", qChar, k, qChar, t.Database().backend.CurrentUTCTimeStampString()))
  96. // updateValues = append(updateValues, now)
  97. } else {
  98. updates = append(updates, fmt.Sprintf("%s%s%s = ?", qChar, k, qChar))
  99. updateValues = append(updateValues, ov)
  100. }
  101. }
  102. // unlikely if created or updated as a primary key but exec an insertOrUpdate query. QIUJIAN 2022/6/5
  103. // if c.IsPrimary() {
  104. // if c.IsZero(ov) {
  105. // primaries[k] = now
  106. // } else {
  107. // primaries[k] = ov
  108. // }
  109. // }
  110. continue
  111. }
  112. // auto_version and must not be a primary key
  113. if update && c.IsAutoVersion() {
  114. updates = append(updates, fmt.Sprintf("%s%s%s = %s%s%s + 1", qChar, k, qChar, qChar, k, qChar))
  115. continue
  116. }
  117. // empty but with default
  118. if c.IsSupportDefault() && (len(c.Default()) > 0 || c.IsString()) && !gotypes.IsNil(ov) && c.IsZero(ov) && !c.AllowZero() { // empty text value
  119. val := c.ConvertFromString(c.Default())
  120. values = append(values, val)
  121. names = append(names, fmt.Sprintf("%s%s%s", qChar, k, qChar))
  122. format = append(format, "?")
  123. if update && !c.IsPrimary() {
  124. updates = append(updates, fmt.Sprintf("%s%s%s = ?", qChar, k, qChar))
  125. updateValues = append(updateValues, val)
  126. }
  127. if c.IsPrimary() {
  128. primaries[k] = val
  129. }
  130. continue
  131. }
  132. // not empty
  133. if !gotypes.IsNil(ov) && (!c.IsZero(ov) || (!c.IsPointer() && !c.IsText())) && !isAutoInc {
  134. // validate text width
  135. if c.IsString() && c.GetWidth() > 0 {
  136. newStr, ok := ov.(string)
  137. if ok && len(newStr) > c.GetWidth() {
  138. ov = newStr[:c.GetWidth()]
  139. }
  140. }
  141. v := c.ConvertFromValue(ov)
  142. values = append(values, v)
  143. names = append(names, fmt.Sprintf("%s%s%s", qChar, k, qChar))
  144. format = append(format, "?")
  145. if update && !c.IsPrimary() {
  146. updates = append(updates, fmt.Sprintf("%s%s%s = ?", qChar, k, qChar))
  147. updateValues = append(updateValues, v)
  148. }
  149. if c.IsPrimary() {
  150. primaries[k] = v
  151. }
  152. continue
  153. }
  154. // empty primary but is autoinc or text
  155. if c.IsPrimary() {
  156. if isAutoInc {
  157. if len(autoIncField) > 0 {
  158. panic(fmt.Sprintf("multiple auto_increment columns: %q, %q", autoIncField, k))
  159. }
  160. autoIncField = k
  161. } else if c.IsText() {
  162. values = append(values, "")
  163. names = append(names, fmt.Sprintf("%s%s%s", qChar, k, qChar))
  164. format = append(format, "?")
  165. primaries[k] = ""
  166. } else {
  167. return nil, errors.Wrapf(ErrEmptyPrimaryKey, "cannot insert for null primary key %q", k)
  168. }
  169. continue
  170. }
  171. // empty without default
  172. if update {
  173. updates = append(updates, fmt.Sprintf("%s%s%s = NULL", qChar, k, qChar))
  174. continue
  175. }
  176. }
  177. var insertSql string
  178. if !update {
  179. insertSql = TemplateEval(t.Database().backend.InsertSQLTemplate(), struct {
  180. Table string
  181. Columns string
  182. Values string
  183. }{
  184. Table: t.name,
  185. Columns: strings.Join(names, ", "),
  186. Values: strings.Join(format, ", "),
  187. })
  188. } else {
  189. sqlTemp := t.Database().backend.InsertOrUpdateSQLTemplate()
  190. if len(sqlTemp) > 0 {
  191. // insert into ... on duplicate update ... pattern
  192. insertSql = TemplateEval(sqlTemp, struct {
  193. Table string
  194. Columns string
  195. Values string
  196. PrimaryKeys string
  197. SetValues string
  198. }{
  199. Table: t.name,
  200. Columns: strings.Join(names, ", "),
  201. Values: strings.Join(format, ", "),
  202. PrimaryKeys: strings.Join(primaryKeys, ", "),
  203. SetValues: strings.Join(updates, ", "),
  204. })
  205. values = append(values, updateValues...)
  206. } else {
  207. // customize pattern
  208. insertSql, values = t.Database().backend.PrepareInsertOrUpdateSQL(t, names, format, primaryKeys, updates, values, updateValues)
  209. }
  210. }
  211. return &InsertSqlResult{
  212. Sql: insertSql,
  213. Values: values,
  214. Primaries: primaries,
  215. }, nil
  216. }
  217. func beforeInsert(val reflect.Value) {
  218. switch val.Kind() {
  219. case reflect.Struct:
  220. structType := val.Type()
  221. for i := 0; i < val.NumField(); i++ {
  222. fieldType := structType.Field(i)
  223. if fieldType.Anonymous {
  224. beforeInsert(val.Field(i))
  225. }
  226. }
  227. valPtr := val.Addr()
  228. afterMarshalFunc := valPtr.MethodByName("BeforeInsert")
  229. if afterMarshalFunc.IsValid() && !afterMarshalFunc.IsNil() {
  230. afterMarshalFunc.Call([]reflect.Value{})
  231. }
  232. case reflect.Ptr:
  233. beforeInsert(val.Elem())
  234. }
  235. }
  236. func (t *STableSpec) insert(data interface{}, update bool, debug bool) error {
  237. insertResult, err := t.InsertSqlPrep(data, update)
  238. if err != nil {
  239. return errors.Wrap(err, "insertSqlPrep")
  240. }
  241. if DEBUG_SQLCHEMY || debug {
  242. log.Debugf("%s values: %#v", insertResult.Sql, insertResult.Values)
  243. }
  244. results, err := t.Database().TxExec(insertResult.Sql, insertResult.Values...)
  245. if err != nil {
  246. return errors.Wrap(err, "TxExec")
  247. }
  248. if t.Database().backend.CanSupportRowAffected() {
  249. affectCnt, err := results.RowsAffected()
  250. if err != nil {
  251. return err
  252. }
  253. targetCnt := int64(1)
  254. if update {
  255. // for insertOrUpdate cases, if no duplication, targetCnt=1, else targetCnt=2
  256. targetCnt = 2
  257. }
  258. if (!update && affectCnt < 1) || affectCnt > targetCnt {
  259. return errors.Wrapf(ErrUnexpectRowCount, "Insert affected cnt %d != (1, %d)", affectCnt, targetCnt)
  260. }
  261. }
  262. /*
  263. if len(autoIncField) > 0 {
  264. lastId, err := results.LastInsertId()
  265. if err == nil {
  266. val, ok := reflectutils.FindStructFieldValue(dataValue, autoIncField)
  267. if ok {
  268. gotypes.SetValue(val, fmt.Sprint(lastId))
  269. }
  270. }
  271. }
  272. */
  273. // query the value, so default value can be feedback into the object
  274. // fields = reflectutils.FetchStructFieldNameValueInterfaces(dataValue)
  275. q := t.Query()
  276. for _, c := range t.Columns() {
  277. if c.IsPrimary() {
  278. if c.IsAutoIncrement() {
  279. lastId, err := results.LastInsertId()
  280. if err != nil {
  281. return errors.Wrap(err, "fetching lastInsertId failed")
  282. }
  283. q = q.Equals(c.Name(), lastId)
  284. } else {
  285. priVal, _ := insertResult.Primaries[c.Name()]
  286. if !gotypes.IsNil(priVal) {
  287. q = q.Equals(c.Name(), priVal)
  288. }
  289. }
  290. }
  291. }
  292. err = q.First(data)
  293. if err != nil {
  294. return errors.Wrap(err, "query after insert failed")
  295. }
  296. return nil
  297. }