Browse Source

Move statement as a sub package (#1564)

Fix test

Fix bug

Move statement as a sub package

Reviewed-on: #1564
tags/v1.0.0
Lunny Xiao 4 weeks ago
parent
commit
2b62dc5a51
50 changed files with 1985 additions and 1782 deletions
  1. +1
    -1
      contexts/context_cache.go
  2. +5
    -0
      dialects/dialect.go
  3. +4
    -0
      dialects/postgres.go
  4. +90
    -0
      dialects/table_name.go
  5. +7
    -5
      dialects/table_name_test.go
  6. +49
    -0
      dialects/time.go
  7. +24
    -69
      engine.go
  8. +0
    -234
      engine_cond.go
  9. +0
    -109
      engine_table.go
  10. +0
    -2
      error.go
  11. +0
    -29
      helpers.go
  12. +1
    -26
      interface.go
  13. +31
    -0
      internal/json/json.go
  14. +79
    -0
      internal/statements/cache.go
  15. +27
    -4
      internal/statements/column_map.go
  16. +17
    -17
      internal/statements/expr_param.go
  17. +448
    -0
      internal/statements/query.go
  18. +372
    -632
      internal/statements/statement.go
  19. +4
    -28
      internal/statements/statement_args.go
  20. +184
    -0
      internal/statements/statement_test.go
  21. +1
    -1
      internal/statements/types.go
  22. +280
    -0
      internal/statements/update.go
  23. +13
    -0
      internal/utils/name.go
  24. +13
    -0
      internal/utils/reflect.go
  25. +19
    -0
      internal/utils/sql.go
  26. +30
    -0
      internal/utils/strings.go
  27. +5
    -4
      rows.go
  28. +34
    -0
      schemas/quote.go
  29. +27
    -19
      session.go
  30. +0
    -13
      session_cols.go
  31. +1
    -1
      session_cond.go
  32. +11
    -10
      session_convert.go
  33. +6
    -6
      session_delete.go
  34. +5
    -75
      session_exist.go
  35. +19
    -69
      session_find.go
  36. +5
    -4
      session_find_test.go
  37. +9
    -8
      session_get.go
  38. +3
    -2
      session_get_test.go
  39. +40
    -40
      session_insert.go
  40. +8
    -6
      session_iterate.go
  41. +4
    -68
      session_query.go
  42. +2
    -16
      session_raw.go
  43. +13
    -13
      session_schema.go
  44. +7
    -24
      session_stats.go
  45. +3
    -2
      session_tx_test.go
  46. +46
    -46
      session_update.go
  47. +13
    -12
      session_update_test.go
  48. +0
    -174
      statement_test.go
  49. +20
    -9
      tags_test.go
  50. +5
    -4
      types_test.go

context_cache.go → contexts/context_cache.go View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm
package contexts

// ContextCache is the interface that operates the cache data.
type ContextCache interface {

+ 5
- 0
dialects/dialect.go View File

@@ -41,6 +41,7 @@ type Dialect interface {
DBType() DBType
SQLType(*schemas.Column) string
FormatBytes(b []byte) string
DefaultSchema() string

DriverName() string
DataSourceName() string
@@ -103,6 +104,10 @@ func (b *Base) SetLogger(logger log.Logger) {
b.logger = logger
}

func (b *Base) DefaultSchema() string {
return ""
}

func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI, drivername, dataSourceName string) error {
b.db, b.dialect, b.uri = db, dialect, uri
b.driverName, b.dataSourceName = drivername, dataSourceName


+ 4
- 0
dialects/postgres.go View File

@@ -788,6 +788,10 @@ func (db *postgres) Init(d *core.DB, uri *URI, drivername, dataSourceName string
return nil
}

func (db *postgres) DefaultSchema() string {
return PostgresPublicSchema
}

func (db *postgres) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {


+ 90
- 0
dialects/table_name.go View File

@@ -0,0 +1,90 @@
// Copyright 2015 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package dialects

import (
"fmt"
"reflect"
"strings"

"xorm.io/xorm/internal/utils"
"xorm.io/xorm/names"
)

// TableNameWithSchema will add schema prefix on table name if possible
func TableNameWithSchema(dialect Dialect, tableName string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if dialect.URI().Schema != "" &&
dialect.URI().Schema != dialect.DefaultSchema() &&
strings.Index(tableName, ".") == -1 {
return fmt.Sprintf("%s.%s", dialect.URI().Schema, tableName)
}
return tableName
}

// TableNameNoSchema returns table name with given tableName
func TableNameNoSchema(dialect Dialect, mapper names.Mapper, tableName interface{}) string {
quote := dialect.Quoter().Quote
switch tableName.(type) {
case []string:
t := tableName.([]string)
if len(t) > 1 {
return fmt.Sprintf("%v AS %v", quote(t[0]), quote(t[1]))
} else if len(t) == 1 {
return quote(t[0])
}
case []interface{}:
t := tableName.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case names.TableName:
table = f.(names.TableName).TableName()
default:
v := utils.ReflectValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
table = names.GetTableName(mapper, v)
} else {
table = quote(fmt.Sprintf("%v", f))
}
}
}
if l > 1 {
return fmt.Sprintf("%v AS %v", quote(table), quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
return quote(table)
}
case names.TableName:
return tableName.(names.TableName).TableName()
case string:
return tableName.(string)
case reflect.Value:
v := tableName.(reflect.Value)
return names.GetTableName(mapper, v)
default:
v := utils.ReflectValue(tableName)
t := v.Type()
if t.Kind() == reflect.Struct {
return names.GetTableName(mapper, v)
}
return quote(fmt.Sprintf("%v", tableName))
}
return ""
}

// FullTableName returns table name with quote and schema according parameter
func FullTableName(dialect Dialect, mapper names.Mapper, bean interface{}, includeSchema ...bool) string {
tbName := TableNameNoSchema(dialect, mapper, bean)
if len(includeSchema) > 0 && includeSchema[0] && !utils.IsSubQuery(tbName) {
tbName = TableNameWithSchema(dialect, tbName)
}
return tbName
}

engine_table_test.go → dialects/table_name_test.go View File

@@ -2,11 +2,13 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm
package dialects

import (
"testing"

"xorm.io/xorm/names"

"github.com/stretchr/testify/assert"
)

@@ -20,9 +22,9 @@ func (mcc *MCC) TableName() string {
return "mcc"
}

func TestTableName1(t *testing.T) {
assert.NoError(t, prepareEngine())
func TestFullTableName(t *testing.T) {
dialect := QueryDialect("mysql")

assert.EqualValues(t, "mcc", testEngine.TableName(new(MCC)))
assert.EqualValues(t, "mcc", testEngine.TableName("mcc"))
assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, &MCC{}))
assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, "mcc"))
}

+ 49
- 0
dialects/time.go View File

@@ -0,0 +1,49 @@
// Copyright 2015 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package dialects

import (
"time"

"xorm.io/xorm/schemas"
)

// FormatTime format time as column type
func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{}) {
switch sqlTypeName {
case schemas.Time:
s := t.Format("2006-01-02 15:04:05") // time.RFC3339
v = s[11:19]
case schemas.Date:
v = t.Format("2006-01-02")
case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar.
v = t.Format("2006-01-02 15:04:05")
case schemas.TimeStampz:
if dialect.DBType() == schemas.MSSQL {
v = t.Format("2006-01-02T15:04:05.9999999Z07:00")
} else {
v = t.Format(time.RFC3339Nano)
}
case schemas.BigInt, schemas.Int:
v = t.Unix()
default:
v = t
}
return
}

func FormatColumnTime(dialect Dialect, defaultTimeZone *time.Location, col *schemas.Column, t time.Time) (v interface{}) {
if t.IsZero() {
if col.Nullable {
return nil
}
return ""
}

if col.TimeZone != nil {
return FormatTime(dialect, col.SQLType.Name, t.In(col.TimeZone))
}
return FormatTime(dialect, col.SQLType.Name, t.In(defaultTimeZone))
}

+ 24
- 69
engine.go View File

@@ -18,7 +18,6 @@ import (
"strings"
"time"

"xorm.io/builder"
"xorm.io/xorm/caches"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
@@ -65,25 +64,6 @@ func (engine *Engine) BufferSize(size int) *Session {
return session.BufferSize(size)
}

// CondDeleted returns the conditions whether a record is soft deleted.
func (engine *Engine) CondDeleted(col *schemas.Column) builder.Cond {
var cond = builder.NewCond()
if col.SQLType.IsNumeric() {
cond = builder.Eq{col.Name: 0}
} else {
// FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value.
if engine.dialect.DBType() != schemas.MSSQL {
cond = builder.Eq{col.Name: utils.ZeroTime1}
}
}

if col.Nullable {
cond = cond.Or(builder.IsNull{col.Name})
}

return cond
}

// ShowSQL show SQL statement or not on logger if log level is great than INFO
func (engine *Engine) ShowSQL(show ...bool) {
engine.logger.ShowSQL(show...)
@@ -237,7 +217,7 @@ func (engine *Engine) NoCascade() *Session {

// MapCacher Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error {
engine.SetCacher(engine.TableName(bean, true), cacher)
engine.SetCacher(dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, true), cacher)
return nil
}

@@ -759,13 +739,13 @@ func (t *Table) IsValid() bool {
}

// TableInfo get table info according to bean's content
func (engine *Engine) TableInfo(bean interface{}) *Table {
v := rValue(bean)
func (engine *Engine) TableInfo(bean interface{}) (*Table, error) {
v := utils.ReflectValue(bean)
tb, err := engine.tagParser.MapType(v)
if err != nil {
engine.logger.Error(err)
return nil, err
}
return &Table{tb, engine.TableName(bean)}
return &Table{tb, dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)}, nil
}

// IsTableEmpty if a table has any reocrd
@@ -787,6 +767,11 @@ func (engine *Engine) IDOf(bean interface{}) schemas.PK {
return engine.IDOfV(reflect.ValueOf(bean))
}

// TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
return dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, includeSchema...)
}

// IDOfV get id from one value of struct
func (engine *Engine) IDOfV(rv reflect.Value) schemas.PK {
pk, err := engine.idOfV(rv)
@@ -873,7 +858,7 @@ func (engine *Engine) CreateUniques(bean interface{}) error {

// ClearCacheBean if enabled cache, clear the cache bean
func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
tableName := engine.TableName(bean)
tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)
cacher := engine.GetCacher(tableName)
if cacher != nil {
cacher.ClearIds(tableName)
@@ -885,7 +870,7 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
// ClearCache if enabled cache, clear some tables' cache
func (engine *Engine) ClearCache(beans ...interface{}) error {
for _, bean := range beans {
tableName := engine.TableName(bean)
tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)
cacher := engine.GetCacher(tableName)
if cacher != nil {
cacher.ClearIds(tableName)
@@ -908,8 +893,8 @@ func (engine *Engine) Sync(beans ...interface{}) error {
defer session.Close()

for _, bean := range beans {
v := rValue(bean)
tableNameNoSchema := engine.TableName(bean)
v := utils.ReflectValue(bean)
tableNameNoSchema := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)
table, err := engine.tagParser.MapType(v)
if err != nil {
return err
@@ -946,7 +931,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
if !isExist {
if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.SetRefBean(bean); err != nil {
return err
}
err = session.addColumn(col.Name)
@@ -957,7 +942,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}

for name, index := range table.Indexes {
if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.SetRefBean(bean); err != nil {
return err
}
if index.Type == schemas.UniqueType {
@@ -966,7 +951,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
if !isExist {
if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.SetRefBean(bean); err != nil {
return err
}

@@ -981,7 +966,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
if !isExist {
if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.SetRefBean(bean); err != nil {
return err
}

@@ -1250,45 +1235,11 @@ func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time) {
if !col.DisableTimeZone && col.TimeZone != nil {
tz = col.TimeZone
}
return engine.formatTime(col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation)
return dialects.FormatTime(engine.dialect, col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation)
}

func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interface{}) {
if t.IsZero() {
if col.Nullable {
return nil
}
return ""
}

if col.TimeZone != nil {
return engine.formatTime(col.SQLType.Name, t.In(col.TimeZone))
}
return engine.formatTime(col.SQLType.Name, t.In(engine.DatabaseTZ))
}

// formatTime format time as column type
func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) {
switch sqlTypeName {
case schemas.Time:
s := t.Format("2006-01-02 15:04:05") // time.RFC3339
v = s[11:19]
case schemas.Date:
v = t.Format("2006-01-02")
case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar.
v = t.Format("2006-01-02 15:04:05")
case schemas.TimeStampz:
if engine.dialect.DBType() == schemas.MSSQL {
v = t.Format("2006-01-02T15:04:05.9999999Z07:00")
} else {
v = t.Format(time.RFC3339Nano)
}
case schemas.BigInt, schemas.Int:
v = t.Unix()
default:
v = t
}
return
return dialects.FormatColumnTime(engine.dialect, engine.DatabaseTZ, col, t)
}

// GetColumnMapper returns the column name mapper
@@ -1332,3 +1283,7 @@ func (engine *Engine) Unscoped() *Session {
session.isAutoClose = true
return session.Unscoped()
}

func (engine *Engine) tbNameWithSchema(v string) string {
return dialects.TableNameWithSchema(engine.dialect, v)
}

+ 0
- 234
engine_cond.go View File

@@ -1,234 +0,0 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm

import (
"database/sql/driver"
"fmt"
"reflect"
"strings"
"time"

"xorm.io/builder"
"xorm.io/xorm/convert"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)

func (engine *Engine) buildConds(table *schemas.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) {
var conds []builder.Cond
for _, col := range table.Columns() {
if !includeVersion && col.IsVersion {
continue
}
if !includeUpdated && col.IsUpdated {
continue
}
if !includeAutoIncr && col.IsAutoIncrement {
continue
}

if engine.dialect.DBType() == schemas.MSSQL && (col.SQLType.Name == schemas.Text || col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) {
continue
}
if col.SQLType.IsJson() {
continue
}

var colName string
if addedTableName {
var nm = tableName
if len(aliasName) > 0 {
nm = aliasName
}
colName = engine.Quote(nm) + "." + engine.Quote(col.Name)
} else {
colName = engine.Quote(col.Name)
}

fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
if !strings.Contains(err.Error(), "is not valid") {
engine.logger.Warn(err)
}
continue
}

if col.IsDeleted && !unscoped { // tag "deleted" is enabled
conds = append(conds, engine.CondDeleted(col))
}

fieldValue := *fieldValuePtr
if fieldValue.Interface() == nil {
continue
}

fieldType := reflect.TypeOf(fieldValue.Interface())
requiredField := useAllCols

if b, ok := getFlagForColumn(mustColumnMap, col); ok {
if b {
requiredField = true
} else {
continue
}
}

if fieldType.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
if includeNil {
conds = append(conds, builder.Eq{colName: nil})
}
continue
} else if !fieldValue.IsValid() {
continue
} else {
// dereference ptr type to instance type
fieldValue = fieldValue.Elem()
fieldType = reflect.TypeOf(fieldValue.Interface())
requiredField = true
}
}

var val interface{}
switch fieldType.Kind() {
case reflect.Bool:
if allUseBool || requiredField {
val = fieldValue.Interface()
} else {
// if a bool in a struct, it will not be as a condition because it default is false,
// please use Where() instead
continue
}
case reflect.String:
if !requiredField && fieldValue.String() == "" {
continue
}
// for MyString, should convert to string or panic
if fieldType.String() != reflect.String.String() {
val = fieldValue.String()
} else {
val = fieldValue.Interface()
}
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
if !requiredField && fieldValue.Int() == 0 {
continue
}
val = fieldValue.Interface()
case reflect.Float32, reflect.Float64:
if !requiredField && fieldValue.Float() == 0.0 {
continue
}
val = fieldValue.Interface()
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
if !requiredField && fieldValue.Uint() == 0 {
continue
}
t := int64(fieldValue.Uint())
val = reflect.ValueOf(&t).Interface()
case reflect.Struct:
if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue
}
val = engine.formatColTime(col, t)
} else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok {
continue
} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = valNul.Value()
if val == nil {
continue
}
} else {
if col.SQLType.IsJson() {
if col.SQLType.IsText() {
bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = bytes
}
} else {
table, err := engine.tagParser.MapType(fieldValue)
if err != nil {
val = fieldValue.Interface()
} else {
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues
//if pkField.Int() != 0 {
if pkField.IsValid() && !utils.IsZero(pkField.Interface()) {
val = pkField.Interface()
} else {
continue
}
} else {
//TODO: how to handler?
return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys)
}
}
}
}
case reflect.Array:
continue
case reflect.Slice, reflect.Map:
if fieldValue == reflect.Zero(fieldType) {
continue
}
if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
continue
}

if col.SQLType.IsText() {
bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
fieldType.Elem().Kind() == reflect.Uint8 {
if fieldValue.Len() > 0 {
val = fieldValue.Bytes()
} else {
continue
}
} else {
bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = bytes
}
} else {
continue
}
default:
val = fieldValue.Interface()
}

conds = append(conds, builder.Eq{colName: val})
}

return builder.And(conds...), nil
}

+ 0
- 109
engine_table.go View File

@@ -1,109 +0,0 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm

import (
"fmt"
"reflect"
"strings"

"xorm.io/xorm/dialects"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
)

// tbNameWithSchema will automatically add schema prefix on table name
func (engine *Engine) tbNameWithSchema(v string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if engine.dialect.DBType() == schemas.POSTGRES &&
engine.dialect.URI().Schema != "" &&
engine.dialect.URI().Schema != dialects.PostgresPublicSchema &&
strings.Index(v, ".") == -1 {
return engine.dialect.URI().Schema + "." + v
}
return v
}

func isSubQuery(tbName string) bool {
const selStr = "select"
if len(tbName) <= len(selStr)+1 {
return false
}

return strings.EqualFold(tbName[:len(selStr)], selStr) || strings.EqualFold(tbName[:len(selStr)+1], "("+selStr)
}

// TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
tbName := engine.tbNameNoSchema(bean)
if len(includeSchema) > 0 && includeSchema[0] && !isSubQuery(tbName) {
tbName = engine.tbNameWithSchema(tbName)
}
return tbName
}

// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *schemas.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}

return table.Name
}

func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
} else if len(t) == 1 {
return engine.Quote(t[0])
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case names.TableName:
table = f.(names.TableName).TableName()
default:
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
table = names.GetTableName(engine.GetTableMapper(), v)
} else {
table = engine.Quote(fmt.Sprintf("%v", f))
}
}
}
if l > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(table),
engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
return engine.Quote(table)
}
case names.TableName:
return tablename.(names.TableName).TableName()
case string:
return tablename.(string)
case reflect.Value:
v := tablename.(reflect.Value)
return names.GetTableName(engine.GetTableMapper(), v)
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
return names.GetTableName(engine.GetTableMapper(), v)
}
return engine.Quote(fmt.Sprintf("%v", tablename))
}
return ""
}

+ 0
- 2
error.go View File

@@ -26,8 +26,6 @@ var (
ErrNotImplemented = errors.New("Not implemented")
// ErrConditionType condition type unsupported
ErrConditionType = errors.New("Unsupported condition type")
// ErrUnSupportedSQLType parameter of SQL is not supported
ErrUnSupportedSQLType = errors.New("Unsupported sql type")
)

// ErrFieldIsNotExist columns does not exist


+ 0
- 29
helpers.go View File

@@ -9,7 +9,6 @@ import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
)

@@ -138,26 +137,6 @@ func int64ToInt(id int64, tp reflect.Type) interface{} {
return int64ToIntValue(id, tp).Interface()
}

func indexNoCase(s, sep string) int {
return strings.Index(strings.ToLower(s), strings.ToLower(sep))
}

func splitNoCase(s, sep string) []string {
idx := indexNoCase(s, sep)
if idx < 0 {
return []string{s}
}
return strings.Split(s, s[idx:idx+len(sep)])
}

func splitNNoCase(s, sep string, n int) []string {
idx := indexNoCase(s, sep)
if idx < 0 {
return []string{s}
}
return strings.SplitN(s, s[idx:idx+len(sep)], n)
}

func makeArray(elem string, count int) []string {
res := make([]string, count)
for i := 0; i < count; i++ {
@@ -166,10 +145,6 @@ func makeArray(elem string, count int) []string {
return res
}

func rValue(bean interface{}) reflect.Value {
return reflect.Indirect(reflect.ValueOf(bean))
}

func rType(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
// return reflect.TypeOf(sliceValue.Interface())
@@ -183,10 +158,6 @@ func structName(v reflect.Type) string {
return v.Name()
}

func indexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
}

func formatTime(t time.Time) string {
return t.Format("2006-01-02 15:04:05")
}

+ 1
- 26
interface.go View File

@@ -7,7 +7,6 @@ package xorm
import (
"context"
"database/sql"
"encoding/json"
"reflect"
"time"

@@ -113,7 +112,7 @@ type EngineInterface interface {
Sync(...interface{}) error
Sync2(...interface{}) error
StoreEngine(storeEngine string) *Session
TableInfo(bean interface{}) *Table
TableInfo(bean interface{}) (*Table, error)
TableName(interface{}, ...bool) string
UnMapType(reflect.Type)
}
@@ -123,27 +122,3 @@ var (
_ EngineInterface = &Engine{}
_ EngineInterface = &EngineGroup{}
)

// JSONInterface represents an interface to handle json data
type JSONInterface interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
}

var (
// DefaultJSONHandler default json handler
DefaultJSONHandler JSONInterface = StdJSON{}
)

// StdJSON implements JSONInterface via encoding/json
type StdJSON struct{}

// Marshal implements JSONInterface
func (StdJSON) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}

// Unmarshal implements JSONInterface
func (StdJSON) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}

+ 31
- 0
internal/json/json.go View File

@@ -0,0 +1,31 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package json

import "encoding/json"

// JSONInterface represents an interface to handle json data
type JSONInterface interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
}

var (
// DefaultJSONHandler default json handler
DefaultJSONHandler JSONInterface = StdJSON{}
)

// StdJSON implements JSONInterface via encoding/json
type StdJSON struct{}

// Marshal implements JSONInterface
func (StdJSON) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}

// Unmarshal implements JSONInterface
func (StdJSON) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}

+ 79
- 0
internal/statements/cache.go View File

@@ -0,0 +1,79 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package statements

import (
"fmt"
"strings"

"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)

func (statement *Statement) ConvertIDSQL(sqlStr string) string {
if statement.RefTable != nil {
cols := statement.RefTable.PKColumns()
if len(cols) == 0 {
return ""
}

colstrs := statement.joinColumns(cols, false)
sqls := utils.SplitNNoCase(sqlStr, " from ", 2)
if len(sqls) != 2 {
return ""
}

var top string
pLimitN := statement.LimitN
if pLimitN != nil && statement.dialect.DBType() == schemas.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN)
}

newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
return newsql
}
return ""
}

func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) {
if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
return "", ""
}

colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
sqls := utils.SplitNNoCase(sqlStr, "where", 2)
if len(sqls) != 2 {
if len(sqls) == 1 {
return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
colstrs, statement.quote(statement.TableName()))
}
return "", ""
}

var whereStr = sqls[1]

// TODO: for postgres only, if any other database?
var paraStr string
if statement.dialect.DBType() == schemas.POSTGRES {
paraStr = "$"
} else if statement.dialect.DBType() == schemas.MSSQL {
paraStr = ":"
}

if paraStr != "" {
if strings.Contains(sqls[1], paraStr) {
dollers := strings.Split(sqls[1], paraStr)
whereStr = dollers[0]
for i, c := range dollers[1:] {
ccs := strings.SplitN(c, " ", 2)
whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
}
}
}

return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
colstrs, statement.quote(statement.TableName()),
whereStr)
}

statement_columnmap.go → internal/statements/column_map.go View File

@@ -2,13 +2,17 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm
package statements

import "strings"
import (
"strings"

"xorm.io/xorm/schemas"
)

type columnMap []string

func (m columnMap) contain(colName string) bool {
func (m columnMap) Contain(colName string) bool {
if len(m) == 0 {
return false
}
@@ -27,9 +31,28 @@ func (m columnMap) contain(colName string) bool {
}

func (m *columnMap) add(colName string) bool {
if m.contain(colName) {
if m.Contain(colName) {
return false
}
*m = append(*m, colName)
return true
}

func getFlagForColumn(m map[string]bool, col *schemas.Column) (val bool, has bool) {
if len(m) == 0 {
return false, false
}

n := len(col.Name)

for mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, col.Name) {
return m[mk], true
}
}

return false, false
}

statement_exprparam.go → internal/statements/expr_param.go View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm
package statements

import (
"fmt"
@@ -26,21 +26,21 @@ type exprParam struct {
}

type exprParams struct {
colNames []string
args []interface{}
ColNames []string
Args []interface{}
}

func (exprs *exprParams) Len() int {
return len(exprs.colNames)
return len(exprs.ColNames)
}

func (exprs *exprParams) addParam(colName string, arg interface{}) {
exprs.colNames = append(exprs.colNames, colName)
exprs.args = append(exprs.args, arg)
exprs.ColNames = append(exprs.ColNames, colName)
exprs.Args = append(exprs.Args, arg)
}

func (exprs *exprParams) isColExist(colName string) bool {
for _, name := range exprs.colNames {
func (exprs *exprParams) IsColExist(colName string) bool {
for _, name := range exprs.ColNames {
if strings.EqualFold(schemas.CommonQuoter.Trim(name), schemas.CommonQuoter.Trim(colName)) {
return true
}
@@ -49,16 +49,16 @@ func (exprs *exprParams) isColExist(colName string) bool {
}

func (exprs *exprParams) getByName(colName string) (exprParam, bool) {
for i, name := range exprs.colNames {
for i, name := range exprs.ColNames {
if strings.EqualFold(name, colName) {
return exprParam{name, exprs.args[i]}, true
return exprParam{name, exprs.Args[i]}, true
}
}
return exprParam{}, false
}

func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
for i, expr := range exprs.args {
func (exprs *exprParams) WriteArgs(w *builder.BytesWriter) error {
for i, expr := range exprs.Args {
switch arg := expr.(type) {
case *builder.Builder:
if _, err := w.WriteString("("); err != nil {
@@ -83,7 +83,7 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
}
w.Append(arg)
}
if i != len(exprs.args)-1 {
if i != len(exprs.Args)-1 {
if _, err := w.WriteString(","); err != nil {
return err
}
@@ -93,7 +93,7 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
}

func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error {
for i, colName := range exprs.colNames {
for i, colName := range exprs.ColNames {
if _, err := w.WriteString(colName); err != nil {
return err
}
@@ -101,7 +101,7 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error {
return err
}

switch arg := exprs.args[i].(type) {
switch arg := exprs.Args[i].(type) {
case *builder.Builder:
if _, err := w.WriteString("("); err != nil {
return err
@@ -113,10 +113,10 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error {
return err
}
default:
w.Append(exprs.args[i])
w.Append(exprs.Args[i])
}

if i+1 != len(exprs.colNames) {
if i+1 != len(exprs.ColNames) {
if _, err := w.WriteString(","); err != nil {
return err
}

+ 448
- 0
internal/statements/query.go View File

@@ -0,0 +1,448 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package statements

import (
"errors"
"fmt"
"reflect"
"strings"

"xorm.io/builder"
"xorm.io/xorm/schemas"
)

func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) {
if len(sqlOrArgs) > 0 {
return ConvertSQLOrArgs(sqlOrArgs...)
}

if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}

if len(statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}

var columnStr = statement.ColumnStr()
if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr
} else {
if statement.JoinStr == "" {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}

if err := statement.ProcessIDParam(); err != nil {
return "", nil, err
}

condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil {
return "", nil, err
}

args := append(statement.joinArgs, condArgs...)
sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return "", nil, err
}
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}

return sqlStr, args, nil
}

func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}

statement.SetRefBean(bean)

var sumStrs = make([]string, 0, len(columns))
for _, colName := range columns {
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
colName = statement.quote(colName)
}
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
}
sumSelect := strings.Join(sumStrs, ", ")

condSQL, condArgs, err := statement.GenConds(bean)
if err != nil {
return "", nil, err
}

sqlStr, err := statement.GenSelectSQL(sumSelect, condSQL, true, true)
if err != nil {
return "", nil, err
}

return sqlStr, append(statement.joinArgs, condArgs...), nil
}

func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) {
v := rValue(bean)
isStruct := v.Kind() == reflect.Struct
if isStruct {
statement.SetRefBean(bean)
}

var columnStr = statement.ColumnStr()
if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr
} else {
// TODO: always generate column names, not use * even if join
if len(statement.JoinStr) == 0 {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
}
}
}
}

if len(columnStr) == 0 {
columnStr = "*"
}

if isStruct {
if err := statement.mergeConds(bean); err != nil {
return "", nil, err
}
} else {
if err := statement.ProcessIDParam(); err != nil {
return "", nil, err
}
}
condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil {
return "", nil, err
}

sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return "", nil, err
}

return sqlStr, append(statement.joinArgs, condArgs...), nil
}

func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) {
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}

var condSQL string
var condArgs []interface{}
var err error
if len(beans) > 0 {
statement.SetRefBean(beans[0])
condSQL, condArgs, err = statement.GenConds(beans[0])
} else {
condSQL, condArgs, err = builder.ToSQL(statement.cond)
}
if err != nil {
return "", nil, err
}

var selectSQL = statement.SelectStr
if len(selectSQL) <= 0 {
if statement.IsDistinct {
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr())
} else {
selectSQL = "count(*)"
}
}
sqlStr, err := statement.GenSelectSQL(selectSQL, condSQL, false, false)
if err != nil {
return "", nil, err
}

return sqlStr, append(statement.joinArgs, condArgs...), nil
}

func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
var (
distinct string
dialect = statement.dialect
quote = statement.quote
fromStr = " FROM "
top, mssqlCondi, whereStr string
)
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT "
}
if len(condSQL) > 0 {
whereStr = " WHERE " + condSQL
}

if dialect.DBType() == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
fromStr += statement.TableName()
} else {
fromStr += quote(statement.TableName())
}

if statement.TableAlias != "" {
if dialect.DBType() == schemas.ORACLE {
fromStr += " " + quote(statement.TableAlias)
} else {
fromStr += " AS " + quote(statement.TableAlias)
}
}
if statement.JoinStr != "" {
fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
}

pLimitN := statement.LimitN
if dialect.DBType() == schemas.MSSQL {
if pLimitN != nil {
LimitNValue := *pLimitN
top = fmt.Sprintf("TOP %d ", LimitNValue)
}
if statement.Start > 0 {
var column string
if len(statement.RefTable.PKColumns()) == 0 {
for _, index := range statement.RefTable.Indexes {
if len(index.Cols) == 1 {
column = index.Cols[0]
break
}
}
if len(column) == 0 {
column = statement.RefTable.ColumnsSeq()[0]
}
} else {
column = statement.RefTable.PKColumns()[0].Name
}
if statement.needTableName() {
if len(statement.TableAlias) > 0 {
column = statement.TableAlias + "." + column
} else {
column = statement.TableName() + "." + column
}
}

var orderStr string
if needOrderBy && len(statement.OrderStr) > 0 {
orderStr = " ORDER BY " + statement.OrderStr
}

var groupStr string
if len(statement.GroupByStr) > 0 {
groupStr = " GROUP BY " + statement.GroupByStr
}
mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
}
}

var buf strings.Builder
fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
if len(mssqlCondi) > 0 {
if len(whereStr) > 0 {
fmt.Fprint(&buf, " AND ", mssqlCondi)
} else {
fmt.Fprint(&buf, " WHERE ", mssqlCondi)
}
}

if statement.GroupByStr != "" {
fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
}
if statement.HavingStr != "" {
fmt.Fprint(&buf, " ", statement.HavingStr)
}
if needOrderBy && statement.OrderStr != "" {
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
}
if needLimit {
if dialect.DBType() != schemas.MSSQL && dialect.DBType() != schemas.ORACLE {
if statement.Start > 0 {
if pLimitN != nil {
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
} else {
fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start)
}
} else if pLimitN != nil {
fmt.Fprint(&buf, " LIMIT ", *pLimitN)
}
} else if dialect.DBType() == schemas.ORACLE {
if statement.Start != 0 || pLimitN != nil {
oldString := buf.String()
buf.Reset()
rawColStr := columnStr
if rawColStr == "*" {
rawColStr = "at.*"
}
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start)
}
}
}
if statement.IsForUpdate {
return dialect.ForUpdateSQL(buf.String()), nil
}

return buf.String(), nil
}

func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) {
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}

var sqlStr string
var args []interface{}
var joinStr string
var err error
if len(bean) == 0 {
tableName := statement.TableName()
if len(tableName) <= 0 {
return "", nil, ErrTableNotFound
}

tableName = statement.quote(tableName)
if len(statement.JoinStr) > 0 {
joinStr = statement.JoinStr
}

if statement.Conds().IsValid() {
condSQL, condArgs, err := builder.ToSQL(statement.Conds())
if err != nil {
return "", nil, err
}

if statement.dialect.DBType() == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL)
} else if statement.dialect.DBType() == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL)
}
args = condArgs
} else {
if statement.dialect.DBType() == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr)
} else if statement.dialect.DBType() == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr)
}
args = []interface{}{}
}
} else {
beanValue := reflect.ValueOf(bean[0])
if beanValue.Kind() != reflect.Ptr {
return "", nil, errors.New("needs a pointer")
}

if beanValue.Elem().Kind() == reflect.Struct {
if err := statement.SetRefBean(bean[0]); err != nil {
return "", nil, err
}
}

if len(statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}
statement.Limit(1)
sqlStr, args, err = statement.GenGetSQL(bean[0])
if err != nil {
return "", nil, err
}
}

return sqlStr, args, nil
}

func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) {
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}

var sqlStr string
var args []interface{}
var err error

if len(statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}

var columnStr = statement.ColumnStr()
if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr
} else {
if statement.JoinStr == "" {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}

statement.cond = statement.cond.And(autoCond)
condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil {
return "", nil, err
}

args = append(statement.joinArgs, condArgs...)
sqlStr, err = statement.GenSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return "", nil, err
}
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}

return sqlStr, args, nil
}

internal/statements/statement.go
File diff suppressed because it is too large
View File


statement_args.go → internal/statements/statement_args.go View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm
package statements

import (
"fmt"
@@ -77,7 +77,7 @@ func convertArg(arg interface{}, convertFunc func(string) string) string {

const insertSelectPlaceHolder = true

func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error {
func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error {
switch argv := arg.(type) {
case bool:
if statement.dialect.DBType() == schemas.MSSQL {
@@ -130,9 +130,9 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er
return nil
}

func (statement *Statement) writeArgs(w *builder.BytesWriter, args []interface{}) error {
func (statement *Statement) WriteArgs(w *builder.BytesWriter, args []interface{}) error {
for i, arg := range args {
if err := statement.writeArg(w, arg); err != nil {
if err := statement.WriteArg(w, arg); err != nil {
return err
}

@@ -144,27 +144,3 @@ func (statement *Statement) writeArgs(w *builder.BytesWriter, args []interface{}
}
return nil
}

func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error {
for i, colName := range cols {
if len(leftQuote) > 0 && colName[0] != '`' {
if _, err := w.WriteString(leftQuote); err != nil {
return err
}
}
if _, err := w.WriteString(colName); err != nil {
return err
}
if len(rightQuote) > 0 && colName[len(colName)-1] != '`' {
if _, err := w.WriteString(rightQuote); err != nil {
return err
}
}
if i+1 != len(cols) {
if _, err := w.WriteString(","); err != nil {
return err
}
}
}
return nil
}

+ 184
- 0
internal/statements/statement_test.go View File

@@ -0,0 +1,184 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package statements

import (
"reflect"
"strings"
"testing"

"xorm.io/xorm/schemas"
)

var colStrTests = []struct {
omitColumn string
onlyToDBColumnNdx int
expected string
}{
{"", -1, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`, `Longitude`"},
{"Code2", -1, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code3`, `ParentID`, `Latitude`, `Longitude`"},
{"", 1, "`ID`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`, `Longitude`"},
{"Code3", 1, "`ID`, `Caption`, `Code1`, `Code2`, `ParentID`, `Latitude`, `Longitude`"},
{"Longitude", 1, "`ID`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`"},
{"", 8, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`"},
}

func TestColumnsStringGeneration(t *testing.T) {
if dbType == "postgres" || dbType == "mssql" {
return
}

var statement *Statement

for ndx, testCase := range colStrTests {
statement = createTestStatement()

if testCase.omitColumn != "" {
statement.Omit(testCase.omitColumn)
}

columns := statement.RefTable.Columns()
if testCase.onlyToDBColumnNdx >= 0 {
columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB
}

actual := statement.genColumnStr()

if actual != testCase.expected {
t.Errorf("[test #%d] Unexpected columns string:\nwant:\t%s\nhave:\t%s", ndx, testCase.expected, actual)
}
if testCase.onlyToDBColumnNdx >= 0 {
columns[testCase.onlyToDBColumnNdx].MapType = schemas.TWOSIDES
}
}
}

func BenchmarkColumnsStringGeneration(b *testing.B) {
b.StopTimer()

statement := createTestStatement()

testCase := colStrTests[0]

if testCase.omitColumn != "" {
statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped
}

if testCase.onlyToDBColumnNdx >= 0 {
columns := statement.RefTable.Columns()
columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped
}

b.StartTimer()

for i := 0; i < b.N; i++ {
actual := statement.genColumnStr()

if actual != testCase.expected {
b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual)
}
}
}

func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {

b.StopTimer()

mapCols := make(map[string]bool)
cols := []*schemas.Column{
{Name: `ID`},
{Name: `IsDeleted`},
{Name: `Caption`},
{Name: `Code1`},
{Name: `Code2`},
{Name: `Code3`},
{Name: `ParentID`},
{Name: `Latitude`},
{Name: `Longitude`},
}

for _, col := range cols {
mapCols[strings.ToLower(col.Name)] = true
}

b.StartTimer()

for i := 0; i < b.N; i++ {

for _, col := range cols {

if _, ok := getFlagForColumn(mapCols, col); !ok {
b.Fatal("Unexpected result")
}
}
}
}

func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) {

b.StopTimer()

mapCols := make(map[string]bool)
cols := []*schemas.Column{
{Name: `ID`},
{Name: `IsDeleted`},
{Name: `Caption`},
{Name: `Code1`},
{Name: `Code2`},
{Name: `Code3`},
{Name: `ParentID`},
{Name: `Latitude`},
{Name: `Longitude`},
}

b.StartTimer()

for i := 0; i < b.N; i++ {

for _, col := range cols {

if _, ok := getFlagForColumn(mapCols, col); ok {
b.Fatal("Unexpected result")
}
}
}
}

type TestType struct {
ID int64 `xorm:"ID PK"`
IsDeleted bool `xorm:"IsDeleted"`
Caption string `xorm:"Caption"`
Code1 string `xorm:"Code1"`
Code2 string `xorm:"Code2"`
Code3 string `xorm:"Code3"`
ParentID int64 `xorm:"ParentID"`
Latitude float64 `xorm:"Latitude"`
Longitude float64 `xorm:"Longitude"`
}

func (TestType) TableName() string {
return "TestTable"
}

func createTestStatement() *Statement {
if engine, ok := testEngine.(*Engine); ok {
statement := &Statement{}
statement.Reset()
statement.Engine = engine
statement.dialect = engine.dialect
statement.SetRefValue(reflect.ValueOf(TestType{}))

return statement
} else if eg, ok := testEngine.(*EngineGroup); ok {
statement := &Statement{}
statement.Reset()
statement.Engine = eg.Engine
statement.dialect = eg.Engine.dialect
statement.SetRefValue(reflect.ValueOf(TestType{}))

return statement
}
return nil
}

types.go → internal/statements/types.go View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm
package statements

import (
"reflect"

+ 280
- 0
internal/statements/update.go View File

@@ -0,0 +1,280 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package statements

import (
"database/sql/driver"
"fmt"
"reflect"
"time"

"xorm.io/xorm/convert"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/json"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)

// BuildUpdates auto generating update columnes and values according a struct
func (statement *Statement) BuildUpdates(bean interface{},
includeVersion, includeUpdated, includeNil,
includeAutoIncr, update bool) ([]string, []interface{}, error) {
//engine := statement.Engine
table := statement.RefTable
allUseBool := statement.allUseBool
useAllCols := statement.useAllCols
mustColumnMap := statement.MustColumnMap
nullableMap := statement.NullableMap
columnMap := statement.ColumnMap
omitColumnMap := statement.OmitColumnMap
unscoped := statement.unscoped

var colNames = make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns() {
if !includeVersion && col.IsVersion {
continue
}
if col.IsCreated && !columnMap.Contain(col.Name) {
continue
}
if !includeUpdated && col.IsUpdated {
continue
}
if !includeAutoIncr && col.IsAutoIncrement {
continue
}
if col.IsDeleted && !unscoped {
continue
}
if omitColumnMap.Contain(col.Name) {
continue
}
if len(columnMap) > 0 && !columnMap.Contain(col.Name) {
continue
}

if col.MapType == schemas.ONLYFROMDB {
continue
}

if statement.IncrColumns.IsColExist(col.Name) {
continue
} else if statement.DecrColumns.IsColExist(col.Name) {
continue
} else if statement.ExprColumns.IsColExist(col.Name) {
continue
}

fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}

fieldValue := *fieldValuePtr
fieldType := reflect.TypeOf(fieldValue.Interface())
if fieldType == nil {
continue
}

requiredField := useAllCols
includeNil := useAllCols

if b, ok := getFlagForColumn(mustColumnMap, col); ok {
if b {
requiredField = true
} else {
continue
}
}

// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if b, ok := getFlagForColumn(nullableMap, col); ok {
if b && col.Nullable && utils.IsZero(fieldValue.Interface()) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
fieldType = reflect.TypeOf(fieldValue.Interface())
includeNil = true
}
}

var val interface{}

if fieldValue.CanAddr() {
if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
data, err := structConvert.ToDB()
if err != nil {
return nil, nil, err
}

val = data
goto APPEND
}
}

if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok {
data, err := structConvert.ToDB()
if err != nil {
return nil, nil, err
}

val = data
goto APPEND
}

if fieldType.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
if includeNil {
args = append(args, nil)
colNames = append(colNames, fmt.Sprintf("%v=?", statement.quote(col.Name)))
}
continue
} else if !fieldValue.IsValid() {
continue
} else {
// dereference ptr type to instance type
fieldValue = fieldValue.Elem()
fieldType = reflect.TypeOf(fieldValue.Interface())
requiredField = true
}
}

switch fieldType.Kind() {
case reflect.Bool:
if allUseBool || requiredField {
val = fieldValue.Interface()
} else {
// if a bool in a struct, it will not be as a condition because it default is false,
// please use Where() instead
continue
}
case reflect.String:
if !requiredField && fieldValue.String() == "" {
continue
}
// for MyString, should convert to string or panic
if fieldType.String() != reflect.String.String() {
val = fieldValue.String()
} else {
val = fieldValue.Interface()
}
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
if !requiredField && fieldValue.Int() == 0 {
continue
}
val = fieldValue.Interface()
case reflect.Float32, reflect.Float64:
if !requiredField && fieldValue.Float() == 0.0 {
continue
}
val = fieldValue.Interface()
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
if !requiredField && fieldValue.Uint() == 0 {
continue
}
t := int64(fieldValue.Uint())
val = reflect.ValueOf(&t).Interface()
case reflect.Struct:
if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue
}
val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t)
} else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = nulType.Value()
} else {
if !col.SQLType.IsJson() {
table, err := statement.tagParser.MapType(fieldValue)
if err != nil {
val = fieldValue.Interface()
} else {
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues
if pkField.IsValid() && (!requiredField && !utils.IsZero(pkField.Interface())) {
val = pkField.Interface()
} else {
continue
}
} else {
// TODO: how to handler?
panic("not supported")
}
}
} else {
// Blank struct could not be as update data
if requiredField || !utils.IsStructZero(fieldValue) {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
}
if col.SQLType.IsText() {
val = string(bytes)