| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- // Copyright 2019 Yunion
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package sqlchemy
- import (
- "fmt"
- "reflect"
- "strings"
- "time"
- "yunion.io/x/log"
- "yunion.io/x/pkg/errors"
- "yunion.io/x/pkg/gotypes"
- "yunion.io/x/pkg/util/reflectutils"
- )
- // Insert perform a insert operation, the value of the record is store in dt
- func (t *STableSpec) Insert(dt interface{}) error {
- if !t.Database().backend.CanInsert() {
- return errors.Wrap(errors.ErrNotSupported, "Insert")
- }
- return t.insert(dt, false, false)
- }
- // InsertOrUpdate perform a insert or update operation, the value of the record is string in dt
- // MySQL: INSERT INTO ... ON DUPLICATE KEY UPDATE ...
- // works only for the cases that all values of primary keys are determeted before insert
- func (t *STableSpec) InsertOrUpdate(dt interface{}) error {
- if !t.Database().backend.CanInsertOrUpdate() {
- if !t.Database().backend.CanUpdate() {
- return t.insert(dt, false, false)
- } else {
- return errors.Wrap(errors.ErrNotSupported, "InsertOrUpdate")
- }
- }
- return t.insert(dt, true, false)
- }
- type InsertSqlResult struct {
- Sql string
- Values []interface{}
- Primaries map[string]interface{}
- }
- func (t *STableSpec) InsertSqlPrep(data interface{}, update bool) (*InsertSqlResult, error) {
- beforeInsert(reflect.ValueOf(data))
- dataValue := reflect.ValueOf(data).Elem()
- dataFields := reflectutils.FetchStructFieldValueSet(dataValue)
- var autoIncField string
- createdAtFields := make([]string, 0)
- now := time.Now().UTC()
- names := make([]string, 0)
- format := make([]string, 0)
- values := make([]interface{}, 0)
- updates := make([]string, 0)
- updateValues := make([]interface{}, 0)
- primaryKeys := make([]string, 0)
- primaries := make(map[string]interface{})
- qChar := t.Database().backend.QuoteChar()
- for _, c := range t.Columns() {
- isAutoInc := false
- if c.IsAutoIncrement() {
- isAutoInc = true
- }
- k := c.Name()
- ov, find := dataFields.GetInterface(k)
- if !find {
- continue
- }
- if c.IsPrimary() {
- primaryKeys = append(primaryKeys, fmt.Sprintf("%s%s%s", qChar, k, qChar))
- }
- // created_at or updated_at but must not be a primary key
- if c.IsCreatedAt() || c.IsUpdatedAt() {
- createdAtFields = append(createdAtFields, k)
- names = append(names, fmt.Sprintf("%s%s%s", qChar, k, qChar))
- if c.IsZero(ov) {
- if t.Database().backend.SupportMixedInsertVariables() {
- format = append(format, t.Database().backend.CurrentUTCTimeStampString())
- } else {
- values = append(values, now)
- format = append(format, "?")
- }
- } else {
- values = append(values, ov)
- format = append(format, "?")
- }
- if update && c.IsUpdatedAt() && !c.IsPrimary() {
- if c.IsZero(ov) {
- updates = append(updates, fmt.Sprintf("%s%s%s = %s", qChar, k, qChar, t.Database().backend.CurrentUTCTimeStampString()))
- // updateValues = append(updateValues, now)
- } else {
- updates = append(updates, fmt.Sprintf("%s%s%s = ?", qChar, k, qChar))
- updateValues = append(updateValues, ov)
- }
- }
- // unlikely if created or updated as a primary key but exec an insertOrUpdate query. QIUJIAN 2022/6/5
- // if c.IsPrimary() {
- // if c.IsZero(ov) {
- // primaries[k] = now
- // } else {
- // primaries[k] = ov
- // }
- // }
- continue
- }
- // auto_version and must not be a primary key
- if update && c.IsAutoVersion() {
- updates = append(updates, fmt.Sprintf("%s%s%s = %s%s%s + 1", qChar, k, qChar, qChar, k, qChar))
- continue
- }
- // empty but with default
- if c.IsSupportDefault() && (len(c.Default()) > 0 || c.IsString()) && !gotypes.IsNil(ov) && c.IsZero(ov) && !c.AllowZero() { // empty text value
- val := c.ConvertFromString(c.Default())
- values = append(values, val)
- names = append(names, fmt.Sprintf("%s%s%s", qChar, k, qChar))
- format = append(format, "?")
- if update && !c.IsPrimary() {
- updates = append(updates, fmt.Sprintf("%s%s%s = ?", qChar, k, qChar))
- updateValues = append(updateValues, val)
- }
- if c.IsPrimary() {
- primaries[k] = val
- }
- continue
- }
- // not empty
- if !gotypes.IsNil(ov) && (!c.IsZero(ov) || (!c.IsPointer() && !c.IsText())) && !isAutoInc {
- // validate text width
- if c.IsString() && c.GetWidth() > 0 {
- newStr, ok := ov.(string)
- if ok && len(newStr) > c.GetWidth() {
- ov = newStr[:c.GetWidth()]
- }
- }
- v := c.ConvertFromValue(ov)
- values = append(values, v)
- names = append(names, fmt.Sprintf("%s%s%s", qChar, k, qChar))
- format = append(format, "?")
- if update && !c.IsPrimary() {
- updates = append(updates, fmt.Sprintf("%s%s%s = ?", qChar, k, qChar))
- updateValues = append(updateValues, v)
- }
- if c.IsPrimary() {
- primaries[k] = v
- }
- continue
- }
- // empty primary but is autoinc or text
- if c.IsPrimary() {
- if isAutoInc {
- if len(autoIncField) > 0 {
- panic(fmt.Sprintf("multiple auto_increment columns: %q, %q", autoIncField, k))
- }
- autoIncField = k
- } else if c.IsText() {
- values = append(values, "")
- names = append(names, fmt.Sprintf("%s%s%s", qChar, k, qChar))
- format = append(format, "?")
- primaries[k] = ""
- } else {
- return nil, errors.Wrapf(ErrEmptyPrimaryKey, "cannot insert for null primary key %q", k)
- }
- continue
- }
- // empty without default
- if update {
- updates = append(updates, fmt.Sprintf("%s%s%s = NULL", qChar, k, qChar))
- continue
- }
- }
- var insertSql string
- if !update {
- insertSql = TemplateEval(t.Database().backend.InsertSQLTemplate(), struct {
- Table string
- Columns string
- Values string
- }{
- Table: t.name,
- Columns: strings.Join(names, ", "),
- Values: strings.Join(format, ", "),
- })
- } else {
- sqlTemp := t.Database().backend.InsertOrUpdateSQLTemplate()
- if len(sqlTemp) > 0 {
- // insert into ... on duplicate update ... pattern
- insertSql = TemplateEval(sqlTemp, struct {
- Table string
- Columns string
- Values string
- PrimaryKeys string
- SetValues string
- }{
- Table: t.name,
- Columns: strings.Join(names, ", "),
- Values: strings.Join(format, ", "),
- PrimaryKeys: strings.Join(primaryKeys, ", "),
- SetValues: strings.Join(updates, ", "),
- })
- values = append(values, updateValues...)
- } else {
- // customize pattern
- insertSql, values = t.Database().backend.PrepareInsertOrUpdateSQL(t, names, format, primaryKeys, updates, values, updateValues)
- }
- }
- return &InsertSqlResult{
- Sql: insertSql,
- Values: values,
- Primaries: primaries,
- }, nil
- }
- func beforeInsert(val reflect.Value) {
- switch val.Kind() {
- case reflect.Struct:
- structType := val.Type()
- for i := 0; i < val.NumField(); i++ {
- fieldType := structType.Field(i)
- if fieldType.Anonymous {
- beforeInsert(val.Field(i))
- }
- }
- valPtr := val.Addr()
- afterMarshalFunc := valPtr.MethodByName("BeforeInsert")
- if afterMarshalFunc.IsValid() && !afterMarshalFunc.IsNil() {
- afterMarshalFunc.Call([]reflect.Value{})
- }
- case reflect.Ptr:
- beforeInsert(val.Elem())
- }
- }
- func (t *STableSpec) insert(data interface{}, update bool, debug bool) error {
- insertResult, err := t.InsertSqlPrep(data, update)
- if err != nil {
- return errors.Wrap(err, "insertSqlPrep")
- }
- if DEBUG_SQLCHEMY || debug {
- log.Debugf("%s values: %#v", insertResult.Sql, insertResult.Values)
- }
- results, err := t.Database().TxExec(insertResult.Sql, insertResult.Values...)
- if err != nil {
- return errors.Wrap(err, "TxExec")
- }
- if t.Database().backend.CanSupportRowAffected() {
- affectCnt, err := results.RowsAffected()
- if err != nil {
- return err
- }
- targetCnt := int64(1)
- if update {
- // for insertOrUpdate cases, if no duplication, targetCnt=1, else targetCnt=2
- targetCnt = 2
- }
- if (!update && affectCnt < 1) || affectCnt > targetCnt {
- return errors.Wrapf(ErrUnexpectRowCount, "Insert affected cnt %d != (1, %d)", affectCnt, targetCnt)
- }
- }
- /*
- if len(autoIncField) > 0 {
- lastId, err := results.LastInsertId()
- if err == nil {
- val, ok := reflectutils.FindStructFieldValue(dataValue, autoIncField)
- if ok {
- gotypes.SetValue(val, fmt.Sprint(lastId))
- }
- }
- }
- */
- // query the value, so default value can be feedback into the object
- // fields = reflectutils.FetchStructFieldNameValueInterfaces(dataValue)
- q := t.Query()
- for _, c := range t.Columns() {
- if c.IsPrimary() {
- if c.IsAutoIncrement() {
- lastId, err := results.LastInsertId()
- if err != nil {
- return errors.Wrap(err, "fetching lastInsertId failed")
- }
- q = q.Equals(c.Name(), lastId)
- } else {
- priVal, _ := insertResult.Primaries[c.Name()]
- if !gotypes.IsNil(priVal) {
- q = q.Equals(c.Name(), priVal)
- }
- }
- }
- }
- err = q.First(data)
- if err != nil {
- return errors.Wrap(err, "query after insert failed")
- }
- return nil
- }
|