Fix pk bug #1602
|
@ -0,0 +1,79 @@
|
|||
// Copyright 2017 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package statements
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
var (
|
||||
ptrPkType = reflect.TypeOf(&schemas.PK{})
|
||||
pkType = reflect.TypeOf(schemas.PK{})
|
||||
stringType = reflect.TypeOf("")
|
||||
intType = reflect.TypeOf(int64(0))
|
||||
uintType = reflect.TypeOf(uint64(0))
|
||||
)
|
||||
|
||||
// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
|
||||
func (statement *Statement) ID(id interface{}) *Statement {
|
||||
switch t := id.(type) {
|
||||
case *schemas.PK:
|
||||
statement.idParam = *t
|
||||
case schemas.PK:
|
||||
statement.idParam = t
|
||||
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
statement.idParam = schemas.PK{id}
|
||||
default:
|
||||
idValue := reflect.ValueOf(id)
|
||||
idType := idValue.Type()
|
||||
|
||||
switch idType.Kind() {
|
||||
case reflect.String:
|
||||
statement.idParam = schemas.PK{idValue.Convert(stringType).Interface()}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
statement.idParam = schemas.PK{idValue.Convert(intType).Interface()}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
statement.idParam = schemas.PK{idValue.Convert(uintType).Interface()}
|
||||
case reflect.Slice:
|
||||
if idType.ConvertibleTo(pkType) {
|
||||
statement.idParam = idValue.Convert(pkType).Interface().(schemas.PK)
|
||||
}
|
||||
case reflect.Ptr:
|
||||
if idType.ConvertibleTo(ptrPkType) {
|
||||
statement.idParam = idValue.Convert(ptrPkType).Elem().Interface().(schemas.PK)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if statement.idParam == nil {
|
||||
statement.LastError = fmt.Errorf("ID param %#v is not supported", id)
|
||||
}
|
||||
|
||||
return statement
|
||||
}
|
||||
|
||||
func (statement *Statement) ProcessIDParam() error {
|
||||
if statement.idParam == nil || statement.RefTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(statement.RefTable.PrimaryKeys) != len(statement.idParam) {
|
||||
fmt.Println("=====", statement.RefTable.PrimaryKeys, statement.idParam)
|
||||
return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
|
||||
len(statement.RefTable.PrimaryKeys),
|
||||
len(statement.idParam),
|
||||
)
|
||||
}
|
||||
|
||||
for i, col := range statement.RefTable.PKColumns() {
|
||||
var colName = statement.colName(col, statement.TableName())
|
||||
statement.cond = statement.cond.And(builder.Eq{colName: statement.idParam[i]})
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -41,7 +41,7 @@ type Statement struct {
|
|||
tagParser *tags.Parser
|
||||
Start int
|
||||
LimitN *int
|
||||
idParam *schemas.PK
|
||||
idParam schemas.PK
|
||||
OrderStr string
|
||||
JoinStr string
|
||||
joinArgs []interface{}
|
||||
|
@ -319,34 +319,6 @@ func (statement *Statement) TableName() string {
|
|||
return statement.tableName
|
||||
}
|
||||
|
||||
// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
|
||||
func (statement *Statement) ID(id interface{}) *Statement {
|
||||
idValue := reflect.ValueOf(id)
|
||||
idType := reflect.TypeOf(idValue.Interface())
|
||||
|
||||
switch idType {
|
||||
case ptrPkType:
|
||||
if pkPtr, ok := (id).(*schemas.PK); ok {
|
||||
statement.idParam = pkPtr
|
||||
return statement
|
||||
}
|
||||
case pkType:
|
||||
if pk, ok := (id).(schemas.PK); ok {
|
||||
statement.idParam = &pk
|
||||
return statement
|
||||
}
|
||||
}
|
||||
|
||||
switch idType.Kind() {
|
||||
case reflect.String:
|
||||
statement.idParam = &schemas.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
|
||||
return statement
|
||||
}
|
||||
|
||||
statement.idParam = &schemas.PK{id}
|
||||
return statement
|
||||
}
|
||||
|
||||
// Incr Generate "Update ... Set column = column + arg" statement
|
||||
func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
|
||||
if len(arg) > 0 {
|
||||
|
@ -981,25 +953,6 @@ func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) {
|
|||
return "", nil, ErrUnSupportedType
|
||||
}
|
||||
|
||||
func (statement *Statement) ProcessIDParam() error {
|
||||
if statement.idParam == nil || statement.RefTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
|
||||
return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
|
||||
len(statement.RefTable.PrimaryKeys),
|
||||
len(*statement.idParam),
|
||||
)
|
||||
}
|
||||
|
||||
for i, col := range statement.RefTable.PKColumns() {
|
||||
var colName = statement.colName(col, statement.TableName())
|
||||
statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName bool) string {
|
||||
var colnames = make([]string, len(cols))
|
||||
for i, col := range cols {
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
// Copyright 2017 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package statements
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
var (
|
||||
ptrPkType = reflect.TypeOf(&schemas.PK{})
|
||||
pkType = reflect.TypeOf(schemas.PK{})
|
||||
)
|
|
@ -18,58 +18,73 @@ import (
|
|||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion, includeUpdated, includeNil,
|
||||
includeAutoIncr, update bool) (bool, error) {
|
||||
columnMap := statement.ColumnMap
|
||||
omitColumnMap := statement.OmitColumnMap
|
||||
unscoped := statement.unscoped
|
||||
|
||||
if !includeVersion && col.IsVersion {
|
||||
return false, nil
|
||||
}
|
||||
if col.IsCreated && !columnMap.Contain(col.Name) {
|
||||
return false, nil
|
||||
}
|
||||
if !includeUpdated && col.IsUpdated {
|
||||
return false, nil
|
||||
}
|
||||
if !includeAutoIncr && col.IsAutoIncrement {
|
||||
return false, nil
|
||||
}
|
||||
if col.IsDeleted && !unscoped {
|
||||
return false, nil
|
||||
}
|
||||
if omitColumnMap.Contain(col.Name) {
|
||||
return false, nil
|
||||
}
|
||||
if len(columnMap) > 0 && !columnMap.Contain(col.Name) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if col.MapType == schemas.ONLYFROMDB {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if statement.IncrColumns.IsColExist(col.Name) {
|
||||
return false, nil
|
||||
} else if statement.DecrColumns.IsColExist(col.Name) {
|
||||
return false, nil
|
||||
} else if statement.ExprColumns.IsColExist(col.Name) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// BuildUpdates auto generating update columnes and values according a struct
|
||||
func (statement *Statement) BuildUpdates(bean interface{},
|
||||
func (statement *Statement) BuildUpdates(tableValue reflect.Value,
|
||||
includeVersion, includeUpdated, includeNil,
|
||||
includeAutoIncr, update bool) ([]string, []interface{}, error) {
|
||||
//engine := statement.Engine
|
||||
table := statement.RefTable
|
||||
allUseBool := statement.allUseBool
|
||||
useAllCols := statement.useAllCols
|
||||
mustColumnMap := statement.MustColumnMap
|
||||
nullableMap := statement.NullableMap
|
||||
columnMap := statement.ColumnMap
|
||||
omitColumnMap := statement.OmitColumnMap
|
||||
unscoped := statement.unscoped
|
||||
|
||||
var colNames = make([]string, 0)
|
||||
var args = make([]interface{}, 0)
|
||||
|
||||
for _, col := range table.Columns() {
|
||||
if !includeVersion && col.IsVersion {
|
||||
continue
|
||||
ok, err := statement.ifAddColUpdate(col, includeVersion, includeUpdated, includeNil,
|
||||
includeAutoIncr, update)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if col.IsCreated && !columnMap.Contain(col.Name) {
|
||||
continue
|
||||
}
|
||||
if !includeUpdated && col.IsUpdated {
|
||||
continue
|
||||
}
|
||||
if !includeAutoIncr && col.IsAutoIncrement {
|
||||
continue
|
||||
}
|
||||
if col.IsDeleted && !unscoped {
|
||||
continue
|
||||
}
|
||||
if omitColumnMap.Contain(col.Name) {
|
||||
continue
|
||||
}
|
||||
if len(columnMap) > 0 && !columnMap.Contain(col.Name) {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if col.MapType == schemas.ONLYFROMDB {
|
||||
continue
|
||||
}
|
||||
|
||||
if statement.IncrColumns.IsColExist(col.Name) {
|
||||
continue
|
||||
} else if statement.DecrColumns.IsColExist(col.Name) {
|
||||
continue
|
||||
} else if statement.ExprColumns.IsColExist(col.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldValuePtr, err := col.ValueOf(bean)
|
||||
fieldValuePtr, err := col.ValueOfV(&tableValue)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -273,9 +288,6 @@ func (statement *Statement) BuildUpdates(bean interface{},
|
|||
|
||||
APPEND:
|
||||
args = append(args, val)
|
||||
if col.IsPrimaryKey {
|
||||
continue
|
||||
}
|
||||
colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name)))
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ const (
|
|||
type Column struct {
|
||||
Name string
|
||||
TableName string
|
||||
FieldName string
|
||||
FieldName string // Avaiable only when parsed from a struct
|
||||
SQLType SQLType
|
||||
IsJSON bool
|
||||
Length int
|
||||
|
|
|
@ -53,13 +53,9 @@ func (table *Table) ColumnsSeq() []string {
|
|||
}
|
||||
|
||||
func (table *Table) columnsByName(name string) []*Column {
|
||||
n := len(name)
|
||||
for k := range table.columnsMap {
|
||||
if len(k) != n {
|
||||
continue
|
||||
}
|
||||
for k, cols := range table.columnsMap {
|
||||
if strings.EqualFold(k, name) {
|
||||
return table.columnsMap[k]
|
||||
return cols
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -177,7 +177,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
}
|
||||
|
||||
if session.statement.ColumnStr() == "" {
|
||||
colNames, args, err = session.statement.BuildUpdates(bean, false, false,
|
||||
colNames, args, err = session.statement.BuildUpdates(v, false, false,
|
||||
false, false, true)
|
||||
} else {
|
||||
colNames, args, err = session.genUpdateColumns(bean)
|
||||
|
|
|
@ -1303,3 +1303,40 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assertGetRecord()
|
||||
}
|
||||
|
||||
func TestUpdateMultiplePK(t *testing.T) {
|
||||
type TestUpdateMultiplePKStruct struct {
|
||||
Id string `xorm:"notnull pk" description:"唯一ID号"`
|
||||
Name string `xorm:"notnull pk" description:"名称"`
|
||||
Value string `xorm:"notnull varchar(4000)" description:"值"`
|
||||
}
|
||||
|
||||
assert.NoError(t, prepareEngine())
|
||||
assertSync(t, new(TestUpdateMultiplePKStruct))
|
||||
|
||||
test := &TestUpdateMultiplePKStruct{
|
||||
Id: "ID1",
|
||||
Name: "Name1",
|
||||
Value: "1",
|
||||
}
|
||||
_, err := testEngine.Insert(test)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.Value = "2"
|
||||
_, err = testEngine.Where("`id` = ? And `name` = ?", test.Id, test.Name).Cols("Value").Update(test)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.Value = "3"
|
||||
num, err := testEngine.Where("`id` = ? And `name` = ?", test.Id, test.Name).Update(test)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, num)
|
||||
|
||||
test.Value = "4"
|
||||
_, err = testEngine.ID([]interface{}{test.Id, test.Name}).Update(test)
|
||||
assert.NoError(t, err)
|
||||
|
||||
type MySlice []interface{}
|
||||
test.Value = "5"
|
||||
_, err = testEngine.ID(&MySlice{test.Id, test.Name}).Update(test)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user