From b10535959b19ca9267a0e9772f717515149e62e5 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 26 Sep 2019 01:06:53 +0800 Subject: [PATCH 1/3] fix bug on insert where --- session_insert_test.go | 8 +++++++ statement_args.go | 53 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/session_insert_test.go b/session_insert_test.go index 88879ef6..b25cc6e1 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -908,6 +908,14 @@ func TestInsertWhere(t *testing.T) { assert.True(t, has) assert.EqualValues(t, "trest3", j3.Name) assert.EqualValues(t, 3, j3.Index) + + inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). + Insert(map[string]string{ + "name": "10';delete * from insert_where; --", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) } type NightlyRate struct { diff --git a/statement_args.go b/statement_args.go index 4ce336f4..0129d8e8 100644 --- a/statement_args.go +++ b/statement_args.go @@ -6,17 +6,60 @@ package xorm import ( "fmt" + "reflect" + "strings" + "time" "xorm.io/builder" "xorm.io/core" ) +func quoteNeeded(a interface{}) bool { + switch a.(type) { + case int, int8, int16, int32, int64: + return false + case uint, uint8, uint16, uint32, uint64: + return false + case float32, float64: + return false + case bool: + return false + case string: + return true + case time.Time, *time.Time: + return true + case builder.Builder, *builder.Builder: + return false + } + + t := reflect.TypeOf(a) + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return false + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return false + case reflect.Float32, reflect.Float64: + return false + case reflect.Bool: + return false + case reflect.String: + return true + } + + return true +} + +func covertArg(arg interface{}) string { + if quoteNeeded(arg) { + argv := fmt.Sprintf("%v", arg) + return "'" + strings.Replace(argv, "'", "''", -1) + "'" + } + + return fmt.Sprintf("%v", arg) +} + func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error { switch argv := arg.(type) { - case string: - if _, err := w.WriteString("'" + argv + "'"); err != nil { - return err - } case bool: if statement.Engine.dialect.DBType() == core.MSSQL { if argv { @@ -50,7 +93,7 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er return err } default: - if _, err := w.WriteString(fmt.Sprintf("%v", argv)); err != nil { + if _, err := w.WriteString(covertArg(arg)); err != nil { return err } } -- 2.40.1 From 93fc4604d0190786335a0d3d8a819756c9e66f53 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 26 Sep 2019 01:15:14 +0800 Subject: [PATCH 2/3] fix bug --- session_insert_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/session_insert_test.go b/session_insert_test.go index b25cc6e1..d040c9e9 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -911,8 +911,9 @@ func TestInsertWhere(t *testing.T) { inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). - Insert(map[string]string{ - "name": "10';delete * from insert_where; --", + Insert(map[string]interface{}{ + "repo_id": 1, + "name": "10';delete * from insert_where; --", }) assert.NoError(t, err) assert.EqualValues(t, 1, inserted) -- 2.40.1 From 350fbb4ca242ca607e3d8e4b0d21b008c3af719f Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 26 Sep 2019 01:21:09 +0800 Subject: [PATCH 3/3] fix lint --- statement_args.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/statement_args.go b/statement_args.go index 0129d8e8..23496443 100644 --- a/statement_args.go +++ b/statement_args.go @@ -49,7 +49,7 @@ func quoteNeeded(a interface{}) bool { return true } -func covertArg(arg interface{}) string { +func convertArg(arg interface{}) string { if quoteNeeded(arg) { argv := fmt.Sprintf("%v", arg) return "'" + strings.Replace(argv, "'", "''", -1) + "'" @@ -93,7 +93,7 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er return err } default: - if _, err := w.WriteString(covertArg(arg)); err != nil { + if _, err := w.WriteString(convertArg(arg)); err != nil { return err } } -- 2.40.1