Fix pk bug #1602

Merged
lunny merged 1 commits from lunny/fix_pk into master 2020-03-13 08:57:37 +00:00
8 changed files with 172 additions and 111 deletions

79
internal/statements/pk.go Normal file
View File

@ -0,0 +1,79 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"reflect"
"xorm.io/builder"
"xorm.io/xorm/schemas"
)
var (
ptrPkType = reflect.TypeOf(&schemas.PK{})
pkType = reflect.TypeOf(schemas.PK{})
stringType = reflect.TypeOf("")
intType = reflect.TypeOf(int64(0))
uintType = reflect.TypeOf(uint64(0))
)
// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
func (statement *Statement) ID(id interface{}) *Statement {
switch t := id.(type) {
case *schemas.PK:
statement.idParam = *t
case schemas.PK:
statement.idParam = t
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
statement.idParam = schemas.PK{id}
default:
idValue := reflect.ValueOf(id)
idType := idValue.Type()
switch idType.Kind() {
case reflect.String:
statement.idParam = schemas.PK{idValue.Convert(stringType).Interface()}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
statement.idParam = schemas.PK{idValue.Convert(intType).Interface()}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
statement.idParam = schemas.PK{idValue.Convert(uintType).Interface()}
case reflect.Slice:
if idType.ConvertibleTo(pkType) {
statement.idParam = idValue.Convert(pkType).Interface().(schemas.PK)
}
case reflect.Ptr:
if idType.ConvertibleTo(ptrPkType) {
statement.idParam = idValue.Convert(ptrPkType).Elem().Interface().(schemas.PK)
}
}
}
if statement.idParam == nil {
statement.LastError = fmt.Errorf("ID param %#v is not supported", id)
}
return statement
}
func (statement *Statement) ProcessIDParam() error {
if statement.idParam == nil || statement.RefTable == nil {
return nil
}
if len(statement.RefTable.PrimaryKeys) != len(statement.idParam) {
fmt.Println("=====", statement.RefTable.PrimaryKeys, statement.idParam)
return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
len(statement.RefTable.PrimaryKeys),
len(statement.idParam),
)
}
for i, col := range statement.RefTable.PKColumns() {
var colName = statement.colName(col, statement.TableName())
statement.cond = statement.cond.And(builder.Eq{colName: statement.idParam[i]})
}
return nil
}

View File

@ -41,7 +41,7 @@ type Statement struct {
tagParser *tags.Parser
Start int
LimitN *int
idParam *schemas.PK
idParam schemas.PK
OrderStr string
JoinStr string
joinArgs []interface{}
@ -319,34 +319,6 @@ func (statement *Statement) TableName() string {
return statement.tableName
}
// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
func (statement *Statement) ID(id interface{}) *Statement {
idValue := reflect.ValueOf(id)
idType := reflect.TypeOf(idValue.Interface())
switch idType {
case ptrPkType:
if pkPtr, ok := (id).(*schemas.PK); ok {
statement.idParam = pkPtr
return statement
}
case pkType:
if pk, ok := (id).(schemas.PK); ok {
statement.idParam = &pk
return statement
}
}
switch idType.Kind() {
case reflect.String:
statement.idParam = &schemas.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
return statement
}
statement.idParam = &schemas.PK{id}
return statement
}
// Incr Generate "Update ... Set column = column + arg" statement
func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
if len(arg) > 0 {
@ -981,25 +953,6 @@ func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) {
return "", nil, ErrUnSupportedType
}
func (statement *Statement) ProcessIDParam() error {
if statement.idParam == nil || statement.RefTable == nil {
return nil
}
if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
len(statement.RefTable.PrimaryKeys),
len(*statement.idParam),
)
}
for i, col := range statement.RefTable.PKColumns() {
var colName = statement.colName(col, statement.TableName())
statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
}
return nil
}
func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName bool) string {
var colnames = make([]string, len(cols))
for i, col := range cols {

View File

@ -1,16 +0,0 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"reflect"
"xorm.io/xorm/schemas"
)
var (
ptrPkType = reflect.TypeOf(&schemas.PK{})
pkType = reflect.TypeOf(schemas.PK{})
)

View File

@ -18,58 +18,73 @@ import (
"xorm.io/xorm/schemas"
)
func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion, includeUpdated, includeNil,
includeAutoIncr, update bool) (bool, error) {
columnMap := statement.ColumnMap
omitColumnMap := statement.OmitColumnMap
unscoped := statement.unscoped
if !includeVersion && col.IsVersion {
return false, nil
}
if col.IsCreated && !columnMap.Contain(col.Name) {
return false, nil
}
if !includeUpdated && col.IsUpdated {
return false, nil
}
if !includeAutoIncr && col.IsAutoIncrement {
return false, nil
}
if col.IsDeleted && !unscoped {
return false, nil
}
if omitColumnMap.Contain(col.Name) {
return false, nil
}
if len(columnMap) > 0 && !columnMap.Contain(col.Name) {
return false, nil
}
if col.MapType == schemas.ONLYFROMDB {
return false, nil
}
if statement.IncrColumns.IsColExist(col.Name) {
return false, nil
} else if statement.DecrColumns.IsColExist(col.Name) {
return false, nil
} else if statement.ExprColumns.IsColExist(col.Name) {
return false, nil
}
return true, nil
}
// BuildUpdates auto generating update columnes and values according a struct
func (statement *Statement) BuildUpdates(bean interface{},
func (statement *Statement) BuildUpdates(tableValue reflect.Value,
includeVersion, includeUpdated, includeNil,
includeAutoIncr, update bool) ([]string, []interface{}, error) {
//engine := statement.Engine
table := statement.RefTable
allUseBool := statement.allUseBool
useAllCols := statement.useAllCols
mustColumnMap := statement.MustColumnMap
nullableMap := statement.NullableMap
columnMap := statement.ColumnMap
omitColumnMap := statement.OmitColumnMap
unscoped := statement.unscoped
var colNames = make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns() {
if !includeVersion && col.IsVersion {
continue
ok, err := statement.ifAddColUpdate(col, includeVersion, includeUpdated, includeNil,
includeAutoIncr, update)
if err != nil {
return nil, nil, err
}
if col.IsCreated && !columnMap.Contain(col.Name) {
continue
}
if !includeUpdated && col.IsUpdated {
continue
}
if !includeAutoIncr && col.IsAutoIncrement {
continue
}
if col.IsDeleted && !unscoped {
continue
}
if omitColumnMap.Contain(col.Name) {
continue
}
if len(columnMap) > 0 && !columnMap.Contain(col.Name) {
if !ok {
continue
}
if col.MapType == schemas.ONLYFROMDB {
continue
}
if statement.IncrColumns.IsColExist(col.Name) {
continue
} else if statement.DecrColumns.IsColExist(col.Name) {
continue
} else if statement.ExprColumns.IsColExist(col.Name) {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
fieldValuePtr, err := col.ValueOfV(&tableValue)
if err != nil {
return nil, nil, err
}
@ -273,9 +288,6 @@ func (statement *Statement) BuildUpdates(bean interface{},
APPEND:
args = append(args, val)
if col.IsPrimaryKey {
continue
}
colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name)))
}

View File

@ -21,7 +21,7 @@ const (
type Column struct {
Name string
TableName string
FieldName string
FieldName string // Avaiable only when parsed from a struct
SQLType SQLType
IsJSON bool
Length int

View File

@ -53,13 +53,9 @@ func (table *Table) ColumnsSeq() []string {
}
func (table *Table) columnsByName(name string) []*Column {
n := len(name)
for k := range table.columnsMap {
if len(k) != n {
continue
}
for k, cols := range table.columnsMap {
if strings.EqualFold(k, name) {
return table.columnsMap[k]
return cols
}
}
return nil

View File

@ -177,7 +177,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
if session.statement.ColumnStr() == "" {
colNames, args, err = session.statement.BuildUpdates(bean, false, false,
colNames, args, err = session.statement.BuildUpdates(v, false, false,
false, false, true)
} else {
colNames, args, err = session.genUpdateColumns(bean)

View File

@ -1303,3 +1303,40 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) {
assert.NoError(t, err)
assertGetRecord()
}
func TestUpdateMultiplePK(t *testing.T) {
type TestUpdateMultiplePKStruct struct {
Id string `xorm:"notnull pk" description:"唯一ID号"`
Name string `xorm:"notnull pk" description:"名称"`
Value string `xorm:"notnull varchar(4000)" description:"值"`
}
assert.NoError(t, prepareEngine())
assertSync(t, new(TestUpdateMultiplePKStruct))
test := &TestUpdateMultiplePKStruct{
Id: "ID1",
Name: "Name1",
Value: "1",
}
_, err := testEngine.Insert(test)
assert.NoError(t, err)
test.Value = "2"
_, err = testEngine.Where("`id` = ? And `name` = ?", test.Id, test.Name).Cols("Value").Update(test)
assert.NoError(t, err)
test.Value = "3"
num, err := testEngine.Where("`id` = ? And `name` = ?", test.Id, test.Name).Update(test)
assert.NoError(t, err)
assert.EqualValues(t, 1, num)
test.Value = "4"
_, err = testEngine.ID([]interface{}{test.Id, test.Name}).Update(test)
assert.NoError(t, err)
type MySlice []interface{}
test.Value = "5"
_, err = testEngine.ID(&MySlice{test.Id, test.Name}).Update(test)
assert.NoError(t, err)
}