refactor some code #2000
|
@ -18,9 +18,15 @@ type ScanContext struct {
|
|||
UserLocation *time.Location
|
||||
}
|
||||
|
||||
// DriverFeatures represents driver feature
|
||||
type DriverFeatures struct {
|
||||
SupportReturnInsertedID bool
|
||||
}
|
||||
|
||||
// Driver represents a database driver
|
||||
type Driver interface {
|
||||
Parse(string, string) (*URI, error)
|
||||
Features() *DriverFeatures
|
||||
GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface
|
||||
Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error
|
||||
}
|
||||
|
|
|
@ -653,6 +653,12 @@ type odbcDriver struct {
|
|||
baseDriver
|
||||
}
|
||||
|
||||
func (p *odbcDriver) Features() *DriverFeatures {
|
||||
return &DriverFeatures{
|
||||
SupportReturnInsertedID: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||
var dbName string
|
||||
|
||||
|
|
|
@ -674,6 +674,12 @@ type mysqlDriver struct {
|
|||
baseDriver
|
||||
}
|
||||
|
||||
func (p *mysqlDriver) Features() *DriverFeatures {
|
||||
return &DriverFeatures{
|
||||
SupportReturnInsertedID: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||
dsnPattern := regexp.MustCompile(
|
||||
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
|
||||
|
|
|
@ -854,6 +854,12 @@ type godrorDriver struct {
|
|||
baseDriver
|
||||
}
|
||||
|
||||
func (g *godrorDriver) Features() *DriverFeatures {
|
||||
return &DriverFeatures{
|
||||
SupportReturnInsertedID: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||
db := &URI{DBType: schemas.ORACLE}
|
||||
dsnPattern := regexp.MustCompile(
|
||||
|
|
|
@ -1387,6 +1387,12 @@ func parseOpts(name string, o values) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *pqDriver) Features() *DriverFeatures {
|
||||
return &DriverFeatures{
|
||||
SupportReturnInsertedID: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||
db := &URI{DBType: schemas.POSTGRES}
|
||||
var err error
|
||||
|
|
|
@ -565,6 +565,12 @@ type sqlite3Driver struct {
|
|||
baseDriver
|
||||
}
|
||||
|
||||
func (p *sqlite3Driver) Features() *DriverFeatures {
|
||||
return &DriverFeatures{
|
||||
SupportReturnInsertedID: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||
if strings.Contains(dataSourceName, "?") {
|
||||
dataSourceName = dataSourceName[:strings.Index(dataSourceName, "?")]
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -334,13 +333,18 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
|||
cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
|
||||
}
|
||||
|
||||
// for postgres, many of them didn't implement lastInsertId, so we should
|
||||
// implemented it ourself.
|
||||
if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 {
|
||||
res, err := session.queryBytes("select seq_atable.currval from dual", args...)
|
||||
// if there is auto increment column and driver don't support return it
|
||||
if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID {
|
||||
var sql = sqlStr
|
||||
if session.engine.dialect.URI().DBType == schemas.ORACLE {
|
||||
sql = "select seq_atable.currval from dual"
|
||||
}
|
||||
|
||||
rows, err := session.queryRows(sql, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
defer handleAfterInsertProcessorFunc(bean)
|
||||
|
||||
|
@ -355,56 +359,16 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
|||
}
|
||||
}
|
||||
|
||||
if len(res) < 1 {
|
||||
return 0, errors.New("insert no error but not returned id")
|
||||
}
|
||||
|
||||
idByte := res[0][table.AutoIncrement]
|
||||
id, err := strconv.ParseInt(string(idByte), 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
return 1, err
|
||||
}
|
||||
|
||||
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
|
||||
if err != nil {
|
||||
session.engine.logger.Errorf("%v", err)
|
||||
}
|
||||
|
||||
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
return 1, convertAssignV(aiValue.Addr(), id)
|
||||
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
|
||||
session.engine.dialect.URI().DBType == schemas.MSSQL) {
|
||||
res, err := session.queryBytes(sqlStr, args...)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer handleAfterInsertProcessorFunc(bean)
|
||||
|
||||
session.cacheInsert(tableName)
|
||||
|
||||
if table.Version != "" && session.statement.CheckVersion {
|
||||
verValue, err := table.VersionColumn().ValueOf(bean)
|
||||
if err != nil {
|
||||
session.engine.logger.Errorf("%v", err)
|
||||
} else if verValue.IsValid() && verValue.CanSet() {
|
||||
session.incrVersionFieldValue(verValue)
|
||||
var id int64
|
||||
if !rows.Next() {
|
||||
if rows.Err() != nil {
|
||||
return 0, rows.Err()
|
||||
}
|
||||
}
|
||||
|
||||
if len(res) < 1 {
|
||||
return 0, errors.New("insert successfully but not returned id")
|
||||
}
|
||||
|
||||
idByte := res[0][table.AutoIncrement]
|
||||
id, err := strconv.ParseInt(string(idByte), 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return 1, err
|
||||
}
|
||||
|
||||
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
|
||||
if err != nil {
|
||||
session.engine.logger.Errorf("%v", err)
|
||||
|
|
|
@ -242,6 +242,10 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, f
|
|||
}
|
||||
|
||||
func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
|
||||
if isNotTitle(field.Name) {
|
||||
return nil, ErrIgnoreField
|
||||
}
|
||||
|
||||
var (
|
||||
tag = field.Tag
|
||||
ormTagStr = strings.TrimSpace(tag.Get(parser.identifier))
|
||||
|
@ -282,12 +286,7 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
|
|||
table.Name = names.GetTableName(parser.tableMapper, v)
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
var field = t.Field(i)
|
||||
if isNotTitle(field.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
col, err := parser.parseField(table, i, field, v.Field(i))
|
||||
col, err := parser.parseField(table, i, t.Field(i), v.Field(i))
|
||||
if err == ErrIgnoreField {
|
||||
continue
|
||||
} else if err != nil {
|
||||
|
|
12
tags/tag.go
12
tags/tag.go
|
@ -101,11 +101,12 @@ type Handler func(ctx *Context) error
|
|||
var (
|
||||
// defaultTagHandlers enumerates all the default tag handler
|
||||
defaultTagHandlers = map[string]Handler{
|
||||
"-": IgnoreHandler,
|
||||
"<-": OnlyFromDBTagHandler,
|
||||
"->": OnlyToDBTagHandler,
|
||||
"PK": PKTagHandler,
|
||||
"NULL": NULLTagHandler,
|
||||
"NOT": IgnoreTagHandler,
|
||||
"NOT": NotTagHandler,
|
||||
"AUTOINCR": AutoIncrTagHandler,
|
||||
"DEFAULT": DefaultTagHandler,
|
||||
"CREATED": CreatedTagHandler,
|
||||
|
@ -130,11 +131,16 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
// IgnoreTagHandler describes ignored tag handler
|
||||
func IgnoreTagHandler(ctx *Context) error {
|
||||
// NotTagHandler describes ignored tag handler
|
||||
func NotTagHandler(ctx *Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// IgnoreHandler represetns the field should be ignored
|
||||
func IgnoreHandler(ctx *Context) error {
|
||||
return ErrIgnoreField
|
||||
}
|
||||
|
||||
// OnlyFromDBTagHandler describes mapping direction tag handler
|
||||
func OnlyFromDBTagHandler(ctx *Context) error {
|
||||
ctx.col.MapType = schemas.ONLYFROMDB
|
||||
|
|
Loading…
Reference in New Issue
Block a user