Improve quote policy #1567

Merged
lunny merged 8 commits from lunny/quote_policy2 into master 2020-03-06 07:48:36 +00:00
2 changed files with 34 additions and 18 deletions
Showing only changes of commit 04e6027f52 - Show all commits

View File

@ -29,13 +29,10 @@ func (s *QuoteFilter) Do(sql string) string {
var buf strings.Builder
buf.Grow(len(sql))
var inSingleQuote bool
var beginSingleQuote bool
var prefix = true
for i := 0; i < len(sql); i++ {
if sql[i] == '\'' && (i == 0 || sql[i-1] != '\\') {
inSingleQuote = !inSingleQuote
}
if !inSingleQuote && sql[i] == '`' {
if !beginSingleQuote && sql[i] == '`' {
if prefix {
buf.WriteByte(s.quoter.Prefix)
} else {
@ -43,6 +40,9 @@ func (s *QuoteFilter) Do(sql string) string {
}
prefix = !prefix
} else {
if sql[i] == '\'' {
beginSingleQuote = !beginSingleQuote
}
buf.WriteByte(sql[i])
}
}
@ -57,23 +57,22 @@ type SeqFilter struct {
}
func convertQuestionMark(sql, prefix string, start int) string {
var (
buf strings.Builder
beginSingleQuote bool
index = start
)
for i, c := range sql {
var buf strings.Builder
var beginSingleQuote bool
var index = start
for _, c := range sql {
if !beginSingleQuote && c == '?' {
buf.WriteString(fmt.Sprintf("%s%v", prefix, index))
index++
} else {
if c == '\'' && (i > 0 && sql[i-1] != '\\') {
if c == '\'' {
beginSingleQuote = !beginSingleQuote
}
buf.WriteRune(c)
}
}
return buf.String()
}
func (s *SeqFilter) Do(sql string) string {

View File

@ -10,12 +10,29 @@ import (
func TestQuoteFilter_Do(t *testing.T) {
f := QuoteFilter{schemas.Quoter{'[', ']', schemas.AlwaysReverse}}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
res := f.Do(sql)
assert.EqualValues(t,
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
res,
)
var kases = []struct {
source string
expected string
}{
{
"SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?",
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
},
{
"SELECT 'abc```test```''', `a` FROM b",
"SELECT 'abc```test```''', [a] FROM b",
},
{
"UPDATE table SET `a` = ~ `a`, `b`='abc`'",
"UPDATE table SET [a] = ~ [a], [b]='abc`'",
},
}
for _, kase := range kases {
t.Run(kase.source, func(t *testing.T) {
assert.EqualValues(t, kase.expected, f.Do(kase.source))
})
}
}
func TestSeqFilter(t *testing.T) {