xorm/get.go
2023-10-28 10:59:32 +00:00

287 lines
6.9 KiB
Go

// Copyright 2016 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"database/sql"
"errors"
"fmt"
"math/big"
"reflect"
"time"
"xorm.io/xorm/v2/internal/convert"
"xorm.io/xorm/v2/internal/core"
"xorm.io/xorm/v2/internal/utils"
"xorm.io/xorm/v2/schemas"
)
// ErrObjectIsNil return error of object is nil
var ErrObjectIsNil = errors.New("object should not be nil")
// Get retrieve one record from database, bean's non-empty fields
// will be as conditions
func (session *Session) Get(beans ...any) (bool, error) {
if session.isAutoClose {
defer session.Close()
}
return session.get(beans...)
}
func isPtrOfTime(v any) bool {
if _, ok := v.(*time.Time); ok {
return true
}
el := reflect.ValueOf(v).Elem()
if el.Kind() != reflect.Struct {
return false
}
return el.Type().ConvertibleTo(schemas.TimeType)
}
func (session *Session) get(beans ...any) (bool, error) {
defer session.resetStatement()
if session.statement.LastError != nil {
return false, session.statement.LastError
}
if len(beans) == 0 {
return false, errors.New("needs at least one parameter for get")
}
beanValue := reflect.ValueOf(beans[0])
if beanValue.Kind() != reflect.Ptr {
return false, errors.New("needs a pointer to a value")
} else if beanValue.Elem().Kind() == reflect.Ptr {
return false, errors.New("a pointer to a pointer is not allowed")
} else if beanValue.IsNil() {
return false, ErrObjectIsNil
}
isStruct := beanValue.Elem().Kind() == reflect.Struct && !isPtrOfTime(beans[0])
if isStruct {
if err := session.statement.SetRefBean(beans[0]); err != nil {
return false, err
}
}
var sqlStr string
var args []any
var err error
if session.statement.RawSQL == "" {
if len(session.statement.TableName()) == 0 {
return false, ErrTableNotFound
}
session.statement.Limit(1)
sqlStr, args, err = session.statement.GenGetSQL(beans[0])
if err != nil {
return false, err
}
} else {
sqlStr = session.statement.GenRawSQL()
args = session.statement.RawParams
}
table := session.statement.RefTable
context := session.statement.Context
if context != nil && isStruct {
res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args))
if res != nil {
session.engine.logger.Debugf("hit context cache: %s", sqlStr)
structValue := reflect.Indirect(reflect.ValueOf(beans[0]))
structValue.Set(reflect.Indirect(reflect.ValueOf(res)))
session.lastSQL = ""
session.lastSQLArgs = nil
return true, nil
}
}
has, err := session.nocacheGet(beanValue.Elem().Kind(), table, beans, sqlStr, args...)
if err != nil || !has {
return has, err
}
if context != nil && isStruct {
context.Put(fmt.Sprintf("%v-%v", sqlStr, args), beans[0])
}
return true, nil
}
func isScannableStruct(bean any, typeLen int) bool {
switch bean.(type) {
case *time.Time:
return false
case sql.Scanner:
return false
case convert.Conversion:
return typeLen > 1
case *big.Float:
return false
}
return true
}
func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, beans []any, sqlStr string, args ...any) (bool, error) {
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return false, err
}
defer rows.Close()
if !rows.Next() {
return false, rows.Err()
}
// WARN: Alougth rows return true, but we may also return error.
types, err := rows.ColumnTypes()
if err != nil {
return true, err
}
fields, err := rows.Columns()
if err != nil {
return true, err
}
columnsSchema := ParseColumnsSchema(fields, types, table)
if err := session.scan(rows, table, beanKind, beans, columnsSchema, types, fields); err != nil {
return true, err
}
rows.Close()
return true, session.executeProcessors()
}
func (session *Session) scan(rows *core.Rows, table *schemas.Table, firstBeanKind reflect.Kind, beans []any, columnsSchema *ColumnsSchema, types []*sql.ColumnType, fields []string) error {
if len(beans) == 1 {
bean := beans[0]
switch firstBeanKind {
case reflect.Struct:
if !isScannableStruct(bean, len(types)) {
break
}
scanResults, err := session.row2Slice(rows, fields, types, bean)
if err != nil {
return err
}
dataStruct := utils.ReflectValue(bean)
_, err = session.slice2Bean(scanResults, columnsSchema, fields, bean, &dataStruct, table)
return err
case reflect.Slice:
return session.getSlice(rows, types, fields, bean)
case reflect.Map:
return session.getMap(rows, types, fields, bean)
}
}
if len(beans) != len(types) {
return fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans))
}
return session.engine.scan(rows, fields, types, beans...)
}
func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean any) error {
switch t := bean.(type) {
case *[]string:
res, err := session.engine.scanStringInterface(rows, fields, types)
if err != nil {
return err
}
needAppend := len(*t) == 0 // both support slice is empty or has been initlized
for i, r := range res {
if needAppend {
*t = append(*t, r.(*sql.NullString).String)
} else {
(*t)[i] = r.(*sql.NullString).String
}
}
return nil
case *[]any:
scanResults, err := session.engine.scanInterfaces(rows, fields, types)
if err != nil {
return err
}
needAppend := len(*t) == 0
for ii := range fields {
s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
if err != nil {
return err
}
if needAppend {
*t = append(*t, s)
} else {
(*t)[ii] = s
}
}
return nil
default:
return fmt.Errorf("unspoorted slice type: %t", t)
}
}
func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean any) error {
switch t := bean.(type) {
case *map[string]string:
scanResults, err := session.engine.scanStringInterface(rows, fields, types)
if err != nil {
return err
}
for ii, key := range fields {
(*t)[key] = scanResults[ii].(*sql.NullString).String
}
return nil
case *map[string]any:
scanResults, err := session.engine.scanInterfaces(rows, fields, types)
if err != nil {
return err
}
for ii, key := range fields {
s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
if err != nil {
return err
}
(*t)[key] = s
}
return nil
default:
return fmt.Errorf("unspoorted map type: %t", t)
}
}
// Exist returns true if the record exist otherwise return false
func (session *Session) Exist(bean ...any) (bool, error) {
if session.isAutoClose {
defer session.Close()
}
if session.statement.LastError != nil {
return false, session.statement.LastError
}
sqlStr, args, err := session.statement.GenExistSQL(bean...)
if err != nil {
return false, err
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return false, err
}
defer rows.Close()
if rows.Next() {
return true, nil
}
return false, rows.Err()
}