query.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package sqlchemy
  15. import (
  16. "database/sql"
  17. "fmt"
  18. "reflect"
  19. "runtime/debug"
  20. "strings"
  21. "yunion.io/x/log"
  22. "yunion.io/x/pkg/errors"
  23. "yunion.io/x/pkg/util/reflectutils"
  24. )
  25. // QueryJoinType is the Join type of SQL query, namely, innerjoin, leftjoin and rightjoin
  26. type QueryJoinType string
  27. const (
  28. // INNERJOIN represents innerjoin
  29. INNERJOIN QueryJoinType = "JOIN"
  30. // LEFTJOIN represents left join
  31. LEFTJOIN QueryJoinType = "LEFT JOIN"
  32. // RIGHTJOIN represents right-join
  33. RIGHTJOIN QueryJoinType = "RIGHT JOIN"
  34. // FULLJOIN QueryJoinType = "FULLJOIN"
  35. )
  36. // sQueryJoin represents the state of a Join Query
  37. type sQueryJoin struct {
  38. jointype QueryJoinType
  39. from IQuerySource
  40. condition ICondition
  41. }
  42. // SQuery is a data structure represents a SQL query in the form of
  43. //
  44. // SELECT ... FROM ... JOIN ... ON ... WHERE ... GROUP BY ... ORDER BY ... HAVING ...
  45. type SQuery struct {
  46. rawSql string
  47. fields []IQueryField
  48. distinct bool
  49. from IQuerySource
  50. joins []sQueryJoin
  51. where ICondition
  52. groupBy []IQueryField
  53. orderBy []sQueryOrder
  54. // having ICondition
  55. limit int
  56. offset int
  57. refFieldMap map[string]IQueryField
  58. snapshot string
  59. db *SDatabase
  60. }
  61. func (tq *SQuery) Copy() *SQuery {
  62. q := &SQuery{
  63. rawSql: tq.rawSql,
  64. fields: []IQueryField{},
  65. refFieldMap: map[string]IQueryField{},
  66. distinct: tq.distinct,
  67. from: tq.from,
  68. joins: []sQueryJoin{},
  69. where: tq.where,
  70. groupBy: []IQueryField{},
  71. orderBy: []sQueryOrder{},
  72. limit: tq.limit,
  73. offset: tq.offset,
  74. snapshot: tq.snapshot,
  75. db: tq.db,
  76. }
  77. for i := range tq.fields {
  78. q.fields = append(q.fields, tq.fields[i])
  79. }
  80. for k := range tq.refFieldMap {
  81. q.refFieldMap[k] = tq.refFieldMap[k]
  82. }
  83. for i := range tq.joins {
  84. q.joins = append(q.joins, tq.joins[i])
  85. }
  86. for i := range tq.groupBy {
  87. q.groupBy = append(q.groupBy, tq.groupBy[i])
  88. }
  89. for i := range tq.orderBy {
  90. q.orderBy = append(q.orderBy, tq.orderBy[i])
  91. }
  92. return q
  93. }
  94. // IsGroupBy returns wether the query contains group by clauses
  95. func (tq *SQuery) IsGroupBy() bool {
  96. return len(tq.groupBy) > 0
  97. }
  98. func (tq *SQuery) hasField(f IQueryField) bool {
  99. if len(tq.fields) == 0 {
  100. return false
  101. }
  102. for i := range tq.fields {
  103. if tq.fields[i].Name() == f.Name() {
  104. return true
  105. }
  106. }
  107. return false
  108. }
  109. // AppendField appends query field to a query
  110. func (tq *SQuery) AppendField(f ...IQueryField) *SQuery {
  111. // log.Debugf("AppendField tq has fields %d", len(tq.fields))
  112. for i := range f {
  113. if !tq.hasField(f[i]) {
  114. if refField, ok := tq.refFieldMap[f[i].Name()]; ok {
  115. tq.fields = append(tq.fields, refField)
  116. delete(tq.refFieldMap, f[i].Name())
  117. } else {
  118. tq.fields = append(tq.fields, f[i])
  119. }
  120. }
  121. }
  122. return tq
  123. }
  124. func (tq *SQuery) addRefField(f IQueryField) *SQuery {
  125. if tq.refFieldMap == nil {
  126. tq.refFieldMap = make(map[string]IQueryField)
  127. }
  128. if !tq.hasField(f) {
  129. if _, ok := tq.refFieldMap[f.Name()]; !ok {
  130. tq.refFieldMap[f.Name()] = f
  131. }
  132. }
  133. return tq
  134. }
  135. func (tq *SQuery) ResetFields() *SQuery {
  136. tq.fields = make([]IQueryField, 0)
  137. tq.refFieldMap = make(map[string]IQueryField)
  138. return tq
  139. }
  140. // Query of SSubQuery generates a new query from a subquery
  141. func (sq *SSubQuery) Query(f ...IQueryField) *SQuery {
  142. return DoQuery(sq, f...)
  143. }
  144. // Query of STable generates a new query from a table
  145. func (tbl *STable) Query(f ...IQueryField) *SQuery {
  146. return DoQuery(tbl, f...)
  147. }
  148. // Query of STableSpec generates a new query from a STableSpec instance
  149. func (ts *STableSpec) Query(f ...IQueryField) *SQuery {
  150. return ts.Instance().Query(f...)
  151. }
  152. // QueryOrderType indicates the query order type, either ASC or DESC
  153. type QueryOrderType string
  154. const (
  155. // SQL_ORDER_ASC represents Ascending order
  156. SQL_ORDER_ASC QueryOrderType = "ASC"
  157. // SQL_ORDER_DESC represents Descending order
  158. SQL_ORDER_DESC QueryOrderType = "DESC"
  159. )
  160. // Equals of QueryOrderType determines whether two order type identical
  161. func (qot QueryOrderType) Equals(orderType string) bool {
  162. if strings.ToUpper(orderType) == string(qot) {
  163. return true
  164. }
  165. return false
  166. }
  167. // internal structure to store state of query order
  168. type sQueryOrder struct {
  169. field IQueryField
  170. order QueryOrderType
  171. }
  172. func (tq *SQuery) _orderBy(order QueryOrderType, fields []IQueryField) *SQuery {
  173. if tq.orderBy == nil {
  174. tq.orderBy = make([]sQueryOrder, 0)
  175. }
  176. for i := range fields {
  177. tq.orderBy = append(tq.orderBy, sQueryOrder{field: fields[i], order: order})
  178. }
  179. return tq
  180. }
  181. // Asc of SQuery does query in ascending order of specified fields
  182. func (tq *SQuery) Asc(fields ...interface{}) *SQuery {
  183. return tq._orderBy(SQL_ORDER_ASC, convertQueryField(tq, fields))
  184. }
  185. // Desc of SQuery does query in descending order of specified fields
  186. func (tq *SQuery) Desc(fields ...interface{}) *SQuery {
  187. return tq._orderBy(SQL_ORDER_DESC, convertQueryField(tq, fields))
  188. }
  189. func convertQueryField(tq IQuery, fields []interface{}) []IQueryField {
  190. nFields := make([]IQueryField, 0)
  191. for _, f := range fields {
  192. switch ff := f.(type) {
  193. case string:
  194. nFields = append(nFields, tq.Field(ff))
  195. case IQueryField:
  196. nFields = append(nFields, ff)
  197. default:
  198. log.Errorf("Invalid query field %s neither string nor IQueryField", f)
  199. }
  200. }
  201. return nFields
  202. }
  203. // GroupBy of SQuery does query group by specified fields
  204. func (tq *SQuery) GroupBy(f ...interface{}) *SQuery {
  205. if tq.groupBy == nil {
  206. tq.groupBy = make([]IQueryField, 0)
  207. }
  208. qfs := convertQueryField(tq, f)
  209. tq.groupBy = append(tq.groupBy, qfs...)
  210. return tq
  211. }
  212. // Limit of SQuery adds limit to a query
  213. func (tq *SQuery) Limit(limit int) *SQuery {
  214. tq.limit = limit
  215. return tq
  216. }
  217. // Offset of SQuery adds offset to a query
  218. func (tq *SQuery) Offset(offset int) *SQuery {
  219. tq.offset = offset
  220. return tq
  221. }
  222. func (tq *SQuery) FieldCount() int {
  223. return len(tq.fields)
  224. }
  225. // QueryFields of SQuery returns fields in SELECT clause of a query
  226. func (tq *SQuery) QueryFields() []IQueryField {
  227. if len(tq.fields) > 0 {
  228. return tq.fields
  229. }
  230. return tq.from.Fields()
  231. }
  232. // String of SQuery implemetation of SQuery for IQuery
  233. func (tq *SQuery) String(fields ...IQueryField) string {
  234. sql := queryString(tq, fields...)
  235. // log.Debugf("Query: %s", sql)
  236. return sql
  237. }
  238. // Join of SQuery joins query with another IQuerySource on specified condition
  239. func (tq *SQuery) Join(from IQuerySource, on ICondition) *SQuery {
  240. return tq._join(from, on, INNERJOIN)
  241. }
  242. // LeftJoin of SQuery left-joins query with another IQuerySource on specified condition
  243. func (tq *SQuery) LeftJoin(from IQuerySource, on ICondition) *SQuery {
  244. return tq._join(from, on, LEFTJOIN)
  245. }
  246. // RightJoin of SQuery right-joins query with another IQuerySource on specified condition
  247. func (tq *SQuery) RightJoin(from IQuerySource, on ICondition) *SQuery {
  248. return tq._join(from, on, RIGHTJOIN)
  249. }
  250. /*func (tq *SQuery) FullJoin(from IQuerySource, on ICondition) *SQuery {
  251. return tq._join(from, on, FULLJOIN)
  252. }*/
  253. func (tq *SQuery) _join(from IQuerySource, on ICondition, joinType QueryJoinType) *SQuery {
  254. if from.database() != tq.db {
  255. panic(fmt.Sprintf("Cannot join across databases %s!=%s", tq.db.name, from.database().name))
  256. }
  257. if tq.joins == nil {
  258. tq.joins = make([]sQueryJoin, 0)
  259. }
  260. qj := sQueryJoin{jointype: joinType, from: from, condition: on}
  261. tq.joins = append(tq.joins, qj)
  262. return tq
  263. }
  264. // Variables implementation of SQuery for IQuery
  265. func (tq *SQuery) Variables() []interface{} {
  266. vars := make([]interface{}, 0)
  267. var fromvars []interface{}
  268. fields := tq.fields
  269. for i := range fields {
  270. fromvars = fields[i].Variables()
  271. vars = append(vars, fromvars...)
  272. }
  273. if tq.from != nil {
  274. fromvars = tq.from.Variables()
  275. vars = append(vars, fromvars...)
  276. }
  277. for _, join := range tq.joins {
  278. fromvars = join.from.Variables()
  279. vars = append(vars, fromvars...)
  280. fromvars = join.condition.Variables()
  281. vars = append(vars, fromvars...)
  282. }
  283. if tq.where != nil {
  284. fromvars = tq.where.Variables()
  285. vars = append(vars, fromvars...)
  286. }
  287. /*if tq.having != nil {
  288. fromvars = tq.having.Variables()
  289. vars = append(vars, fromvars...)
  290. }*/
  291. return vars
  292. }
  293. // Distinct of SQuery indicates a distinct query results
  294. func (tq *SQuery) Distinct() *SQuery {
  295. tq.distinct = true
  296. return tq
  297. }
  298. // SubQuery of SQuery generates a SSubQuery from a Query
  299. func (tq *SQuery) SubQuery() *SSubQuery {
  300. sq := SSubQuery{
  301. query: tq.Copy(),
  302. alias: getTableAliasName(),
  303. referedFields: make(map[string]IQueryField),
  304. }
  305. return &sq
  306. }
  307. func (tq *SQuery) database() *SDatabase {
  308. return tq.db
  309. }
  310. // Row of SQuery returns an instance of sql.Row for native data fetching
  311. func (tq *SQuery) Row() *sql.Row {
  312. sqlstr := tq.String()
  313. vars := tq.Variables()
  314. if DEBUG_SQLCHEMY {
  315. sqlDebug("SQuery.Row", sqlstr, vars)
  316. }
  317. if tq.db == nil {
  318. panic("tq.db")
  319. }
  320. if tq.db.db == nil {
  321. panic("tq.db.db")
  322. }
  323. return tq.db.db.QueryRow(sqlstr, vars...)
  324. }
  325. // Rows of SQuery returns an instance of sql.Rows for native data fetching
  326. func (tq *SQuery) Rows() (*sql.Rows, error) {
  327. sqlstr := tq.String()
  328. vars := tq.Variables()
  329. if DEBUG_SQLCHEMY {
  330. sqlDebug("SQuery.Rows", sqlstr, vars)
  331. }
  332. return tq.db.db.Query(sqlstr, vars...)
  333. }
  334. // Count of SQuery returns the count of a query
  335. // use CountWithError instead
  336. // deprecated
  337. func (tq *SQuery) Count() int {
  338. cnt, _ := tq.CountWithError()
  339. return cnt
  340. }
  341. func (tq *SQuery) CountQuery() *SQuery {
  342. tq2 := *tq
  343. tq2.limit = 0
  344. tq2.offset = 0
  345. cq := &SQuery{
  346. fields: []IQueryField{
  347. COUNT("count"),
  348. },
  349. from: tq2.SubQuery(),
  350. db: tq.database(),
  351. }
  352. return cq
  353. }
  354. // CountWithError of SQuery returns the row count of a query
  355. func (tq *SQuery) CountWithError() (int, error) {
  356. cq := tq.CountQuery()
  357. count := 0
  358. err := cq.Row().Scan(&count)
  359. if err == nil {
  360. return count, nil
  361. }
  362. log.Errorf("SQuery count %s failed: %s", cq.String(), err)
  363. return -1, err
  364. }
  365. // Field implementation of SQuery for IQuery
  366. func (tq *SQuery) Field(name string) IQueryField {
  367. f := tq.findField(name)
  368. if f == nil {
  369. log.Errorf("SQuery %s cannot find Field %s", tq.String(), name)
  370. debug.PrintStack()
  371. }
  372. return f
  373. }
  374. func (tq *SQuery) findField(name string) IQueryField {
  375. for _, f := range tq.fields {
  376. if f.Name() == name {
  377. // switch f.(type) {
  378. // case *SFunctionFieldBase:
  379. // log.Warningf("cannot directly reference a function alias, should use Subquery() to enclose the query")
  380. // }
  381. return f
  382. }
  383. }
  384. if f, ok := tq.refFieldMap[name]; ok {
  385. return f
  386. }
  387. f := tq.from.Field(name)
  388. if f != nil {
  389. return newQueryField(tq.from, name)
  390. }
  391. finds := make([]IQueryField, 0)
  392. for _, join := range tq.joins {
  393. f = join.from.Field(name)
  394. if f != nil {
  395. finds = append(finds, newQueryField(join.from, name))
  396. }
  397. }
  398. if len(finds) == 1 {
  399. return finds[0]
  400. } else if len(finds) > 1 {
  401. log.Errorf("Field %s found duplicated %d, please specifify the field", name, len(finds))
  402. return finds[0]
  403. }
  404. return nil
  405. }
  406. // IRowScanner is an interface for sql data fetching
  407. type IRowScanner interface {
  408. Scan(desc ...interface{}) error
  409. }
  410. func rowScan2StringMap(fields []string, row IRowScanner) (map[string]string, error) {
  411. targets := make([]interface{}, len(fields))
  412. for i := range fields {
  413. var recver interface{}
  414. targets[i] = &recver
  415. }
  416. if err := row.Scan(targets...); err != nil {
  417. return nil, err
  418. }
  419. results := make(map[string]string)
  420. for i, f := range fields {
  421. //log.Debugf("%d %s: %s", i, f, targets[i])
  422. rawValue := reflect.Indirect(reflect.ValueOf(targets[i]))
  423. if rawValue.Interface() == nil {
  424. results[f] = ""
  425. } else {
  426. value := rawValue.Interface()
  427. // log.Infof("%s %s", value, reflect.TypeOf(value))
  428. results[f] = GetStringValue(value)
  429. }
  430. }
  431. return results, nil
  432. }
  433. func (tq *SQuery) rowScan2StringMap(row IRowScanner) (map[string]string, error) {
  434. queryFields := tq.QueryFields()
  435. fields := make([]string, len(queryFields))
  436. for i, f := range queryFields {
  437. fields[i] = f.Name()
  438. }
  439. return rowScan2StringMap(fields, row)
  440. }
  441. // FirstStringMap returns query result of the first row in a stringmap(map[string]string)
  442. func (tq *SQuery) FirstStringMap() (map[string]string, error) {
  443. return tq.rowScan2StringMap(tq.Row())
  444. }
  445. // AllStringMap returns query result of all rows in an array of stringmap(map[string]string)
  446. func (tq *SQuery) AllStringMap() ([]map[string]string, error) {
  447. rows, err := tq.Rows()
  448. if err != nil {
  449. return nil, err
  450. }
  451. defer rows.Close()
  452. results := make([]map[string]string, 0)
  453. for rows.Next() {
  454. result, err := tq.rowScan2StringMap(rows)
  455. if err != nil {
  456. return nil, err
  457. }
  458. results = append(results, result)
  459. }
  460. return results, nil
  461. }
  462. func mapString2Struct(mapResult map[string]string, destValue reflect.Value) error {
  463. destFields := reflectutils.FetchStructFieldValueSet(destValue)
  464. var err error
  465. for k, v := range mapResult {
  466. if len(v) > 0 {
  467. fieldValue, ok := destFields.GetValue(k)
  468. if ok {
  469. err = setValueBySQLString(fieldValue, v)
  470. if err != nil {
  471. log.Errorf("Set field %q value error %s", k, err)
  472. }
  473. }
  474. }
  475. }
  476. return err
  477. }
  478. func callAfterQuery(val reflect.Value) {
  479. afterQueryFunc := val.MethodByName("AfterQuery")
  480. if afterQueryFunc.IsValid() && !afterQueryFunc.IsNil() {
  481. afterQueryFunc.Call([]reflect.Value{})
  482. }
  483. }
  484. // First return query result of first row and store the result in a data struct
  485. func (tq *SQuery) First(dest interface{}) error {
  486. mapResult, err := tq.FirstStringMap()
  487. if err != nil {
  488. return err
  489. }
  490. destPtrValue := reflect.ValueOf(dest)
  491. if destPtrValue.Kind() != reflect.Ptr {
  492. return errors.Wrap(ErrNeedsPointer, "input must be a pointer")
  493. }
  494. destValue := destPtrValue.Elem()
  495. err = mapString2Struct(mapResult, destValue)
  496. if err != nil {
  497. return err
  498. }
  499. callAfterQuery(destPtrValue)
  500. return nil
  501. }
  502. // All return query results of all rows and store the result in an array of data struct
  503. func (tq *SQuery) All(dest interface{}) error {
  504. arrayType := reflect.TypeOf(dest).Elem()
  505. if arrayType.Kind() != reflect.Array && arrayType.Kind() != reflect.Slice {
  506. return errors.Wrap(ErrNeedsArray, "dest is not an array or slice")
  507. }
  508. elemType := arrayType.Elem()
  509. mapResults, err := tq.AllStringMap()
  510. if err != nil {
  511. return err
  512. }
  513. arrayValue := reflect.ValueOf(dest).Elem()
  514. for _, mapV := range mapResults {
  515. elemPtrValue := reflect.New(elemType)
  516. elemValue := reflect.Indirect(elemPtrValue)
  517. err = mapString2Struct(mapV, elemValue)
  518. if err != nil {
  519. break
  520. }
  521. callAfterQuery(elemPtrValue)
  522. newArray := reflect.Append(arrayValue, elemValue)
  523. arrayValue.Set(newArray)
  524. }
  525. return err
  526. }
  527. // Row2Map is a utility function that fetch stringmap(map[string]string) from a native sql.Row or sql.Rows
  528. func (tq *SQuery) Row2Map(row IRowScanner) (map[string]string, error) {
  529. return tq.rowScan2StringMap(row)
  530. }
  531. // RowMap2Struct is a utility function that fetch struct from a native sql.Row or sql.Rows
  532. func (tq *SQuery) RowMap2Struct(result map[string]string, dest interface{}) error {
  533. destPtrValue := reflect.ValueOf(dest)
  534. if destPtrValue.Kind() != reflect.Ptr {
  535. return errors.Wrap(ErrNeedsPointer, "input must be a pointer")
  536. }
  537. destValue := destPtrValue.Elem()
  538. err := mapString2Struct(result, destValue)
  539. if err != nil {
  540. return err
  541. }
  542. callAfterQuery(destPtrValue)
  543. return nil
  544. }
  545. // Row2Struct is a utility function that fill a struct with the value of a sql.Row or sql.Rows
  546. func (tq *SQuery) Row2Struct(row IRowScanner, dest interface{}) error {
  547. result, err := tq.rowScan2StringMap(row)
  548. if err != nil {
  549. return err
  550. }
  551. return tq.RowMap2Struct(result, dest)
  552. }
  553. // Snapshot of SQuery take a snapshot of the query, so we can tell wether the query is modified later by comparing the SQL with snapshot
  554. func (tq *SQuery) Snapshot() *SQuery {
  555. tq.snapshot = tq.String()
  556. return tq
  557. }
  558. // IsAltered of SQuery indicates whether a query was altered. By comparing with the saved query snapshot, we can tell whether a query is altered
  559. func (tq *SQuery) IsAltered() bool {
  560. if len(tq.snapshot) == 0 {
  561. panic(fmt.Sprintf("Query %s has never been snapshot when IsAltered called", tq.String()))
  562. }
  563. return tq.String() != tq.snapshot
  564. }