287 lines
6.9 KiB
Go
287 lines
6.9 KiB
Go
// Copyright 2016 The Xorm Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package xorm
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"math/big"
|
|
"reflect"
|
|
"time"
|
|
|
|
"xorm.io/xorm/v2/internal/convert"
|
|
"xorm.io/xorm/v2/internal/core"
|
|
"xorm.io/xorm/v2/internal/utils"
|
|
"xorm.io/xorm/v2/schemas"
|
|
)
|
|
|
|
// ErrObjectIsNil return error of object is nil
|
|
var ErrObjectIsNil = errors.New("object should not be nil")
|
|
|
|
// Get retrieve one record from database, bean's non-empty fields
|
|
// will be as conditions
|
|
func (session *Session) Get(beans ...any) (bool, error) {
|
|
if session.isAutoClose {
|
|
defer session.Close()
|
|
}
|
|
return session.get(beans...)
|
|
}
|
|
|
|
func isPtrOfTime(v any) bool {
|
|
if _, ok := v.(*time.Time); ok {
|
|
return true
|
|
}
|
|
|
|
el := reflect.ValueOf(v).Elem()
|
|
if el.Kind() != reflect.Struct {
|
|
return false
|
|
}
|
|
|
|
return el.Type().ConvertibleTo(schemas.TimeType)
|
|
}
|
|
|
|
func (session *Session) get(beans ...any) (bool, error) {
|
|
defer session.resetStatement()
|
|
|
|
if session.statement.LastError != nil {
|
|
return false, session.statement.LastError
|
|
}
|
|
if len(beans) == 0 {
|
|
return false, errors.New("needs at least one parameter for get")
|
|
}
|
|
|
|
beanValue := reflect.ValueOf(beans[0])
|
|
if beanValue.Kind() != reflect.Ptr {
|
|
return false, errors.New("needs a pointer to a value")
|
|
} else if beanValue.Elem().Kind() == reflect.Ptr {
|
|
return false, errors.New("a pointer to a pointer is not allowed")
|
|
} else if beanValue.IsNil() {
|
|
return false, ErrObjectIsNil
|
|
}
|
|
|
|
isStruct := beanValue.Elem().Kind() == reflect.Struct && !isPtrOfTime(beans[0])
|
|
if isStruct {
|
|
if err := session.statement.SetRefBean(beans[0]); err != nil {
|
|
return false, err
|
|
}
|
|
}
|
|
|
|
var sqlStr string
|
|
var args []any
|
|
var err error
|
|
|
|
if session.statement.RawSQL == "" {
|
|
if len(session.statement.TableName()) == 0 {
|
|
return false, ErrTableNotFound
|
|
}
|
|
session.statement.Limit(1)
|
|
sqlStr, args, err = session.statement.GenGetSQL(beans[0])
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
} else {
|
|
sqlStr = session.statement.GenRawSQL()
|
|
args = session.statement.RawParams
|
|
}
|
|
|
|
table := session.statement.RefTable
|
|
context := session.statement.Context
|
|
if context != nil && isStruct {
|
|
res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args))
|
|
if res != nil {
|
|
session.engine.logger.Debugf("hit context cache: %s", sqlStr)
|
|
|
|
structValue := reflect.Indirect(reflect.ValueOf(beans[0]))
|
|
structValue.Set(reflect.Indirect(reflect.ValueOf(res)))
|
|
session.lastSQL = ""
|
|
session.lastSQLArgs = nil
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
has, err := session.nocacheGet(beanValue.Elem().Kind(), table, beans, sqlStr, args...)
|
|
if err != nil || !has {
|
|
return has, err
|
|
}
|
|
|
|
if context != nil && isStruct {
|
|
context.Put(fmt.Sprintf("%v-%v", sqlStr, args), beans[0])
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func isScannableStruct(bean any, typeLen int) bool {
|
|
switch bean.(type) {
|
|
case *time.Time:
|
|
return false
|
|
case sql.Scanner:
|
|
return false
|
|
case convert.Conversion:
|
|
return typeLen > 1
|
|
case *big.Float:
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, beans []any, sqlStr string, args ...any) (bool, error) {
|
|
rows, err := session.queryRows(sqlStr, args...)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
if !rows.Next() {
|
|
return false, rows.Err()
|
|
}
|
|
|
|
// WARN: Alougth rows return true, but we may also return error.
|
|
types, err := rows.ColumnTypes()
|
|
if err != nil {
|
|
return true, err
|
|
}
|
|
fields, err := rows.Columns()
|
|
if err != nil {
|
|
return true, err
|
|
}
|
|
|
|
columnsSchema := ParseColumnsSchema(fields, types, table)
|
|
|
|
if err := session.scan(rows, table, beanKind, beans, columnsSchema, types, fields); err != nil {
|
|
return true, err
|
|
}
|
|
rows.Close()
|
|
|
|
return true, session.executeProcessors()
|
|
}
|
|
|
|
func (session *Session) scan(rows *core.Rows, table *schemas.Table, firstBeanKind reflect.Kind, beans []any, columnsSchema *ColumnsSchema, types []*sql.ColumnType, fields []string) error {
|
|
if len(beans) == 1 {
|
|
bean := beans[0]
|
|
switch firstBeanKind {
|
|
case reflect.Struct:
|
|
if !isScannableStruct(bean, len(types)) {
|
|
break
|
|
}
|
|
scanResults, err := session.row2Slice(rows, fields, types, bean)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
dataStruct := utils.ReflectValue(bean)
|
|
_, err = session.slice2Bean(scanResults, columnsSchema, fields, bean, &dataStruct, table)
|
|
return err
|
|
case reflect.Slice:
|
|
return session.getSlice(rows, types, fields, bean)
|
|
case reflect.Map:
|
|
return session.getMap(rows, types, fields, bean)
|
|
}
|
|
}
|
|
|
|
if len(beans) != len(types) {
|
|
return fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans))
|
|
}
|
|
|
|
return session.engine.scan(rows, fields, types, beans...)
|
|
}
|
|
|
|
func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean any) error {
|
|
switch t := bean.(type) {
|
|
case *[]string:
|
|
res, err := session.engine.scanStringInterface(rows, fields, types)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
needAppend := len(*t) == 0 // both support slice is empty or has been initlized
|
|
for i, r := range res {
|
|
if needAppend {
|
|
*t = append(*t, r.(*sql.NullString).String)
|
|
} else {
|
|
(*t)[i] = r.(*sql.NullString).String
|
|
}
|
|
}
|
|
return nil
|
|
case *[]any:
|
|
scanResults, err := session.engine.scanInterfaces(rows, fields, types)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
needAppend := len(*t) == 0
|
|
for ii := range fields {
|
|
s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if needAppend {
|
|
*t = append(*t, s)
|
|
} else {
|
|
(*t)[ii] = s
|
|
}
|
|
}
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("unspoorted slice type: %t", t)
|
|
}
|
|
}
|
|
|
|
func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean any) error {
|
|
switch t := bean.(type) {
|
|
case *map[string]string:
|
|
scanResults, err := session.engine.scanStringInterface(rows, fields, types)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for ii, key := range fields {
|
|
(*t)[key] = scanResults[ii].(*sql.NullString).String
|
|
}
|
|
return nil
|
|
case *map[string]any:
|
|
scanResults, err := session.engine.scanInterfaces(rows, fields, types)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for ii, key := range fields {
|
|
s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
(*t)[key] = s
|
|
}
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("unspoorted map type: %t", t)
|
|
}
|
|
}
|
|
|
|
// Exist returns true if the record exist otherwise return false
|
|
func (session *Session) Exist(bean ...any) (bool, error) {
|
|
if session.isAutoClose {
|
|
defer session.Close()
|
|
}
|
|
|
|
if session.statement.LastError != nil {
|
|
return false, session.statement.LastError
|
|
}
|
|
|
|
sqlStr, args, err := session.statement.GenExistSQL(bean...)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
rows, err := session.queryRows(sqlStr, args...)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
if rows.Next() {
|
|
return true, nil
|
|
}
|
|
return false, rows.Err()
|
|
}
|