refactor some code #2000
|
@ -18,9 +18,15 @@ type ScanContext struct {
|
||||||
UserLocation *time.Location
|
UserLocation *time.Location
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DriverFeatures represents driver feature
|
||||||
|
type DriverFeatures struct {
|
||||||
|
SupportReturnInsertedID bool
|
||||||
|
}
|
||||||
|
|
||||||
// Driver represents a database driver
|
// Driver represents a database driver
|
||||||
type Driver interface {
|
type Driver interface {
|
||||||
Parse(string, string) (*URI, error)
|
Parse(string, string) (*URI, error)
|
||||||
|
Features() *DriverFeatures
|
||||||
GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface
|
GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface
|
||||||
Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error
|
Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -653,6 +653,12 @@ type odbcDriver struct {
|
||||||
baseDriver
|
baseDriver
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *odbcDriver) Features() *DriverFeatures {
|
||||||
|
return &DriverFeatures{
|
||||||
|
SupportReturnInsertedID: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
var dbName string
|
var dbName string
|
||||||
|
|
||||||
|
|
|
@ -674,6 +674,12 @@ type mysqlDriver struct {
|
||||||
baseDriver
|
baseDriver
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *mysqlDriver) Features() *DriverFeatures {
|
||||||
|
return &DriverFeatures{
|
||||||
|
SupportReturnInsertedID: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
dsnPattern := regexp.MustCompile(
|
dsnPattern := regexp.MustCompile(
|
||||||
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
|
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
|
||||||
|
|
|
@ -854,6 +854,12 @@ type godrorDriver struct {
|
||||||
baseDriver
|
baseDriver
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *godrorDriver) Features() *DriverFeatures {
|
||||||
|
return &DriverFeatures{
|
||||||
|
SupportReturnInsertedID: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (g *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
func (g *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
db := &URI{DBType: schemas.ORACLE}
|
db := &URI{DBType: schemas.ORACLE}
|
||||||
dsnPattern := regexp.MustCompile(
|
dsnPattern := regexp.MustCompile(
|
||||||
|
|
|
@ -1387,6 +1387,12 @@ func parseOpts(name string, o values) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *pqDriver) Features() *DriverFeatures {
|
||||||
|
return &DriverFeatures{
|
||||||
|
SupportReturnInsertedID: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
db := &URI{DBType: schemas.POSTGRES}
|
db := &URI{DBType: schemas.POSTGRES}
|
||||||
var err error
|
var err error
|
||||||
|
|
|
@ -565,6 +565,12 @@ type sqlite3Driver struct {
|
||||||
baseDriver
|
baseDriver
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *sqlite3Driver) Features() *DriverFeatures {
|
||||||
|
return &DriverFeatures{
|
||||||
|
SupportReturnInsertedID: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) {
|
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
if strings.Contains(dataSourceName, "?") {
|
if strings.Contains(dataSourceName, "?") {
|
||||||
dataSourceName = dataSourceName[:strings.Index(dataSourceName, "?")]
|
dataSourceName = dataSourceName[:strings.Index(dataSourceName, "?")]
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -334,13 +333,18 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
||||||
cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
|
cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
|
||||||
}
|
}
|
||||||
|
|
||||||
// for postgres, many of them didn't implement lastInsertId, so we should
|
// if there is auto increment column and driver don't support return it
|
||||||
// implemented it ourself.
|
if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID {
|
||||||
if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 {
|
var sql = sqlStr
|
||||||
res, err := session.queryBytes("select seq_atable.currval from dual", args...)
|
if session.engine.dialect.URI().DBType == schemas.ORACLE {
|
||||||
|
sql = "select seq_atable.currval from dual"
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := session.queryRows(sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
defer handleAfterInsertProcessorFunc(bean)
|
defer handleAfterInsertProcessorFunc(bean)
|
||||||
|
|
||||||
|
@ -355,56 +359,16 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(res) < 1 {
|
var id int64
|
||||||
return 0, errors.New("insert no error but not returned id")
|
if !rows.Next() {
|
||||||
}
|
if rows.Err() != nil {
|
||||||
|
return 0, rows.Err()
|
||||||
idByte := res[0][table.AutoIncrement]
|
|
||||||
id, err := strconv.ParseInt(string(idByte), 10, 64)
|
|
||||||
if err != nil || id <= 0 {
|
|
||||||
return 1, err
|
|
||||||
}
|
|
||||||
|
|
||||||
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
|
|
||||||
if err != nil {
|
|
||||||
session.engine.logger.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return 1, convertAssignV(aiValue.Addr(), id)
|
|
||||||
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
|
|
||||||
session.engine.dialect.URI().DBType == schemas.MSSQL) {
|
|
||||||
res, err := session.queryBytes(sqlStr, args...)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer handleAfterInsertProcessorFunc(bean)
|
|
||||||
|
|
||||||
session.cacheInsert(tableName)
|
|
||||||
|
|
||||||
if table.Version != "" && session.statement.CheckVersion {
|
|
||||||
verValue, err := table.VersionColumn().ValueOf(bean)
|
|
||||||
if err != nil {
|
|
||||||
session.engine.logger.Errorf("%v", err)
|
|
||||||
} else if verValue.IsValid() && verValue.CanSet() {
|
|
||||||
session.incrVersionFieldValue(verValue)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if len(res) < 1 {
|
|
||||||
return 0, errors.New("insert successfully but not returned id")
|
return 0, errors.New("insert successfully but not returned id")
|
||||||
}
|
}
|
||||||
|
if err := rows.Scan(&id); err != nil {
|
||||||
idByte := res[0][table.AutoIncrement]
|
|
||||||
id, err := strconv.ParseInt(string(idByte), 10, 64)
|
|
||||||
if err != nil || id <= 0 {
|
|
||||||
return 1, err
|
return 1, err
|
||||||
}
|
}
|
||||||
|
|
||||||
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
|
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
session.engine.logger.Errorf("%v", err)
|
session.engine.logger.Errorf("%v", err)
|
||||||
|
|
|
@ -242,6 +242,10 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, f
|
||||||
}
|
}
|
||||||
|
|
||||||
func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
|
func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
|
||||||
|
if isNotTitle(field.Name) {
|
||||||
|
return nil, ErrIgnoreField
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
tag = field.Tag
|
tag = field.Tag
|
||||||
ormTagStr = strings.TrimSpace(tag.Get(parser.identifier))
|
ormTagStr = strings.TrimSpace(tag.Get(parser.identifier))
|
||||||
|
@ -282,12 +286,7 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
|
||||||
table.Name = names.GetTableName(parser.tableMapper, v)
|
table.Name = names.GetTableName(parser.tableMapper, v)
|
||||||
|
|
||||||
for i := 0; i < t.NumField(); i++ {
|
for i := 0; i < t.NumField(); i++ {
|
||||||
var field = t.Field(i)
|
col, err := parser.parseField(table, i, t.Field(i), v.Field(i))
|
||||||
if isNotTitle(field.Name) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
col, err := parser.parseField(table, i, field, v.Field(i))
|
|
||||||
if err == ErrIgnoreField {
|
if err == ErrIgnoreField {
|
||||||
continue
|
continue
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
|
12
tags/tag.go
12
tags/tag.go
|
@ -101,11 +101,12 @@ type Handler func(ctx *Context) error
|
||||||
var (
|
var (
|
||||||
// defaultTagHandlers enumerates all the default tag handler
|
// defaultTagHandlers enumerates all the default tag handler
|
||||||
defaultTagHandlers = map[string]Handler{
|
defaultTagHandlers = map[string]Handler{
|
||||||
|
"-": IgnoreHandler,
|
||||||
"<-": OnlyFromDBTagHandler,
|
"<-": OnlyFromDBTagHandler,
|
||||||
"->": OnlyToDBTagHandler,
|
"->": OnlyToDBTagHandler,
|
||||||
"PK": PKTagHandler,
|
"PK": PKTagHandler,
|
||||||
"NULL": NULLTagHandler,
|
"NULL": NULLTagHandler,
|
||||||
"NOT": IgnoreTagHandler,
|
"NOT": NotTagHandler,
|
||||||
"AUTOINCR": AutoIncrTagHandler,
|
"AUTOINCR": AutoIncrTagHandler,
|
||||||
"DEFAULT": DefaultTagHandler,
|
"DEFAULT": DefaultTagHandler,
|
||||||
"CREATED": CreatedTagHandler,
|
"CREATED": CreatedTagHandler,
|
||||||
|
@ -130,11 +131,16 @@ func init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IgnoreTagHandler describes ignored tag handler
|
// NotTagHandler describes ignored tag handler
|
||||||
func IgnoreTagHandler(ctx *Context) error {
|
func NotTagHandler(ctx *Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IgnoreHandler represetns the field should be ignored
|
||||||
|
func IgnoreHandler(ctx *Context) error {
|
||||||
|
return ErrIgnoreField
|
||||||
|
}
|
||||||
|
|
||||||
// OnlyFromDBTagHandler describes mapping direction tag handler
|
// OnlyFromDBTagHandler describes mapping direction tag handler
|
||||||
func OnlyFromDBTagHandler(ctx *Context) error {
|
func OnlyFromDBTagHandler(ctx *Context) error {
|
||||||
ctx.col.MapType = schemas.ONLYFROMDB
|
ctx.col.MapType = schemas.ONLYFROMDB
|
||||||
|
|
Loading…
Reference in New Issue
Block a user