Browse Source

Move statement as a sub package (#1564)

Fix test

Fix bug

Move statement as a sub package

Reviewed-on: #1564
pull/1565/head
Lunny Xiao 4 months ago
parent
commit
2b62dc5a51
50 changed files with 1985 additions and 1782 deletions
  1. +1
    -1
      contexts/context_cache.go
  2. +5
    -0
      dialects/dialect.go
  3. +4
    -0
      dialects/postgres.go
  4. +90
    -0
      dialects/table_name.go
  5. +7
    -5
      dialects/table_name_test.go
  6. +49
    -0
      dialects/time.go
  7. +24
    -69
      engine.go
  8. +0
    -234
      engine_cond.go
  9. +0
    -109
      engine_table.go
  10. +0
    -2
      error.go
  11. +0
    -29
      helpers.go
  12. +1
    -26
      interface.go
  13. +31
    -0
      internal/json/json.go
  14. +79
    -0
      internal/statements/cache.go
  15. +27
    -4
      internal/statements/column_map.go
  16. +17
    -17
      internal/statements/expr_param.go
  17. +448
    -0
      internal/statements/query.go
  18. +372
    -632
      internal/statements/statement.go
  19. +4
    -28
      internal/statements/statement_args.go
  20. +184
    -0
      internal/statements/statement_test.go
  21. +1
    -1
      internal/statements/types.go
  22. +280
    -0
      internal/statements/update.go
  23. +13
    -0
      internal/utils/name.go
  24. +13
    -0
      internal/utils/reflect.go
  25. +19
    -0
      internal/utils/sql.go
  26. +30
    -0
      internal/utils/strings.go
  27. +5
    -4
      rows.go
  28. +34
    -0
      schemas/quote.go
  29. +27
    -19
      session.go
  30. +0
    -13
      session_cols.go
  31. +1
    -1
      session_cond.go
  32. +11
    -10
      session_convert.go
  33. +6
    -6
      session_delete.go
  34. +5
    -75
      session_exist.go
  35. +19
    -69
      session_find.go
  36. +5
    -4
      session_find_test.go
  37. +9
    -8
      session_get.go
  38. +3
    -2
      session_get_test.go
  39. +40
    -40
      session_insert.go
  40. +8
    -6
      session_iterate.go
  41. +4
    -68
      session_query.go
  42. +2
    -16
      session_raw.go
  43. +13
    -13
      session_schema.go
  44. +7
    -24
      session_stats.go
  45. +3
    -2
      session_tx_test.go
  46. +46
    -46
      session_update.go
  47. +13
    -12
      session_update_test.go
  48. +0
    -174
      statement_test.go
  49. +20
    -9
      tags_test.go
  50. +5
    -4
      types_test.go

context_cache.go → contexts/context_cache.go View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package contexts
// ContextCache is the interface that operates the cache data.
type ContextCache interface {

+ 5
- 0
dialects/dialect.go View File

@ -41,6 +41,7 @@ type Dialect interface {
DBType() DBType
SQLType(*schemas.Column) string
FormatBytes(b []byte) string
DefaultSchema() string
DriverName() string
DataSourceName() string
@ -103,6 +104,10 @@ func (b *Base) SetLogger(logger log.Logger) {
b.logger = logger
}
func (b *Base) DefaultSchema() string {
return ""
}
func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI, drivername, dataSourceName string) error {
b.db, b.dialect, b.uri = db, dialect, uri
b.driverName, b.dataSourceName = drivername, dataSourceName

+ 4
- 0
dialects/postgres.go View File

@ -788,6 +788,10 @@ func (db *postgres) Init(d *core.DB, uri *URI, drivername, dataSourceName string
return nil
}
func (db *postgres) DefaultSchema() string {
return PostgresPublicSchema
}
func (db *postgres) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {

+ 90
- 0
dialects/table_name.go View File

@ -0,0 +1,90 @@
// Copyright 2015 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dialects
import (
"fmt"
"reflect"
"strings"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/names"
)
// TableNameWithSchema will add schema prefix on table name if possible
func TableNameWithSchema(dialect Dialect, tableName string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if dialect.URI().Schema != "" &&
dialect.URI().Schema != dialect.DefaultSchema() &&
strings.Index(tableName, ".") == -1 {
return fmt.Sprintf("%s.%s", dialect.URI().Schema, tableName)
}
return tableName
}
// TableNameNoSchema returns table name with given tableName
func TableNameNoSchema(dialect Dialect, mapper names.Mapper, tableName interface{}) string {
quote := dialect.Quoter().Quote
switch tableName.(type) {
case []string:
t := tableName.([]string)
if len(t) > 1 {
return fmt.Sprintf("%v AS %v", quote(t[0]), quote(t[1]))
} else if len(t) == 1 {
return quote(t[0])
}
case []interface{}:
t := tableName.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case names.TableName:
table = f.(names.TableName).TableName()
default:
v := utils.ReflectValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
table = names.GetTableName(mapper, v)
} else {
table = quote(fmt.Sprintf("%v", f))
}
}
}
if l > 1 {
return fmt.Sprintf("%v AS %v", quote(table), quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
return quote(table)
}
case names.TableName:
return tableName.(names.TableName).TableName()
case string:
return tableName.(string)
case reflect.Value:
v := tableName.(reflect.Value)
return names.GetTableName(mapper, v)
default:
v := utils.ReflectValue(tableName)
t := v.Type()
if t.Kind() == reflect.Struct {
return names.GetTableName(mapper, v)
}
return quote(fmt.Sprintf("%v", tableName))
}
return ""
}
// FullTableName returns table name with quote and schema according parameter
func FullTableName(dialect Dialect, mapper names.Mapper, bean interface{}, includeSchema ...bool) string {
tbName := TableNameNoSchema(dialect, mapper, bean)
if len(includeSchema) > 0 && includeSchema[0] && !utils.IsSubQuery(tbName) {
tbName = TableNameWithSchema(dialect, tbName)
}
return tbName
}

engine_table_test.go → dialects/table_name_test.go View File

@ -2,11 +2,13 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package dialects
import (
"testing"
"xorm.io/xorm/names"
"github.com/stretchr/testify/assert"
)
@ -20,9 +22,9 @@ func (mcc *MCC) TableName() string {
return "mcc"
}
func TestTableName1(t *testing.T) {
assert.NoError(t, prepareEngine())
func TestFullTableName(t *testing.T) {
dialect := QueryDialect("mysql")
assert.EqualValues(t, "mcc", testEngine.TableName(new(MCC)))
assert.EqualValues(t, "mcc", testEngine.TableName("mcc"))
assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, &MCC{}))
assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, "mcc"))
}

+ 49
- 0
dialects/time.go View File

@ -0,0 +1,49 @@
// Copyright 2015 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dialects
import (
"time"
"xorm.io/xorm/schemas"
)
// FormatTime format time as column type
func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{}) {
switch sqlTypeName {
case schemas.Time:
s := t.Format("2006-01-02 15:04:05") // time.RFC3339
v = s[11:19]
case schemas.Date:
v = t.Format("2006-01-02")
case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar.
v = t.Format("2006-01-02 15:04:05")
case schemas.TimeStampz:
if dialect.DBType() == schemas.MSSQL {
v = t.Format("2006-01-02T15:04:05.9999999Z07:00")
} else {
v = t.Format(time.RFC3339Nano)
}
case schemas.BigInt, schemas.Int:
v = t.Unix()
default:
v = t
}
return
}
func FormatColumnTime(dialect Dialect, defaultTimeZone *time.Location, col *schemas.Column, t time.Time) (v interface{}) {
if t.IsZero() {
if col.Nullable {
return nil
}
return ""
}
if col.TimeZone != nil {
return FormatTime(dialect, col.SQLType.Name, t.In(col.TimeZone))
}
return FormatTime(dialect, col.SQLType.Name, t.In(defaultTimeZone))
}

+ 24
- 69
engine.go View File

@ -18,7 +18,6 @@ import (
"strings"
"time"
"xorm.io/builder"
"xorm.io/xorm/caches"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
@ -65,25 +64,6 @@ func (engine *Engine) BufferSize(size int) *Session {
return session.BufferSize(size)
}
// CondDeleted returns the conditions whether a record is soft deleted.
func (engine *Engine) CondDeleted(col *schemas.Column) builder.Cond {
var cond = builder.NewCond()
if col.SQLType.IsNumeric() {
cond = builder.Eq{col.Name: 0}
} else {
// FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value.
if engine.dialect.DBType() != schemas.MSSQL {
cond = builder.Eq{col.Name: utils.ZeroTime1}
}
}
if col.Nullable {
cond = cond.Or(builder.IsNull{col.Name})
}
return cond
}
// ShowSQL show SQL statement or not on logger if log level is great than INFO
func (engine *Engine) ShowSQL(show ...bool) {
engine.logger.ShowSQL(show...)
@ -237,7 +217,7 @@ func (engine *Engine) NoCascade() *Session {
// MapCacher Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error {
engine.SetCacher(engine.TableName(bean, true), cacher)
engine.SetCacher(dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, true), cacher)
return nil
}
@ -759,13 +739,13 @@ func (t *Table) IsValid() bool {
}
// TableInfo get table info according to bean's content
func (engine *Engine) TableInfo(bean interface{}) *Table {
v := rValue(bean)
func (engine *Engine) TableInfo(bean interface{}) (*Table, error) {
v := utils.ReflectValue(bean)
tb, err := engine.tagParser.MapType(v)
if err != nil {
engine.logger.Error(err)
return nil, err
}
return &Table{tb, engine.TableName(bean)}
return &Table{tb, dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)}, nil
}
// IsTableEmpty if a table has any reocrd
@ -787,6 +767,11 @@ func (engine *Engine) IDOf(bean interface{}) schemas.PK {
return engine.IDOfV(reflect.ValueOf(bean))
}
// TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
return dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, includeSchema...)
}
// IDOfV get id from one value of struct
func (engine *Engine) IDOfV(rv reflect.Value) schemas.PK {
pk, err := engine.idOfV(rv)
@ -873,7 +858,7 @@ func (engine *Engine) CreateUniques(bean interface{}) error {
// ClearCacheBean if enabled cache, clear the cache bean
func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
tableName := engine.TableName(bean)
tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)
cacher := engine.GetCacher(tableName)
if cacher != nil {
cacher.ClearIds(tableName)
@ -885,7 +870,7 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
// ClearCache if enabled cache, clear some tables' cache
func (engine *Engine) ClearCache(beans ...interface{}) error {
for _, bean := range beans {
tableName := engine.TableName(bean)
tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)
cacher := engine.GetCacher(tableName)
if cacher != nil {
cacher.ClearIds(tableName)
@ -908,8 +893,8 @@ func (engine *Engine) Sync(beans ...interface{}) error {
defer session.Close()
for _, bean := range beans {
v := rValue(bean)
tableNameNoSchema := engine.TableName(bean)
v := utils.ReflectValue(bean)
tableNameNoSchema := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)
table, err := engine.tagParser.MapType(v)
if err != nil {
return err
@ -946,7 +931,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
if !isExist {
if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.SetRefBean(bean); err != nil {
return err
}
err = session.addColumn(col.Name)
@ -957,7 +942,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}
for name, index := range table.Indexes {
if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.SetRefBean(bean); err != nil {
return err
}
if index.Type == schemas.UniqueType {
@ -966,7 +951,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
if !isExist {
if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.SetRefBean(bean); err != nil {
return err
}
@ -981,7 +966,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
if !isExist {
if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.SetRefBean(bean); err != nil {
return err
}
@ -1250,45 +1235,11 @@ func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time) {
if !col.DisableTimeZone && col.TimeZone != nil {
tz = col.TimeZone
}
return engine.formatTime(col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation)
return dialects.FormatTime(engine.dialect, col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation)
}
func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interface{}) {
if t.IsZero() {
if col.Nullable {
return nil
}
return ""
}
if col.TimeZone != nil {
return engine.formatTime(col.SQLType.Name, t.In(col.TimeZone))
}
return engine.formatTime(col.SQLType.Name, t.In(engine.DatabaseTZ))
}
// formatTime format time as column type
func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) {
switch sqlTypeName {
case schemas.Time:
s := t.Format("2006-01-02 15:04:05") // time.RFC3339
v = s[11:19]
case schemas.Date:
v = t.Format("2006-01-02")
case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar.
v = t.Format("2006-01-02 15:04:05")
case schemas.TimeStampz:
if engine.dialect.DBType() == schemas.MSSQL {
v = t.Format("2006-01-02T15:04:05.9999999Z07:00")
} else {
v = t.Format(time.RFC3339Nano)
}
case schemas.BigInt, schemas.Int:
v = t.Unix()
default:
v = t
}
return
return dialects.FormatColumnTime(engine.dialect, engine.DatabaseTZ, col, t)
}
// GetColumnMapper returns the column name mapper
@ -1332,3 +1283,7 @@ func (engine *Engine) Unscoped() *Session {
session.isAutoClose = true
return session.Unscoped()
}
func (engine *Engine) tbNameWithSchema(v string) string {
return dialects.TableNameWithSchema(engine.dialect, v)
}

+ 0
- 234
engine_cond.go View File

@ -1,234 +0,0 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"database/sql/driver"
"fmt"
"reflect"
"strings"
"time"
"xorm.io/builder"
"xorm.io/xorm/convert"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
func (engine *Engine) buildConds(table *schemas.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) {
var conds []builder.Cond
for _, col := range table.Columns() {
if !includeVersion && col.IsVersion {
continue
}
if !includeUpdated && col.IsUpdated {
continue
}
if !includeAutoIncr && col.IsAutoIncrement {
continue
}
if engine.dialect.DBType() == schemas.MSSQL && (col.SQLType.Name == schemas.Text || col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) {
continue
}
if col.SQLType.IsJson() {
continue
}
var colName string
if addedTableName {
var nm = tableName
if len(aliasName) > 0 {
nm = aliasName
}
colName = engine.Quote(nm) + "." + engine.Quote(col.Name)
} else {
colName = engine.Quote(col.Name)
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
if !strings.Contains(err.Error(), "is not valid") {
engine.logger.Warn(err)
}
continue
}
if col.IsDeleted && !unscoped { // tag "deleted" is enabled
conds = append(conds, engine.CondDeleted(col))
}
fieldValue := *fieldValuePtr
if fieldValue.Interface() == nil {
continue
}
fieldType := reflect.TypeOf(fieldValue.Interface())
requiredField := useAllCols
if b, ok := getFlagForColumn(mustColumnMap, col); ok {
if b {
requiredField = true
} else {
continue
}
}
if fieldType.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
if includeNil {
conds = append(conds, builder.Eq{colName: nil})
}
continue
} else if !fieldValue.IsValid() {
continue
} else {
// dereference ptr type to instance type
fieldValue = fieldValue.Elem()
fieldType = reflect.TypeOf(fieldValue.Interface())
requiredField = true
}
}
var val interface{}
switch fieldType.Kind() {
case reflect.Bool:
if allUseBool || requiredField {
val = fieldValue.Interface()
} else {
// if a bool in a struct, it will not be as a condition because it default is false,
// please use Where() instead
continue
}
case reflect.String:
if !requiredField && fieldValue.String() == "" {
continue
}
// for MyString, should convert to string or panic
if fieldType.String() != reflect.String.String() {
val = fieldValue.String()
} else {
val = fieldValue.Interface()
}
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
if !requiredField && fieldValue.Int() == 0 {
continue
}
val = fieldValue.Interface()
case reflect.Float32, reflect.Float64:
if !requiredField && fieldValue.Float() == 0.0 {
continue
}
val = fieldValue.Interface()
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
if !requiredField && fieldValue.Uint() == 0 {
continue
}
t := int64(fieldValue.Uint())
val = reflect.ValueOf(&t).Interface()
case reflect.Struct:
if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue
}
val = engine.formatColTime(col, t)
} else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok {
continue
} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = valNul.Value()
if val == nil {
continue
}
} else {
if col.SQLType.IsJson() {
if col.SQLType.IsText() {
bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = bytes
}
} else {
table, err := engine.tagParser.MapType(fieldValue)
if err != nil {
val = fieldValue.Interface()
} else {
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues
//if pkField.Int() != 0 {
if pkField.IsValid() && !utils.IsZero(pkField.Interface()) {
val = pkField.Interface()
} else {
continue
}
} else {
//TODO: how to handler?
return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys)
}
}
}
}
case reflect.Array:
continue
case reflect.Slice, reflect.Map:
if fieldValue == reflect.Zero(fieldType) {
continue
}
if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
continue
}
if col.SQLType.IsText() {
bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
fieldType.Elem().Kind() == reflect.Uint8 {
if fieldValue.Len() > 0 {
val = fieldValue.Bytes()
} else {
continue
}
} else {
bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = bytes
}
} else {
continue
}
default:
val = fieldValue.Interface()
}
conds = append(conds, builder.Eq{colName: val})
}
return builder.And(conds...), nil
}

+ 0
- 109
engine_table.go View File

@ -1,109 +0,0 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"fmt"
"reflect"
"strings"
"xorm.io/xorm/dialects"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
)
// tbNameWithSchema will automatically add schema prefix on table name
func (engine *Engine) tbNameWithSchema(v string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if engine.dialect.DBType() == schemas.POSTGRES &&
engine.dialect.URI().Schema != "" &&
engine.dialect.URI().Schema != dialects.PostgresPublicSchema &&
strings.Index(v, ".") == -1 {
return engine.dialect.URI().Schema + "." + v
}
return v
}
func isSubQuery(tbName string) bool {
const selStr = "select"
if len(tbName) <= len(selStr)+1 {
return false
}
return strings.EqualFold(tbName[:len(selStr)], selStr) || strings.EqualFold(tbName[:len(selStr)+1], "("+selStr)
}
// TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
tbName := engine.tbNameNoSchema(bean)
if len(includeSchema) > 0 && includeSchema[0] && !isSubQuery(tbName) {
tbName = engine.tbNameWithSchema(tbName)
}
return tbName
}
// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *schemas.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}
return table.Name
}
func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
} else if len(t) == 1 {
return engine.Quote(t[0])
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case names.TableName:
table = f.(names.TableName).TableName()
default:
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
table = names.GetTableName(engine.GetTableMapper(), v)
} else {
table = engine.Quote(fmt.Sprintf("%v", f))
}
}
}
if l > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(table),
engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
return engine.Quote(table)
}
case names.TableName:
return tablename.(names.TableName).TableName()
case string:
return tablename.(string)
case reflect.Value:
v := tablename.(reflect.Value)
return names.GetTableName(engine.GetTableMapper(), v)
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
return names.GetTableName(engine.GetTableMapper(), v)
}
return engine.Quote(fmt.Sprintf("%v", tablename))
}
return ""
}

+ 0
- 2
error.go View File

@ -26,8 +26,6 @@ var (
ErrNotImplemented = errors.New("Not implemented")
// ErrConditionType condition type unsupported
ErrConditionType = errors.New("Unsupported condition type")
// ErrUnSupportedSQLType parameter of SQL is not supported
ErrUnSupportedSQLType = errors.New("Unsupported sql type")
)
// ErrFieldIsNotExist columns does not exist

+ 0
- 29
helpers.go View File

@ -9,7 +9,6 @@ import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
@ -138,26 +137,6 @@ func int64ToInt(id int64, tp reflect.Type) interface{} {
return int64ToIntValue(id, tp).Interface()
}
func indexNoCase(s, sep string) int {
return strings.Index(strings.ToLower(s), strings.ToLower(sep))
}
func splitNoCase(s, sep string) []string {
idx := indexNoCase(s, sep)
if idx < 0 {
return []string{s}
}
return strings.Split(s, s[idx:idx+len(sep)])
}
func splitNNoCase(s, sep string, n int) []string {
idx := indexNoCase(s, sep)
if idx < 0 {
return []string{s}
}
return strings.SplitN(s, s[idx:idx+len(sep)], n)
}
func makeArray(elem string, count int) []string {
res := make([]string, count)
for i := 0; i < count; i++ {
@ -166,10 +145,6 @@ func makeArray(elem string, count int) []string {
return res
}
func rValue(bean interface{}) reflect.Value {
return reflect.Indirect(reflect.ValueOf(bean))
}
func rType(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
// return reflect.TypeOf(sliceValue.Interface())
@ -183,10 +158,6 @@ func structName(v reflect.Type) string {
return v.Name()
}
func indexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
}
func formatTime(t time.Time) string {
return t.Format("2006-01-02 15:04:05")
}

+ 1
- 26
interface.go View File

@ -7,7 +7,6 @@ package xorm
import (
"context"
"database/sql"
"encoding/json"
"reflect"
"time"
@ -113,7 +112,7 @@ type EngineInterface interface {
Sync(...interface{}) error
Sync2(...interface{}) error
StoreEngine(storeEngine string) *Session
TableInfo(bean interface{}) *Table
TableInfo(bean interface{}) (*Table, error)
TableName(interface{}, ...bool) string
UnMapType(reflect.Type)
}
@ -123,27 +122,3 @@ var (
_ EngineInterface = &Engine{}
_ EngineInterface = &EngineGroup{}
)
// JSONInterface represents an interface to handle json data
type JSONInterface interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
}
var (
// DefaultJSONHandler default json handler
DefaultJSONHandler JSONInterface = StdJSON{}
)
// StdJSON implements JSONInterface via encoding/json
type StdJSON struct{}
// Marshal implements JSONInterface
func (StdJSON) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
// Unmarshal implements JSONInterface
func (StdJSON) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}

+ 31
- 0
internal/json/json.go View File

@ -0,0 +1,31 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import "encoding/json"
// JSONInterface represents an interface to handle json data
type JSONInterface interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
}
var (
// DefaultJSONHandler default json handler
DefaultJSONHandler JSONInterface = StdJSON{}
)
// StdJSON implements JSONInterface via encoding/json
type StdJSON struct{}
// Marshal implements JSONInterface
func (StdJSON) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
// Unmarshal implements JSONInterface
func (StdJSON) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}

+ 79
- 0
internal/statements/cache.go View File

@ -0,0 +1,79 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"strings"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
func (statement *Statement) ConvertIDSQL(sqlStr string) string {
if statement.RefTable != nil {
cols := statement.RefTable.PKColumns()
if len(cols) == 0 {
return ""
}
colstrs := statement.joinColumns(cols, false)
sqls := utils.SplitNNoCase(sqlStr, " from ", 2)
if len(sqls) != 2 {
return ""
}
var top string
pLimitN := statement.LimitN
if pLimitN != nil && statement.dialect.DBType() == schemas.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN)
}
newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
return newsql
}
return ""
}
func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) {
if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
return "", ""
}
colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
sqls := utils.SplitNNoCase(sqlStr, "where", 2)
if len(sqls) != 2 {
if len(sqls) == 1 {
return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
colstrs, statement.quote(statement.TableName()))
}
return "", ""
}
var whereStr = sqls[1]
// TODO: for postgres only, if any other database?
var paraStr string
if statement.dialect.DBType() == schemas.POSTGRES {
paraStr = "$"
} else if statement.dialect.DBType() == schemas.MSSQL {
paraStr = ":"
}
if paraStr != "" {
if strings.Contains(sqls[1], paraStr) {
dollers := strings.Split(sqls[1], paraStr)
whereStr = dollers[0]
for i, c := range dollers[1:] {
ccs := strings.SplitN(c, " ", 2)
whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
}
}
}
return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
colstrs, statement.quote(statement.TableName()),
whereStr)
}

statement_columnmap.go → internal/statements/column_map.go View File

@ -2,13 +2,17 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package statements
import "strings"
import (
"strings"
"xorm.io/xorm/schemas"
)
type columnMap []string
func (m columnMap) contain(colName string) bool {
func (m columnMap) Contain(colName string) bool {
if len(m) == 0 {
return false
}
@ -27,9 +31,28 @@ func (m columnMap) contain(colName string) bool {
}
func (m *columnMap) add(colName string) bool {
if m.contain(colName) {
if m.Contain(colName) {
return false
}
*m = append(*m, colName)
return true
}
func getFlagForColumn(m map[string]bool, col *schemas.Column) (val bool, has bool) {
if len(m) == 0 {
return false, false
}
n := len(col.Name)
for mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, col.Name) {
return m[mk], true
}
}
return false, false
}

statement_exprparam.go → internal/statements/expr_param.go View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package statements
import (
"fmt"
@ -26,21 +26,21 @@ type exprParam struct {
}
type exprParams struct {
colNames []string
args []interface{}
ColNames []string
Args []interface{}
}
func (exprs *exprParams) Len() int {
return len(exprs.colNames)
return len(exprs.ColNames)
}
func (exprs *exprParams) addParam(colName string, arg interface{}) {
exprs.colNames = append(exprs.colNames, colName)
exprs.args = append(exprs.args, arg)
exprs.ColNames = append(exprs.ColNames, colName)
exprs.Args = append(exprs.Args, arg)
}
func (exprs *exprParams) isColExist(colName string) bool {
for _, name := range exprs.colNames {
func (exprs *exprParams) IsColExist(colName string) bool {
for _, name := range exprs.ColNames {
if strings.EqualFold(schemas.CommonQuoter.Trim(name), schemas.CommonQuoter.Trim(colName)) {
return true
}
@ -49,16 +49,16 @@ func (exprs *exprParams) isColExist(colName string) bool {
}
func (exprs *exprParams) getByName(colName string) (exprParam, bool) {
for i, name := range exprs.colNames {
for i, name := range exprs.ColNames {
if strings.EqualFold(name, colName) {
return exprParam{name, exprs.args[i]}, true
return exprParam{name, exprs.Args[i]}, true
}
}
return exprParam{}, false
}
func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
for i, expr := range exprs.args {
func (exprs *exprParams) WriteArgs(w *builder.BytesWriter) error {
for i, expr := range exprs.Args {
switch arg := expr.(type) {
case *builder.Builder:
if _, err := w.WriteString("("); err != nil {
@ -83,7 +83,7 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
}
w.Append(arg)
}
if i != len(exprs.args)-1 {
if i != len(exprs.Args)-1 {
if _, err := w.WriteString(","); err != nil {
return err
}
@ -93,7 +93,7 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
}
func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error {
for i, colName := range exprs.colNames {
for i, colName := range exprs.ColNames {
if _, err := w.WriteString(colName); err != nil {
return err
}
@ -101,7 +101,7 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error {
return err
}
switch arg := exprs.args[i].(type) {
switch arg := exprs.Args[i].(type) {
case *builder.Builder:
if _, err := w.WriteString("("); err != nil {
return err
@ -113,10 +113,10 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error {
return err
}
default:
w.Append(exprs.args[i])
w.Append(exprs.Args[i])
}
if i+1 != len(exprs.colNames) {
if i+1 != len(exprs.ColNames) {
if _, err := w.WriteString(","); err != nil {
return err
}

+ 448
- 0
internal/statements/query.go View File

@ -0,0 +1,448 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"errors"
"fmt"
"reflect"
"strings"
"xorm.io/builder"
"xorm.io/xorm/schemas"
)
func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) {
if len(sqlOrArgs) > 0 {
return ConvertSQLOrArgs(sqlOrArgs...)
}
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}
if len(statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}
var columnStr = statement.ColumnStr()
if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr
} else {
if statement.JoinStr == "" {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}
if err := statement.ProcessIDParam(); err != nil {
return "", nil, err
}
condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil {
return "", nil, err
}
args := append(statement.joinArgs, condArgs...)
sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return "", nil, err
}
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
return sqlStr, args, nil
}
func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}
statement.SetRefBean(bean)
var sumStrs = make([]string, 0, len(columns))
for _, colName := range columns {
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
colName = statement.quote(colName)
}
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
}
sumSelect := strings.Join(sumStrs, ", ")
condSQL, condArgs, err := statement.GenConds(bean)
if err != nil {
return "", nil, err
}
sqlStr, err := statement.GenSelectSQL(sumSelect, condSQL, true, true)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
}
func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) {
v := rValue(bean)
isStruct := v.Kind() == reflect.Struct
if isStruct {
statement.SetRefBean(bean)
}
var columnStr = statement.ColumnStr()
if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr
} else {
// TODO: always generate column names, not use * even if join
if len(statement.JoinStr) == 0 {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
}
}
}
}
if len(columnStr) == 0 {
columnStr = "*"
}
if isStruct {
if err := statement.mergeConds(bean); err != nil {
return "", nil, err
}
} else {
if err := statement.ProcessIDParam(); err != nil {
return "", nil, err
}
}
condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil {
return "", nil, err
}
sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
}
func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) {
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}
var condSQL string
var condArgs []interface{}
var err error
if len(beans) > 0 {
statement.SetRefBean(beans[0])
condSQL, condArgs, err = statement.GenConds(beans[0])
} else {
condSQL, condArgs, err = builder.ToSQL(statement.cond)
}
if err != nil {
return "", nil, err
}
var selectSQL = statement.SelectStr
if len(selectSQL) <= 0 {
if statement.IsDistinct {
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr())
} else {
selectSQL = "count(*)"
}
}
sqlStr, err := statement.GenSelectSQL(selectSQL, condSQL, false, false)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
}
func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
var (
distinct string
dialect = statement.dialect
quote = statement.quote
fromStr = " FROM "
top, mssqlCondi, whereStr string
)
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT "
}
if len(condSQL) > 0 {
whereStr = " WHERE " + condSQL
}
if dialect.DBType() == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
fromStr += statement.TableName()
} else {
fromStr += quote(statement.TableName())
}
if statement.TableAlias != "" {
if dialect.DBType() == schemas.ORACLE {
fromStr += " " + quote(statement.TableAlias)
} else {
fromStr += " AS " + quote(statement.TableAlias)
}
}
if statement.JoinStr != "" {
fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
}
pLimitN := statement.LimitN
if dialect.DBType() == schemas.MSSQL {
if pLimitN != nil {
LimitNValue := *pLimitN
top = fmt.Sprintf("TOP %d ", LimitNValue)
}
if statement.Start > 0 {
var column string
if len(statement.RefTable.PKColumns()) == 0 {
for _, index := range statement.RefTable.Indexes {
if len(index.Cols) == 1 {
column = index.Cols[0]
break
}
}
if len(column) == 0 {
column = statement.RefTable.ColumnsSeq()[0]
}
} else {
column = statement.RefTable.PKColumns()[0].Name
}
if statement.needTableName() {
if len(statement.TableAlias) > 0 {
column = statement.TableAlias + "." + column
} else {
column = statement.TableName() + "." + column
}
}
var orderStr string
if needOrderBy && len(statement.OrderStr) > 0 {
orderStr = " ORDER BY " + statement.OrderStr
}
var groupStr string
if len(statement.GroupByStr) > 0 {
groupStr = " GROUP BY " + statement.GroupByStr
}
mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
}
}
var buf strings.Builder
fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
if len(mssqlCondi) > 0 {
if len(whereStr) > 0 {
fmt.Fprint(&buf, " AND ", mssqlCondi)
} else {
fmt.Fprint(&buf, " WHERE ", mssqlCondi)
}
}
if statement.GroupByStr != "" {
fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
}
if statement.HavingStr != "" {
fmt.Fprint(&buf, " ", statement.HavingStr)
}
if needOrderBy && statement.OrderStr != "" {
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
}
if needLimit {
if dialect.DBType() != schemas.MSSQL && dialect.DBType() != schemas.ORACLE {
if statement.Start > 0 {
if pLimitN != nil {
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
} else {
fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start)
}
} else if pLimitN != nil {
fmt.Fprint(&buf, " LIMIT ", *pLimitN)
}
} else if dialect.DBType() == schemas.ORACLE {
if statement.Start != 0 || pLimitN != nil {
oldString := buf.String()
buf.Reset()
rawColStr := columnStr
if rawColStr == "*" {
rawColStr = "at.*"
}
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start)
}
}
}
if statement.IsForUpdate {
return dialect.ForUpdateSQL(buf.String()), nil
}
return buf.String(), nil
}
func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) {
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}
var sqlStr string
var args []interface{}
var joinStr string
var err error
if len(bean) == 0 {
tableName := statement.TableName()
if len(tableName) <= 0 {
return "", nil, ErrTableNotFound
}
tableName = statement.quote(tableName)
if len(statement.JoinStr) > 0 {
joinStr = statement.JoinStr
}
if statement.Conds().IsValid() {
condSQL, condArgs, err := builder.ToSQL(statement.Conds())
if err != nil {
return "", nil, err
}
if statement.dialect.DBType() == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL)
} else if statement.dialect.DBType() == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL)
}
args = condArgs
} else {
if statement.dialect.DBType() == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr)
} else if statement.dialect.DBType() == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr)
}
args = []interface{}{}
}
} else {
beanValue := reflect.ValueOf(bean[0])
if beanValue.Kind() != reflect.Ptr {
return "", nil, errors.New("needs a pointer")
}
if beanValue.Elem().Kind() == reflect.Struct {
if err := statement.SetRefBean(bean[0]); err != nil {
return "", nil, err
}
}
if len(statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}
statement.Limit(1)
sqlStr, args, err = statement.GenGetSQL(bean[0])
if err != nil {
return "", nil, err
}
}
return sqlStr, args, nil
}
func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) {
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}
var sqlStr string
var args []interface{}
var err error
if len(statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}
var columnStr = statement.ColumnStr()
if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr
} else {
if statement.JoinStr == "" {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}
statement.cond = statement.cond.And(autoCond)
condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil {
return "", nil, err
}
args = append(statement.joinArgs, condArgs...)
sqlStr, err = statement.GenSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return "", nil, err
}
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
return sqlStr, args, nil
}

internal/statements/statement.go
File diff suppressed because it is too large
View File


statement_args.go → internal/statements/statement_args.go View File

@ -2,7 +2,7 @@