refactor some code #2000

Merged
lunny merged 2 commits from lunny/refactor_more into master 2021-07-20 16:12:21 +00:00
9 changed files with 64 additions and 59 deletions

View File

@ -18,9 +18,15 @@ type ScanContext struct {
UserLocation *time.Location
}
// DriverFeatures represents driver feature
type DriverFeatures struct {
SupportReturnInsertedID bool
}
// Driver represents a database driver
type Driver interface {
Parse(string, string) (*URI, error)
Features() *DriverFeatures
GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface
Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error
}

View File

@ -653,6 +653,12 @@ type odbcDriver struct {
baseDriver
}
func (p *odbcDriver) Features() *DriverFeatures {
return &DriverFeatures{
SupportReturnInsertedID: false,
}
}
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
var dbName string

View File

@ -674,6 +674,12 @@ type mysqlDriver struct {
baseDriver
}
func (p *mysqlDriver) Features() *DriverFeatures {
return &DriverFeatures{
SupportReturnInsertedID: true,
}
}
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]

View File

@ -854,6 +854,12 @@ type godrorDriver struct {
baseDriver
}
func (g *godrorDriver) Features() *DriverFeatures {
return &DriverFeatures{
SupportReturnInsertedID: false,
}
}
func (g *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &URI{DBType: schemas.ORACLE}
dsnPattern := regexp.MustCompile(

View File

@ -1387,6 +1387,12 @@ func parseOpts(name string, o values) error {
return nil
}
func (p *pqDriver) Features() *DriverFeatures {
return &DriverFeatures{
SupportReturnInsertedID: false,
}
}
func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &URI{DBType: schemas.POSTGRES}
var err error

View File

@ -565,6 +565,12 @@ type sqlite3Driver struct {
baseDriver
}
func (p *sqlite3Driver) Features() *DriverFeatures {
return &DriverFeatures{
SupportReturnInsertedID: true,
}
}
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) {
if strings.Contains(dataSourceName, "?") {
dataSourceName = dataSourceName[:strings.Index(dataSourceName, "?")]

View File

@ -9,7 +9,6 @@ import (
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"time"
@ -334,13 +333,18 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
}
// for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself.
if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 {
res, err := session.queryBytes("select seq_atable.currval from dual", args...)
// if there is auto increment column and driver don't support return it
if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID {
var sql = sqlStr
if session.engine.dialect.URI().DBType == schemas.ORACLE {
sql = "select seq_atable.currval from dual"
}
rows, err := session.queryRows(sql, args...)
if err != nil {
return 0, err
}
defer rows.Close()
defer handleAfterInsertProcessorFunc(bean)
@ -355,56 +359,16 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
}
}
if len(res) < 1 {
return 0, errors.New("insert no error but not returned id")
}
idByte := res[0][table.AutoIncrement]
id, err := strconv.ParseInt(string(idByte), 10, 64)
if err != nil || id <= 0 {
return 1, err
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
return 1, nil
}
return 1, convertAssignV(aiValue.Addr(), id)
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
session.engine.dialect.URI().DBType == schemas.MSSQL) {
res, err := session.queryBytes(sqlStr, args...)
if err != nil {
return 0, err
}
defer handleAfterInsertProcessorFunc(bean)
session.cacheInsert(tableName)
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
var id int64
if !rows.Next() {
if rows.Err() != nil {
return 0, rows.Err()
}
}
if len(res) < 1 {
return 0, errors.New("insert successfully but not returned id")
}
idByte := res[0][table.AutoIncrement]
id, err := strconv.ParseInt(string(idByte), 10, 64)
if err != nil || id <= 0 {
if err := rows.Scan(&id); err != nil {
return 1, err
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)

View File

@ -242,6 +242,10 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, f
}
func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
if isNotTitle(field.Name) {
return nil, ErrIgnoreField
}
var (
tag = field.Tag
ormTagStr = strings.TrimSpace(tag.Get(parser.identifier))
@ -282,12 +286,7 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
table.Name = names.GetTableName(parser.tableMapper, v)
for i := 0; i < t.NumField(); i++ {
var field = t.Field(i)
if isNotTitle(field.Name) {
continue
}
col, err := parser.parseField(table, i, field, v.Field(i))
col, err := parser.parseField(table, i, t.Field(i), v.Field(i))
if err == ErrIgnoreField {
continue
} else if err != nil {

View File

@ -101,11 +101,12 @@ type Handler func(ctx *Context) error
var (
// defaultTagHandlers enumerates all the default tag handler
defaultTagHandlers = map[string]Handler{
"-": IgnoreHandler,
"<-": OnlyFromDBTagHandler,
"->": OnlyToDBTagHandler,
"PK": PKTagHandler,
"NULL": NULLTagHandler,
"NOT": IgnoreTagHandler,
"NOT": NotTagHandler,
"AUTOINCR": AutoIncrTagHandler,
"DEFAULT": DefaultTagHandler,
"CREATED": CreatedTagHandler,
@ -130,11 +131,16 @@ func init() {
}
}
// IgnoreTagHandler describes ignored tag handler
func IgnoreTagHandler(ctx *Context) error {
// NotTagHandler describes ignored tag handler
func NotTagHandler(ctx *Context) error {
return nil
}
// IgnoreHandler represetns the field should be ignored
func IgnoreHandler(ctx *Context) error {
return ErrIgnoreField
}
// OnlyFromDBTagHandler describes mapping direction tag handler
func OnlyFromDBTagHandler(ctx *Context) error {
ctx.col.MapType = schemas.ONLYFROMDB