Improve quote policy #1567
|
@ -29,13 +29,10 @@ func (s *QuoteFilter) Do(sql string) string {
|
|||
var buf strings.Builder
|
||||
buf.Grow(len(sql))
|
||||
|
||||
var inSingleQuote bool
|
||||
var beginSingleQuote bool
|
||||
var prefix = true
|
||||
for i := 0; i < len(sql); i++ {
|
||||
if sql[i] == '\'' && (i == 0 || sql[i-1] != '\\') {
|
||||
inSingleQuote = !inSingleQuote
|
||||
}
|
||||
if !inSingleQuote && sql[i] == '`' {
|
||||
if !beginSingleQuote && sql[i] == '`' {
|
||||
if prefix {
|
||||
buf.WriteByte(s.quoter.Prefix)
|
||||
} else {
|
||||
|
@ -43,6 +40,9 @@ func (s *QuoteFilter) Do(sql string) string {
|
|||
}
|
||||
prefix = !prefix
|
||||
} else {
|
||||
if sql[i] == '\'' {
|
||||
beginSingleQuote = !beginSingleQuote
|
||||
}
|
||||
buf.WriteByte(sql[i])
|
||||
}
|
||||
}
|
||||
|
@ -57,23 +57,22 @@ type SeqFilter struct {
|
|||
}
|
||||
|
||||
func convertQuestionMark(sql, prefix string, start int) string {
|
||||
var (
|
||||
buf strings.Builder
|
||||
beginSingleQuote bool
|
||||
index = start
|
||||
)
|
||||
for i, c := range sql {
|
||||
var buf strings.Builder
|
||||
var beginSingleQuote bool
|
||||
var index = start
|
||||
for _, c := range sql {
|
||||
if !beginSingleQuote && c == '?' {
|
||||
buf.WriteString(fmt.Sprintf("%s%v", prefix, index))
|
||||
index++
|
||||
} else {
|
||||
if c == '\'' && (i > 0 && sql[i-1] != '\\') {
|
||||
if c == '\'' {
|
||||
beginSingleQuote = !beginSingleQuote
|
||||
}
|
||||
buf.WriteRune(c)
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
|
||||
}
|
||||
|
||||
func (s *SeqFilter) Do(sql string) string {
|
||||
|
|
|
@ -10,12 +10,29 @@ import (
|
|||
|
||||
func TestQuoteFilter_Do(t *testing.T) {
|
||||
f := QuoteFilter{schemas.Quoter{'[', ']', schemas.AlwaysReverse}}
|
||||
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
|
||||
res := f.Do(sql)
|
||||
assert.EqualValues(t,
|
||||
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
|
||||
res,
|
||||
)
|
||||
var kases = []struct {
|
||||
source string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?",
|
||||
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
|
||||
},
|
||||
{
|
||||
"SELECT 'abc```test```''', `a` FROM b",
|
||||
"SELECT 'abc```test```''', [a] FROM b",
|
||||
},
|
||||
{
|
||||
"UPDATE table SET `a` = ~ `a`, `b`='abc`'",
|
||||
"UPDATE table SET [a] = ~ [a], [b]='abc`'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, kase := range kases {
|
||||
t.Run(kase.source, func(t *testing.T) {
|
||||
assert.EqualValues(t, kase.expected, f.Do(kase.source))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeqFilter(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user