Add pgx driver support #1795

Merged
lunny merged 7 commits from lunny/pgx_support into master 2021-08-05 06:04:11 +00:00
3 changed files with 78 additions and 88 deletions
Showing only changes of commit de95ea6bb0 - Show all commits

View File

@ -349,7 +349,7 @@ func TestUpdate1(t *testing.T) {
And("height = ?", user.Height).
And("departname = ?", "").
And("detail_id = ?", 0).
And("is_man = ?", 0).
And("is_man = ?", false).
Get(&Userinfo{})
assert.NoError(t, err)
assert.True(t, has, "cannot insert properly")
@ -825,7 +825,7 @@ func TestNewUpdate(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 0, af)
af, err = testEngine.Table(new(TbUserInfo)).Where("phone=?", 13126564922).Update(&changeUsr)
af, err = testEngine.Table(new(TbUserInfo)).Where("phone=?", "13126564922").Update(&changeUsr)
assert.NoError(t, err)
assert.EqualValues(t, 0, af)
}

View File

@ -5,9 +5,9 @@
package schemas
import (
"database/sql"
"math/big"
"reflect"
"sort"
"strings"
"time"
)
@ -229,88 +229,40 @@ var (
Array: ARRAY_TYPE,
}
intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"}
uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"}
)
// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison
var (
emptyString string
boolDefault bool
byteDefault byte
complex64Default complex64
complex128Default complex128
float32Default float32
float64Default float64
int64Default int64
uint64Default uint64
int32Default int32
uint32Default uint32
int16Default int16
uint16Default uint16
int8Default int8
uint8Default uint8
intDefault int
uintDefault uint
timeDefault time.Time
bigFloatDefault big.Float
)
// enumerates all types
var (
IntType = reflect.TypeOf(intDefault)
Int8Type = reflect.TypeOf(int8Default)
Int16Type = reflect.TypeOf(int16Default)
Int32Type = reflect.TypeOf(int32Default)
Int64Type = reflect.TypeOf(int64Default)
IntType = reflect.TypeOf((*int)(nil)).Elem()
Int8Type = reflect.TypeOf((*int8)(nil)).Elem()
Int16Type = reflect.TypeOf((*int16)(nil)).Elem()
Int32Type = reflect.TypeOf((*int32)(nil)).Elem()
Int64Type = reflect.TypeOf((*int64)(nil)).Elem()
UintType = reflect.TypeOf(uintDefault)
Uint8Type = reflect.TypeOf(uint8Default)
Uint16Type = reflect.TypeOf(uint16Default)
Uint32Type = reflect.TypeOf(uint32Default)
Uint64Type = reflect.TypeOf(uint64Default)
UintType = reflect.TypeOf((*uint)(nil)).Elem()
Uint8Type = reflect.TypeOf((*uint8)(nil)).Elem()
Uint16Type = reflect.TypeOf((*uint16)(nil)).Elem()
Uint32Type = reflect.TypeOf((*uint32)(nil)).Elem()
Uint64Type = reflect.TypeOf((*uint64)(nil)).Elem()
Float32Type = reflect.TypeOf(float32Default)
Float64Type = reflect.TypeOf(float64Default)
Float32Type = reflect.TypeOf((*float32)(nil)).Elem()
Float64Type = reflect.TypeOf((*float64)(nil)).Elem()
Complex64Type = reflect.TypeOf(complex64Default)
Complex128Type = reflect.TypeOf(complex128Default)
Complex64Type = reflect.TypeOf((*complex64)(nil)).Elem()
Complex128Type = reflect.TypeOf((*complex128)(nil)).Elem()
StringType = reflect.TypeOf(emptyString)
BoolType = reflect.TypeOf(boolDefault)
ByteType = reflect.TypeOf(byteDefault)
StringType = reflect.TypeOf((*string)(nil)).Elem()
BoolType = reflect.TypeOf((*bool)(nil)).Elem()
ByteType = reflect.TypeOf((*byte)(nil)).Elem()
BytesType = reflect.SliceOf(ByteType)
TimeType = reflect.TypeOf(timeDefault)
BigFloatType = reflect.TypeOf(bigFloatDefault)
)
// enumerates all types
var (
PtrIntType = reflect.PtrTo(IntType)
PtrInt8Type = reflect.PtrTo(Int8Type)
PtrInt16Type = reflect.PtrTo(Int16Type)
PtrInt32Type = reflect.PtrTo(Int32Type)
PtrInt64Type = reflect.PtrTo(Int64Type)
PtrUintType = reflect.PtrTo(UintType)
PtrUint8Type = reflect.PtrTo(Uint8Type)
PtrUint16Type = reflect.PtrTo(Uint16Type)
PtrUint32Type = reflect.PtrTo(Uint32Type)
PtrUint64Type = reflect.PtrTo(Uint64Type)
PtrFloat32Type = reflect.PtrTo(Float32Type)
PtrFloat64Type = reflect.PtrTo(Float64Type)
PtrComplex64Type = reflect.PtrTo(Complex64Type)
PtrComplex128Type = reflect.PtrTo(Complex128Type)
PtrStringType = reflect.PtrTo(StringType)
PtrBoolType = reflect.PtrTo(BoolType)
PtrByteType = reflect.PtrTo(ByteType)
PtrTimeType = reflect.PtrTo(TimeType)
TimeType = reflect.TypeOf((*time.Time)(nil)).Elem()
BigFloatType = reflect.TypeOf((*big.Float)(nil)).Elem()
NullFloat64Type = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem()
NullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem()
NullInt32Type = reflect.TypeOf((*sql.NullInt32)(nil)).Elem()
NullInt64Type = reflect.TypeOf((*sql.NullInt64)(nil)).Elem()
NullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem()
)
// Type2SQLType generate SQLType acorrding Go's type
@ -331,7 +283,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) {
case reflect.Complex64, reflect.Complex128:
st = SQLType{Varchar, 64, 0}
case reflect.Array, reflect.Slice, reflect.Map:
if t.Elem() == reflect.TypeOf(byteDefault) {
if t.Elem() == ByteType {
st = SQLType{Blob, 0, 0}
} else {
st = SQLType{Text, 0, 0}
@ -343,6 +295,16 @@ func Type2SQLType(t reflect.Type) (st SQLType) {
case reflect.Struct:
if t.ConvertibleTo(TimeType) {
st = SQLType{DateTime, 0, 0}
} else if t.ConvertibleTo(NullFloat64Type) {
st = SQLType{Double, 0, 0}
} else if t.ConvertibleTo(NullStringType) {
st = SQLType{Varchar, 255, 0}
} else if t.ConvertibleTo(NullInt32Type) {
st = SQLType{Integer, 0, 0}
} else if t.ConvertibleTo(NullInt64Type) {
st = SQLType{BigInt, 0, 0}
} else if t.ConvertibleTo(NullBoolType) {
st = SQLType{Boolean, 0, 0}
} else {
// TODO need to handle association struct
st = SQLType{Text, 0, 0}
@ -360,25 +322,25 @@ func SQLType2Type(st SQLType) reflect.Type {
name := strings.ToUpper(st.Name)
switch name {
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial:
return reflect.TypeOf(1)
return IntType
case BigInt, BigSerial:
return reflect.TypeOf(int64(1))
return Int64Type
case Float, Real:
return reflect.TypeOf(float32(1))
return Float32Type
case Double:
return reflect.TypeOf(float64(1))
return Float64Type
case Char, NChar, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName:
return reflect.TypeOf("")
return StringType
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier:
return reflect.TypeOf([]byte{})
return BytesType
case Bool:
return reflect.TypeOf(true)
return BoolType
case DateTime, Date, Time, TimeStamp, TimeStampz, SmallDateTime, Year:
return reflect.TypeOf(timeDefault)
return TimeType
case Decimal, Numeric, Money, SmallMoney:
return reflect.TypeOf("")
return StringType
default:
return reflect.TypeOf("")
return StringType
}
}

View File

@ -7,6 +7,7 @@ package tags
import (
"encoding/gob"
"errors"
"fmt"
"reflect"
"strings"
"sync"
@ -127,6 +128,25 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index
// ErrIgnoreField represents an error to ignore field
var ErrIgnoreField = errors.New("field will be ignored")
func (parser *Parser) getSQLTypeByType(t reflect.Type) (schemas.SQLType, error) {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() == reflect.Struct {
v, ok := parser.tableCache.Load(t)
if ok {
pkCols := v.(*schemas.Table).PKColumns()
if len(pkCols) == 1 {
return pkCols[0].SQLType, nil
}
if len(pkCols) > 1 {
return schemas.SQLType{}, fmt.Errorf("unsupported mulitiple primary key on cascade")
}
}
}
return schemas.Type2SQLType(t), nil
}
func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
var sqlType schemas.SQLType
if fieldValue.CanAddr() {
@ -137,7 +157,11 @@ func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructFi
if _, ok := fieldValue.Interface().(convert.Conversion); ok {
sqlType = schemas.SQLType{Name: schemas.Text}
} else {
sqlType = schemas.Type2SQLType(field.Type)
var err error
sqlType, err = parser.getSQLTypeByType(field.Type)
if err != nil {
return nil, err
}
}
col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name),
field.Name, sqlType, sqlType.DefaultLength,
@ -215,7 +239,11 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, f
}
if col.SQLType.Name == "" {
col.SQLType = schemas.Type2SQLType(field.Type)
var err error
col.SQLType, err = parser.getSQLTypeByType(field.Type)
if err != nil {
return nil, err
}
}
if ctx.isUnsigned && col.SQLType.IsNumeric() && !strings.HasPrefix(col.SQLType.Name, "UNSIGNED") {
col.SQLType.Name = "UNSIGNED " + col.SQLType.Name