Improve quote policy #1567

Merged
lunny merged 8 commits from lunny/quote_policy2 into master 2020-03-06 07:48:36 +00:00
2 changed files with 34 additions and 18 deletions
Showing only changes of commit 04e6027f52 - Show all commits

View File

@ -29,13 +29,10 @@ func (s *QuoteFilter) Do(sql string) string {
var buf strings.Builder var buf strings.Builder
buf.Grow(len(sql)) buf.Grow(len(sql))
var inSingleQuote bool var beginSingleQuote bool
var prefix = true var prefix = true
for i := 0; i < len(sql); i++ { for i := 0; i < len(sql); i++ {
if sql[i] == '\'' && (i == 0 || sql[i-1] != '\\') { if !beginSingleQuote && sql[i] == '`' {
inSingleQuote = !inSingleQuote
}
if !inSingleQuote && sql[i] == '`' {
if prefix { if prefix {
buf.WriteByte(s.quoter.Prefix) buf.WriteByte(s.quoter.Prefix)
} else { } else {
@ -43,6 +40,9 @@ func (s *QuoteFilter) Do(sql string) string {
} }
prefix = !prefix prefix = !prefix
} else { } else {
if sql[i] == '\'' {
beginSingleQuote = !beginSingleQuote
}
buf.WriteByte(sql[i]) buf.WriteByte(sql[i])
} }
} }
@ -57,23 +57,22 @@ type SeqFilter struct {
} }
func convertQuestionMark(sql, prefix string, start int) string { func convertQuestionMark(sql, prefix string, start int) string {
var ( var buf strings.Builder
buf strings.Builder var beginSingleQuote bool
beginSingleQuote bool var index = start
index = start for _, c := range sql {
)
for i, c := range sql {
if !beginSingleQuote && c == '?' { if !beginSingleQuote && c == '?' {
buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) buf.WriteString(fmt.Sprintf("%s%v", prefix, index))
index++ index++
} else { } else {
if c == '\'' && (i > 0 && sql[i-1] != '\\') { if c == '\'' {
beginSingleQuote = !beginSingleQuote beginSingleQuote = !beginSingleQuote
} }
buf.WriteRune(c) buf.WriteRune(c)
} }
} }
return buf.String() return buf.String()
} }
func (s *SeqFilter) Do(sql string) string { func (s *SeqFilter) Do(sql string) string {

View File

@ -10,12 +10,29 @@ import (
func TestQuoteFilter_Do(t *testing.T) { func TestQuoteFilter_Do(t *testing.T) {
f := QuoteFilter{schemas.Quoter{'[', ']', schemas.AlwaysReverse}} f := QuoteFilter{schemas.Quoter{'[', ']', schemas.AlwaysReverse}}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" var kases = []struct {
res := f.Do(sql) source string
assert.EqualValues(t, expected string
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?", }{
res, {
) "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?",
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
},
{
"SELECT 'abc```test```''', `a` FROM b",
"SELECT 'abc```test```''', [a] FROM b",
},
{
"UPDATE table SET `a` = ~ `a`, `b`='abc`'",
"UPDATE table SET [a] = ~ [a], [b]='abc`'",
},
}
for _, kase := range kases {
t.Run(kase.source, func(t *testing.T) {
assert.EqualValues(t, kase.expected, f.Do(kase.source))
})
}
} }
func TestSeqFilter(t *testing.T) { func TestSeqFilter(t *testing.T) {