db2 support #1387

Closed
lunny wants to merge 18 commits from lunny/db2_support into master
21 changed files with 816 additions and 118 deletions

View File

@ -363,6 +363,62 @@ services:
commands:
- /cockroach/cockroach start --insecure
---
kind: pipeline
name: test-db2
depends_on:
- test-cockroach
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-db2
pull: never
image: golang:1.15
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_DB2_HOST: db2
TEST_DB2_PORT: 50000
TEST_DB2_DBNAME: xorm_test
TEST_DB2_USERNAME: sa
TEST_DB2_PASSWORD: xorm_test
DB2HOME: /go/pkg/mod/github.com/ibmdb/clidriver
CGO_CFLAGS: -I$DB2HOME/include
CGO_LDFLAGS: -L$DB2HOME/lib
LD_LIBRARY_PATH: $DB2HOME/lib
commands:
- make test-db2
- TEST_CACHE_ENABLE=true make test-db2
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: db2
pull: default
image: store/ibmcorp/db2_developer_c:11.1.4.4-x86_64
environment:
LICENSE: accept
DB2INSTANCE: db2inst1
DB2INST1_PASSWORD: xorm_test
DBNAME: xorm_test
BLU: false
ENABLE_ORACLE_COMPATIBILITY: false
UPDATEAVAIL: NO
TO_CREATE_SAMPLEDB: false
REPODB: false
IS_OSXFS: true
PERSISTENT_HOME: true
HADR_ENABLED: false
ETCD_ENDPOINT:
ETCD_USERNAME:
ETCD_PASSWORD:
---
kind: pipeline
name: merge_coverage
@ -374,6 +430,7 @@ depends_on:
- test-mssql
- test-tidb
- test-cockroach
- test-db2
trigger:
ref:
- refs/heads/master

View File

@ -18,6 +18,12 @@ TEST_COCKROACH_DBNAME ?= xorm_test
TEST_COCKROACH_USERNAME ?= postgres
TEST_COCKROACH_PASSWORD ?=
TEST_DB2_HOST ?= db2
TEST_DB2_PORT ?= 50000
TEST_DB2_DBNAME ?= gitea
TEST_DB2_USERNAME ?= sa
TEST_DB2_PASSWORD ?= MwantsaSecurePassword1
TEST_MSSQL_HOST ?= mssql:1433
TEST_MSSQL_DBNAME ?= gitea
TEST_MSSQL_USERNAME ?= sa
@ -46,6 +52,10 @@ TEST_TIDB_PASSWORD ?=
TEST_CACHE_ENABLE ?= false
TEST_QUOTE_POLICY ?= always
DB2ORG := $(GOPATH)/pkg/mod/github.com/ibmdb
DB2HOME := $(DB2ORG)/go_ibm_db@v0.4.1/installer
DB2_DRIVER_DIR := $(DB2ORG)/clidriver
.PHONY: all
all: build
@ -146,6 +156,22 @@ test-cockroach\#%: go-check
-conn_str="postgres://$(TEST_COCKROACH_USERNAME):$(TEST_COCKROACH_PASSWORD)@$(TEST_COCKROACH_HOST)/$(TEST_COCKROACH_DBNAME)?sslmode=disable&experimental_serial_normalization=sql_sequence" \
-ignore_update_limit=true -coverprofile=cockroach.$(TEST_COCKROACH_SCHEMA).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
$(DB2_DRIVER_DIR):
go get -d -v github.com/ibmdb/go_ibm_db
cd $(DB2HOME) && go run setup.go
.PNONY: test-db2
test-db2: go-check $(DB2_DRIVER_DIR)
$(GO) test $(INTEGRATION_PACKAGES) -v -tags=db2 -db=go_ibm_db -cache=$(TEST_CACHE_ENABLE) \
-conn_str="HOSTNAME=$(TEST_DB2_HOST);DATABASE=$(TEST_DB2_DBNAME);PORT=$(TEST_DB2_PORT);UID=$(TEST_DB2_USERNAME);PWD=$(TEST_DB2_PASSWORD)" \
-coverprofile=db2.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PNONY: test-db2\#%
test-db2\#%: go-check
$(GO) test $(INTEGRATION_PACKAGES) -v -run $* -db=go_ibm_db -tags=db2 -cache=$(TEST_CACHE_ENABLE) \
-conn_str="HOSTNAME=$(TEST_DB2_HOST);DATABASE=$(TEST_DB2_DBNAME);PORT=$(TEST_DB2_PORT);UID=$(TEST_DB2_USERNAME);PWD=$(TEST_DB2_PASSWORD)" \
-coverprofile=db2.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PNONY: test-mssql
test-mssql: go-check
$(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mssql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \
@ -238,8 +264,7 @@ test-sqlite-schema: go-check
.PHONY: test-sqlite\#%
test-sqlite\#%: go-check
$(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -cache=$(TEST_CACHE_ENABLE) -db=sqlite -conn_str="./test.db?cache=shared&mode=rwc" \
-quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomicattt
.PNONY: test-tidb
test-tidb: go-check

View File

@ -6,7 +6,7 @@ xorm 是一个简单而强大的Go语言ORM库. 通过它可以使数据库操
[![Build Status](https://drone.gitea.com/api/badges/xorm/xorm/status.svg)](https://drone.gitea.com/xorm/xorm) [![](http://gocover.io/_badge/xorm.io/xorm)](https://gocover.io/xorm.io/xorm) [![](https://goreportcard.com/badge/xorm.io/xorm)](https://goreportcard.com/report/xorm.io/xorm) [![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3)
## Notice
## 注意
v1.0.0 相对于 v0.8.2 有以下不兼容的变更:

534
dialects/db2.go Normal file
View File

@ -0,0 +1,534 @@
// Copyright 2020 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dialects
import (
"context"
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"xorm.io/xorm/core"
"xorm.io/xorm/schemas"
)
var (
db2ReservedWords = map[string]bool{}
db2Quoter = schemas.Quoter{
Prefix: '"',
Suffix: '"',
IsReserved: schemas.AlwaysReserve,
}
)
type db2 struct {
Base
}
func (db *db2) Init(uri *URI) error {
db.quoter = db2Quoter
return db.Base.Init(db, uri)
}
func (db *db2) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) {
rows, err := queryer.QueryContext(ctx, "SELECT service_level, fixpack_num FROM TABLE(sysproc.env_get_inst_info()) as INSTANCEINFO")
if err != nil {
return nil, err
}
defer rows.Close()
if rows.Next() {
var serviceLevel, fixpackNum string
if err := rows.Scan(&serviceLevel, &fixpackNum); err != nil {
return nil, err
}
parts := strings.Split(serviceLevel, " ")
return &schemas.Version{
Number: parts[1],
Level: fixpackNum,
Edition: parts[0],
}, nil
}
return nil, rows.Err()
}
func (db *db2) Features() *DialectFeatures {
return &DialectFeatures{
DefaultClause: "WITH DEFAULT",
}
}
func (db *db2) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) {
case "DATE", "DATETIME", "DATETIME2", "TIME":
return schemas.TIME_TYPE
case "VARCHAR", "TEXT", "CHAR", "NVARCHAR", "NCHAR", "NTEXT":
return schemas.TEXT_TYPE
case "FLOAT", "REAL", "BIGINT", "DATETIMEOFFSET", "TINYINT", "SMALLINT", "INT":
return schemas.NUMERIC_TYPE
default:
return schemas.UNKNOW_TYPE
}
}
func (db *db2) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {
case schemas.TinyInt:
res = schemas.SmallInt
return res
case schemas.UnsignedBigInt:
return schemas.BigInt
case schemas.UnsignedInt, schemas.BigInt:
return schemas.BigInt
case schemas.Int, schemas.Integer:
return schemas.Integer
case schemas.Bit, schemas.Bool, schemas.Boolean:
res = schemas.Boolean
return res
case schemas.Binary:
res = schemas.Binary
case schemas.DateTime:
res = schemas.TimeStamp
case schemas.TimeStampz:
return "timestamp with time zone"
case schemas.TinyText, schemas.MediumText, schemas.LongText:
res = schemas.Text
case schemas.NVarchar:
res = schemas.Varchar
case schemas.Uuid:
return schemas.Uuid
case schemas.VarBinary, schemas.Bytea:
res = schemas.VarBinary
if c.Length == 0 {
return res + "(MAX)"
}
case schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob:
return schemas.Blob
default:
res = t
}
hasLen1 := (c.Length > 0)
hasLen2 := (c.Length2 > 0)
if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
} else if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")"
}
return res
}
func (db *db2) SupportInsertMany() bool {
return true
}
func (db *db2) IsReserved(name string) bool {
_, ok := db2ReservedWords[name]
return ok
}
func (db *db2) AutoIncrStr() string {
return ""
}
func (db *db2) SupportEngine() bool {
return false
}
func (db *db2) SupportCharset() bool {
return false
}
func (db *db2) IndexOnTable() bool {
return false
}
func (db *db2) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
if tableName == "" {
tableName = table.Name
}
quoter := db.Quoter()
var b strings.Builder
b.WriteString("CREATE TABLE ")
quoter.QuoteTo(&b, tableName)
b.WriteString(" (")
for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if !col.DefaultIsEmpty {
col.Nullable = false
}
s, _ := ColumnString(db, col, false)
b.WriteString(s)
if col.IsAutoIncrement {
b.WriteString(" GENERATED BY DEFAULT AS IDENTITY (START WITH 1, INCREMENT BY 1)")
}
if i != len(table.ColumnsSeq())-1 {
b.WriteString(", ")
}
}
if len(table.PrimaryKeys) > 0 {
b.WriteString(", PRIMARY KEY (")
b.WriteString(quoter.Join(table.PrimaryKeys, ","))
b.WriteString(")")
}
b.WriteString(")")
return []string{b.String()}, true
}
func (db *db2) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
if len(db.uri.Schema) == 0 {
args := []interface{}{tableName, idxName}
return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args
}
args := []interface{}{db.uri.Schema, tableName, idxName}
return `SELECT indexname FROM pg_indexes ` +
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
}
func (db *db2) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy {
case QuotePolicyNone:
var q = oracleQuoter
q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q
case QuotePolicyReserved:
var q = oracleQuoter
q.IsReserved = db.IsReserved
db.quoter = q
case QuotePolicyAlways:
fallthrough
default:
db.quoter = oracleQuoter
}
}
func (db *db2) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
if len(db.uri.Schema) == 0 {
return db.HasRecords(queryer, ctx, `SELECT tabname FROM syscat.tables WHERE tabname = ?`, tableName)
}
return db.HasRecords(queryer, ctx, `SELECT tabname FROM syscat.tables WHERE tabschema = ? AND tabname = ?`,
db.uri.Schema, tableName,
)
}
func (db *db2) ModifyColumnSQL(tableName string, col *schemas.Column) string {
if len(db.uri.Schema) == 0 {
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
tableName, col.Name, db.SQLType(col))
}
return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s",
db.uri.Schema, tableName, col.Name, db.SQLType(col))
}
// DropTableSQL returns drop table SQL
func (db *db2) DropTableSQL(tableName string) (string, bool) {
quote := db.Quoter().Quote
return fmt.Sprintf("DROP TABLE %s", quote(tableName)), false
}
func (db *db2) DropIndexSQL(tableName string, index *schemas.Index) string {
quote := db.Quoter().Quote
idxName := index.Name
tableName = strings.Replace(tableName, `"`, "", -1)
tableName = strings.Replace(tableName, `.`, "_", -1)
if !strings.HasPrefix(idxName, "UQE_") &&
!strings.HasPrefix(idxName, "IDX_") {
if index.Type == schemas.UniqueType {
idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
} else {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
}
}
if db.uri.Schema != "" {
idxName = db.uri.Schema + "." + idxName
}
return fmt.Sprintf("DROP INDEX %v", quote(idxName))
}
func (db *db2) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
args := []interface{}{db.uri.Schema, tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
" AND column_name = $3"
if len(db.uri.Schema) == 0 {
args = []interface{}{tableName, colName}
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2"
}
rows, err := queryer.QueryContext(ctx, query, args...)
if err != nil {
return false, err
}
defer rows.Close()
return rows.Next(), nil
}
func (db *db2) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName}
s := `Select c.colname as column_name,
c.colno as position,
c.typename as data_type,
c.length,
c.scale,
c.remarks as description,
case when c.nulls = 'Y' then 1 else 0 end as nullable,
default as default_value,
case when c.identity ='Y' then 1 else 0 end as is_identity,
case when c.generated ='' then 0 else 1 end as is_computed,
c.text as computed_formula
from syscat.columns c
inner join syscat.tables t on
t.tabschema = c.tabschema and t.tabname = c.tabname
where t.type = 'T' AND c.tabname = ?`
var f string
if len(db.uri.Schema) != 0 {
args = append(args, db.uri.Schema)
f = " AND c.tabschema = ?"
}
s = s + f
rows, err := queryer.QueryContext(ctx, s, args...)
if err != nil {
return nil, nil, err
}
defer rows.Close()
cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0)
for rows.Next() {
col := new(schemas.Column)
col.Indexes = make(map[string]int)
var colName, position, dataType, numericScale string
var description, colDefault, computedFormula, maxLenStr *string
var isComputed bool
err = rows.Scan(&colName, &position, &dataType, &maxLenStr, &numericScale, &description, &col.Nullable, &colDefault, &col.IsPrimaryKey, &isComputed, &computedFormula)
if err != nil {
return nil, nil, err
}
//fmt.Println(colName, position, dataType, maxLenStr, numericScale, description, col.Nullable, colDefault, col.IsPrimaryKey, isComputed, computedFormula)
var maxLen int
if maxLenStr != nil {
maxLen, err = strconv.Atoi(*maxLenStr)
if err != nil {
return nil, nil, err
}
}
col.Name = strings.Trim(colName, `" `)
if colDefault != nil {
col.DefaultIsEmpty = false
col.Default = *colDefault
}
if colDefault != nil && strings.HasPrefix(*colDefault, "nextval(") {
col.IsAutoIncrement = true
}
switch dataType {
case "character", "CHARACTER":
col.SQLType = schemas.SQLType{Name: schemas.Char, DefaultLength: 0, DefaultLength2: 0}
case "timestamp without time zone":
col.SQLType = schemas.SQLType{Name: schemas.DateTime, DefaultLength: 0, DefaultLength2: 0}
case "timestamp with time zone":
col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
case "double precision":
col.SQLType = schemas.SQLType{Name: schemas.Double, DefaultLength: 0, DefaultLength2: 0}
case "boolean":
col.SQLType = schemas.SQLType{Name: schemas.Bool, DefaultLength: 0, DefaultLength2: 0}
case "time without time zone":
col.SQLType = schemas.SQLType{Name: schemas.Time, DefaultLength: 0, DefaultLength2: 0}
case "oid":
col.SQLType = schemas.SQLType{Name: schemas.BigInt, DefaultLength: 0, DefaultLength2: 0}
default:
col.SQLType = schemas.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0}
}
if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, fmt.Errorf("Unknown colType: %v", dataType)
}
col.Length = maxLen
if col.SQLType.IsText() || col.SQLType.IsTime() {
if col.Default != "" {
col.Default = "'" + col.Default + "'"
} else {
if col.DefaultIsEmpty {
col.Default = "''"
}
}
}
cols[col.Name] = col
colSeq = append(colSeq, col.Name)
}
return colSeq, cols, nil
}
func (db *db2) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
args := []interface{}{}
s := "SELECT TABNAME FROM SYSCAT.TABLES WHERE type = 'T' AND OWNERTYPE = 'U'"
if len(db.uri.Schema) != 0 {
args = append(args, db.uri.Schema)
s = s + " AND TABSCHEMA = ?"
}
rows, err := queryer.QueryContext(ctx, s, args...)
if err != nil {
return nil, err
}
defer rows.Close()
tables := make([]*schemas.Table, 0)
for rows.Next() {
table := schemas.NewEmptyTable()
var name string
err = rows.Scan(&name)
if err != nil {
return nil, err
}
table.Name = name
tables = append(tables, table)
}
return tables, nil
}
func (db *db2) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName}
s := fmt.Sprintf(`select uniquerule,
indname as index_name,
replace(substring(colnames,2,length(colnames)),'+',',') as columns
from syscat.indexes WHERE tabname = ?`)
if len(db.uri.Schema) != 0 {
args = append(args, db.uri.Schema)
s = s + " AND tabschema=?"
}
rows, err := queryer.QueryContext(ctx, s, args...)
if err != nil {
return nil, err
}
defer rows.Close()
indexes := make(map[string]*schemas.Index, 0)
for rows.Next() {
var indexTypeName, indexName, columns string
/*when 'P' then 'Primary key'
when 'U' then 'Unique'
when 'D' then 'Nonunique'*/
err = rows.Scan(&indexTypeName, &indexName, &columns)
if err != nil {
return nil, err
}
indexName = strings.Trim(indexName, `" `)
if strings.HasSuffix(indexName, "_pkey") {
continue
}
var indexType int
if strings.EqualFold(indexTypeName, "U") {
indexType = schemas.UniqueType
} else if strings.EqualFold(indexTypeName, "D") {
indexType = schemas.IndexType
}
var isRegular bool
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
newIdxName := indexName[5+len(tableName):]
isRegular = true
if newIdxName != "" {
indexName = newIdxName
}
}
index := &schemas.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
colNames := strings.Split(columns, ",")
for _, colName := range colNames {
index.Cols = append(index.Cols, strings.Trim(colName, `" `))
}
index.IsRegular = isRegular
indexes[index.Name] = index
}
return indexes, nil
}
func (db *db2) Filters() []Filter {
return []Filter{}
}
type db2Driver struct {
baseDriver
}
func (p *db2Driver) Features() *DriverFeatures {
return &DriverFeatures{
SupportReturnInsertedID: false,
}
}
func (g *db2Driver) GenScanResult(colType string) (interface{}, error) {
switch colType {
case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB":
var s sql.NullString
return &s, nil
case "NUMBER":
var s sql.NullString
return &s, nil
case "DATE":
var s sql.NullTime
return &s, nil
case "BLOB":
var r sql.RawBytes
return &r, nil
default:
var r sql.RawBytes
return &r, nil
}
}
func (p *db2Driver) Parse(driverName, dataSourceName string) (*URI, error) {
var dbName string
var defaultSchema string
kv := strings.Split(dataSourceName, ";")
for _, c := range kv {
vv := strings.SplitN(strings.TrimSpace(c), "=", 2)
if len(vv) == 2 {
switch strings.ToLower(vv[0]) {
case "database":
dbName = vv[1]
case "uid":
defaultSchema = vv[1]
}
}
}
if dbName == "" {
return nil, errors.New("no db name provided")
}
return &URI{
DBName: dbName,
DBType: "db2",
Schema: defaultSchema,
}, nil
}

View File

@ -32,17 +32,21 @@ type URI struct {
// SetSchema set schema
func (uri *URI) SetSchema(schema string) {
// hack me
if uri.DBType == schemas.POSTGRES {
uri.Schema = strings.TrimSpace(schema)
}
}
type DialectFeatures struct {
DefaultClause string // default key word
}
// Dialect represents a kind of database
type Dialect interface {
Init(*URI) error
URI() *URI
Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error)
Features() *DialectFeatures
SQLType(*schemas.Column) string
Alias(string) string // return what a sql type's alias of
@ -103,6 +107,12 @@ func (db *Base) URI() *URI {
return db.uri
}
func (db *Base) Features() *DialectFeatures {
return &DialectFeatures{
DefaultClause: "DEFAULT",
}
}
// DropTableSQL returns drop table SQL
func (db *Base) DropTableSQL(tableName string) (string, bool) {
quote := db.dialect.Quoter().Quote
@ -211,16 +221,17 @@ func regDrvsNDialects() bool {
getDriver func() Driver
getDialect func() Dialect
}{
"mssql": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }},
"odbc": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access
"mysql": {"mysql", func() Driver { return &mysqlDriver{} }, func() Dialect { return &mysql{} }},
"mymysql": {"mysql", func() Driver { return &mymysqlDriver{} }, func() Dialect { return &mysql{} }},
"postgres": {"postgres", func() Driver { return &pqDriver{} }, func() Dialect { return &postgres{} }},
"pgx": {"postgres", func() Driver { return &pqDriverPgx{} }, func() Dialect { return &postgres{} }},
"sqlite3": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }},
"sqlite": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }},
"oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }},
"godror": {"oracle", func() Driver { return &godrorDriver{} }, func() Dialect { return &oracle{} }},
"mssql": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }},
"odbc": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access
"mysql": {"mysql", func() Driver { return &mysqlDriver{} }, func() Dialect { return &mysql{} }},
"mymysql": {"mysql", func() Driver { return &mymysqlDriver{} }, func() Dialect { return &mysql{} }},
"postgres": {"postgres", func() Driver { return &pqDriver{} }, func() Dialect { return &postgres{} }},
"pgx": {"postgres", func() Driver { return &pqDriverPgx{} }, func() Dialect { return &postgres{} }},
"sqlite3": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }},
"sqlite": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }},
"oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }},
"godror": {"oracle", func() Driver { return &godrorDriver{} }, func() Dialect { return &oracle{} }},
"go_ibm_db": {"db2", func() Driver { return &db2Driver{} }, func() Dialect { return &db2{} }},
}
for driverName, v := range providedDrvsNDialects {
@ -252,43 +263,46 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool)
return "", err
}
if err := bd.WriteByte(' '); err != nil {
return "", err
}
if includePrimaryKey && col.IsPrimaryKey {
if _, err := bd.WriteString("PRIMARY KEY "); err != nil {
if _, err := bd.WriteString(" PRIMARY KEY"); err != nil {
return "", err
}
if col.IsAutoIncrement {
if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil {
if err := bd.WriteByte(' '); err != nil {
return "", err
}
if err := bd.WriteByte(' '); err != nil {
if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil {
return "", err
}
}
}
if col.Default != "" {
if _, err := bd.WriteString("DEFAULT "); err != nil {
if err := bd.WriteByte(' '); err != nil {
return "", err
}
if _, err := bd.WriteString(col.Default); err != nil {
if _, err := bd.WriteString(dialect.Features().DefaultClause); err != nil {
return "", err
}
if err := bd.WriteByte(' '); err != nil {
return "", err
}
if _, err := bd.WriteString(col.Default); err != nil {
return "", err
}
}
if err := bd.WriteByte(' '); err != nil {
return "", err
}
if col.Nullable {
if _, err := bd.WriteString("NULL "); err != nil {
if _, err := bd.WriteString("NULL"); err != nil {
return "", err
}
} else {
if _, err := bd.WriteString("NOT NULL "); err != nil {
if _, err := bd.WriteString("NOT NULL"); err != nil {
return "", err
}
}

View File

@ -966,38 +966,35 @@ func (db *postgres) AutoIncrStr() string {
}
func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}
quoter := db.Quoter()
sql += quoter.Quote(tableName)
sql += " ("
var b strings.Builder
b.WriteString("CREATE TABLE IF NOT EXIST ")
quoter.QuoteTo(&b, tableName)
b.WriteString(" (")
if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys
for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1)
b.WriteString(s)
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1)
sql += s
sql = strings.TrimSpace(sql)
sql += ", "
if len(table.PrimaryKeys) > 1 {
b.WriteString("PRIMARY KEY ( ")
b.WriteString(quoter.Join(table.PrimaryKeys, ","))
b.WriteString(" )")
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += quoter.Join(pkList, ",")
sql += " ), "
if i != len(table.ColumnsSeq())-1 {
b.WriteString(", ")
}
sql = sql[:len(sql)-2]
}
sql += ")"
return []string{sql}, true
b.WriteString(")")
return []string{b.String()}, false
}
func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {

1
go.mod
View File

@ -7,6 +7,7 @@ require (
github.com/go-sql-driver/mysql v1.6.0
github.com/goccy/go-json v0.7.4
github.com/jackc/pgx/v4 v4.12.0
github.com/ibmdb/go_ibm_db v0.4.1
github.com/json-iterator/go v1.1.11
github.com/lib/pq v1.10.2
github.com/mattn/go-sqlite3 v1.14.8

6
go.sum
View File

@ -203,6 +203,12 @@ github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/ibmdb/go_ibm_db v0.1.0 h1:Ok7W7wysBUa8eyVYxWLS5vIA0VomTsurK57l5Rah1M8=
github.com/ibmdb/go_ibm_db v0.1.0/go.mod h1:nl5aUh1IzBVExcqYXaZLApaq8RUvTEph3VP49UTmEvg=
github.com/ibmdb/go_ibm_db v0.3.0 h1:KCSVFS9eXmlTEFL8ScyROsYWmP02G3eGce7VRAt4Csk=
github.com/ibmdb/go_ibm_db v0.3.0/go.mod h1:nl5aUh1IzBVExcqYXaZLApaq8RUvTEph3VP49UTmEvg=
github.com/ibmdb/go_ibm_db v0.4.1 h1:IYZqoKTzD9xtkzLIkp8u6zzg7/4v7nFOfHzF79agvak=
github.com/ibmdb/go_ibm_db v0.4.1/go.mod h1:nl5aUh1IzBVExcqYXaZLApaq8RUvTEph3VP49UTmEvg=
github.com/json-iterator/go v1.1.11 h1:uVUAXhF2To8cbw/3xN3pxj6kk7TYKs98NIrTqPlMWAQ=
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=

View File

@ -62,7 +62,7 @@ func TestCacheFind(t *testing.T) {
}
boxes = make([]MailBox, 0, 2)
assert.NoError(t, testEngine.Alias("a").Where("a.id > -1").Asc("a.id").Find(&boxes))
assert.NoError(t, testEngine.Alias("a").Where("`a`.`id` > -1").Asc("a.id").Find(&boxes))
assert.EqualValues(t, 2, len(boxes))
for i, box := range boxes {
assert.Equal(t, inserts[i].Id, box.Id)
@ -77,7 +77,7 @@ func TestCacheFind(t *testing.T) {
}
boxes2 := make([]MailBox4, 0, 2)
assert.NoError(t, testEngine.Table("mail_box").Where("mail_box.id > -1").Asc("mail_box.id").Find(&boxes2))
assert.NoError(t, testEngine.Table("mail_box").Where("`mail_box`.`id` > -1").Asc("mail_box.id").Find(&boxes2))
assert.EqualValues(t, 2, len(boxes2))
for i, box := range boxes2 {
assert.Equal(t, inserts[i].Id, box.Id)
@ -164,14 +164,14 @@ func TestCacheGet(t *testing.T) {
assert.NoError(t, err)
var box1 MailBox3
has, err := testEngine.Where("id = ?", inserts[0].Id).Get(&box1)
has, err := testEngine.Where("`id` = ?", inserts[0].Id).Get(&box1)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "user1", box1.Username)
assert.EqualValues(t, "pass1", box1.Password)
var box2 MailBox3
has, err = testEngine.Where("id = ?", inserts[0].Id).Get(&box2)
has, err = testEngine.Where("`id` = ?", inserts[0].Id).Get(&box2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "user1", box2.Username)

View File

@ -0,0 +1,11 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build db2
package integrations
import (
_ "github.com/ibmdb/go_ibm_db"
)

View File

@ -126,7 +126,7 @@ func TestDump(t *testing.T) {
assert.NoError(t, err)
assert.NoError(t, sess.Commit())
for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL} {
for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL, schemas.DB2} {
name := fmt.Sprintf("dump_%v.sql", tp)
t.Run(name, func(t *testing.T) {
assert.NoError(t, testEngine.DumpAllToFile(name, tp))
@ -169,7 +169,7 @@ func TestDumpTables(t *testing.T) {
assert.NoError(t, err)
assert.NoError(t, sess.Commit())
for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL} {
for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL, schemas.DB2} {
name := fmt.Sprintf("dump_%v-table.sql", tp)
t.Run(name, func(t *testing.T) {
assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp))

View File

@ -52,11 +52,11 @@ func TestSetExpr(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
tableName := testEngine.TableName(new(UserExprIssue), true)
tableName := testEngine.Quote(testEngine.TableName(new(UserExprIssue), true))
cnt, err = testEngine.SetExpr("issue_id",
builder.Select("id").
builder.Select("`id`").
From(tableName).
Where(builder.Eq{"id": issue.Id})).
Where(builder.Eq{"`id`": issue.Id})).
ID(1).
Update(new(UserExpr))
assert.NoError(t, err)

View File

@ -37,49 +37,49 @@ func TestBuilder(t *testing.T) {
assert.NoError(t, err)
var cond Condition
has, err := testEngine.Where(builder.Eq{"col_name": "col1"}).Get(&cond)
has, err := testEngine.Where(builder.Eq{"`col_name`": "col1"}).Get(&cond)
assert.NoError(t, err)
assert.Equal(t, true, has, "records should exist")
has, err = testEngine.Where(builder.Eq{"col_name": "col1"}.
And(builder.Eq{"op": OpEqual})).
has, err = testEngine.Where(builder.Eq{"`col_name`": "col1"}.
And(builder.Eq{"`op`": OpEqual})).
NoAutoCondition().
Get(&cond)
assert.NoError(t, err)
assert.Equal(t, true, has, "records should exist")
has, err = testEngine.Where(builder.Eq{"col_name": "col1", "op": OpEqual, "value": "1"}).
has, err = testEngine.Where(builder.Eq{"`col_name`": "col1", "`op`": OpEqual, "`value`": "1"}).
NoAutoCondition().
Get(&cond)
assert.NoError(t, err)
assert.Equal(t, true, has, "records should exist")
has, err = testEngine.Where(builder.Eq{"col_name": "col1"}.
And(builder.Neq{"op": OpEqual})).
has, err = testEngine.Where(builder.Eq{"`col_name`": "col1"}.
And(builder.Neq{"`op`": OpEqual})).
NoAutoCondition().
Get(&cond)
assert.NoError(t, err)
assert.Equal(t, false, has, "records should not exist")
var conds []Condition
err = testEngine.Where(builder.Eq{"col_name": "col1"}.
And(builder.Eq{"op": OpEqual})).
err = testEngine.Where(builder.Eq{"`col_name`": "col1"}.
And(builder.Eq{"`op`": OpEqual})).
Find(&conds)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist")
conds = make([]Condition, 0)
err = testEngine.Where(builder.Like{"col_name", "col"}).Find(&conds)
err = testEngine.Where(builder.Like{"`col_name`", "col"}).Find(&conds)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist")
conds = make([]Condition, 0)
err = testEngine.Where(builder.Expr("col_name = ?", "col1")).Find(&conds)
err = testEngine.Where(builder.Expr("`col_name` = ?", "col1")).Find(&conds)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist")
conds = make([]Condition, 0)
err = testEngine.Where(builder.In("col_name", "col1", "col2")).Find(&conds)
err = testEngine.Where(builder.In("`col_name`", "col1", "col2")).Find(&conds)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist")
@ -91,8 +91,8 @@ func TestBuilder(t *testing.T) {
// complex condtions
var where = builder.NewCond()
if true {
where = where.And(builder.Eq{"col_name": "col1"})
where = where.Or(builder.And(builder.In("col_name", "col1", "col2"), builder.Expr("col_name = ?", "col1")))
where = where.And(builder.Eq{"`col_name`": "col1"})
where = where.Or(builder.And(builder.In("`col_name`", "col1", "col2"), builder.Expr("`col_name` = ?", "col1")))
}
conds = make([]Condition, 0)
@ -215,7 +215,7 @@ func TestFindAndCount(t *testing.T) {
assert.NoError(t, err)
var results []FindAndCount
sess := testEngine.Where("name = ?", "test1")
sess := testEngine.Where("`name` = ?", "test1")
conds := sess.Conds()
err = sess.Find(&results)
assert.NoError(t, err)

View File

@ -63,7 +63,7 @@ func TestSQLCount(t *testing.T) {
assertSync(t, new(UserinfoCount2), new(UserinfoBooks))
total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableName("userinfo_count2", true)).
total, err := testEngine.SQL("SELECT count(`id`) FROM " + testEngine.Quote(testEngine.TableName("userinfo_count2", true))).
Count()
assert.NoError(t, err)
assert.EqualValues(t, 0, total)
@ -89,7 +89,7 @@ func TestCountWithOthers(t *testing.T) {
})
assert.NoError(t, err)
total, err := testEngine.OrderBy("id desc").Limit(1).Count(new(CountWithOthers))
total, err := testEngine.OrderBy("`id` desc").Limit(1).Count(new(CountWithOthers))
assert.NoError(t, err)
assert.EqualValues(t, 2, total)
}
@ -118,11 +118,11 @@ func TestWithTableName(t *testing.T) {
})
assert.NoError(t, err)
total, err := testEngine.OrderBy("id desc").Count(new(CountWithTableName))
total, err := testEngine.OrderBy("`id` desc").Count(new(CountWithTableName))
assert.NoError(t, err)
assert.EqualValues(t, 2, total)
total, err = testEngine.OrderBy("id desc").Count(CountWithTableName{})
total, err = testEngine.OrderBy("`id` desc").Count(CountWithTableName{})
assert.NoError(t, err)
assert.EqualValues(t, 2, total)
}
@ -146,7 +146,7 @@ func TestCountWithSelectCols(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 2, total)
total, err = testEngine.Select("count(id)").Count(CountWithTableName{})
total, err = testEngine.Select("count(`id`)").Count(CountWithTableName{})
assert.NoError(t, err)
assert.EqualValues(t, 2, total)
}
@ -166,7 +166,7 @@ func TestCountWithGroupBy(t *testing.T) {
})
assert.NoError(t, err)
cnt, err := testEngine.GroupBy("name").Count(new(CountWithTableName))
cnt, err := testEngine.GroupBy("`name`").Count(new(CountWithTableName))
assert.NoError(t, err)
assert.EqualValues(t, 2, cnt)
}

View File

@ -48,19 +48,19 @@ func TestExistStruct(t *testing.T) {
assert.NoError(t, err)
assert.False(t, has)
has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{})
has, err = testEngine.Where("`name` = ?", "test1").Exist(&RecordExist{})
assert.NoError(t, err)
assert.True(t, has)
has, err = testEngine.Where("name = ?", "test2").Exist(&RecordExist{})
has, err = testEngine.Where("`name` = ?", "test2").Exist(&RecordExist{})
assert.NoError(t, err)
assert.False(t, has)
has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test1").Exist()
has, err = testEngine.SQL("select * from "+testEngine.Quote(testEngine.TableName("record_exist", true))+" where `name` = ?", "test1").Exist()
assert.NoError(t, err)
assert.True(t, has)
has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test2").Exist()
has, err = testEngine.SQL("select * from "+testEngine.Quote(testEngine.TableName("record_exist", true))+" where `name` = ?", "test2").Exist()
assert.NoError(t, err)
assert.False(t, has)
@ -68,11 +68,11 @@ func TestExistStruct(t *testing.T) {
assert.NoError(t, err)
assert.True(t, has)
has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist()
has, err = testEngine.Table("record_exist").Where("`name` = ?", "test1").Exist()
assert.NoError(t, err)
assert.True(t, has)
has, err = testEngine.Table("record_exist").Where("name = ?", "test2").Exist()
has, err = testEngine.Table("record_exist").Where("`name` = ?", "test2").Exist()
assert.NoError(t, err)
assert.False(t, has)
@ -124,43 +124,43 @@ func TestExistStructForJoin(t *testing.T) {
defer session.Close()
session.Table("number").
Join("INNER", "order_list", "order_list.id = number.lid").
Join("LEFT", "player", "player.id = order_list.eid").
Where("number.lid = ?", 1)
Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`").
Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`").
Where("`number`.`lid` = ?", 1)
has, err := session.Exist()
assert.NoError(t, err)
assert.True(t, has)
session.Table("number").
Join("INNER", "order_list", "order_list.id = number.lid").
Join("LEFT", "player", "player.id = order_list.eid").
Where("number.lid = ?", 2)
Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`").
Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`").
Where("`number`.`lid` = ?", 2)
has, err = session.Exist()
assert.NoError(t, err)
assert.False(t, has)
session.Table("number").
Select("order_list.id").
Join("INNER", "order_list", "order_list.id = number.lid").
Join("LEFT", "player", "player.id = order_list.eid").
Where("order_list.id = ?", 1)
Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`").
Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`").
Where("`order_list`.`id` = ?", 1)
has, err = session.Exist()
assert.NoError(t, err)
assert.True(t, has)
session.Table("number").
Select("player.id").
Join("INNER", "order_list", "order_list.id = number.lid").
Join("LEFT", "player", "player.id = order_list.eid").
Where("player.id = ?", 2)
Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`").
Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`").
Where("`player`.`id` = ?", 2)
has, err = session.Exist()
assert.NoError(t, err)
assert.False(t, has)
session.Table("number").
Select("player.id").
Join("INNER", "order_list", "order_list.id = number.lid").
Join("LEFT", "player", "player.id = order_list.eid")
Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`").
Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`")
has, err = session.Exist()
assert.NoError(t, err)
assert.True(t, has)
@ -174,15 +174,15 @@ func TestExistStructForJoin(t *testing.T) {
session.Table("number").
Select("player.id").
Join("INNER", "order_list", "order_list.id = number.lid").
Join("LEFT", "player", "player.id = order_list.eid")
Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`").
Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`")
has, err = session.Exist()
assert.Error(t, err)
assert.False(t, has)
session.Table("number").
Select("player.id").
Join("LEFT", "player", "player.id = number.lid")
Join("LEFT", "player", "`player`.`id` = `number`.`lid`")
has, err = session.Exist()
assert.NoError(t, err)
assert.True(t, has)

View File

@ -56,8 +56,8 @@ func TestJoinLimit(t *testing.T) {
var salaries []Salary
err = testEngine.Table("salary").
Join("INNER", "check_list", "check_list.id = salary.lid").
Join("LEFT", "empsetting", "empsetting.id = check_list.eid").
Join("INNER", "check_list", "`check_list`.`id` = `salary`.`lid`").
Join("LEFT", "empsetting", "`empsetting`.`id` = `check_list`.`eid`").
Limit(10, 0).
Find(&salaries)
assert.NoError(t, err)
@ -69,10 +69,10 @@ func TestWhere(t *testing.T) {
assertSync(t, new(Userinfo))
users := make([]Userinfo, 0)
err := testEngine.Where("id > ?", 2).Find(&users)
err := testEngine.Where("`id` > ?", 2).Find(&users)
assert.NoError(t, err)
err = testEngine.Where("id > ?", 2).And("id < ?", 10).Find(&users)
err = testEngine.Where("`id` > ?", 2).And("`id` < ?", 10).Find(&users)
assert.NoError(t, err)
}

View File

@ -6,9 +6,15 @@ package utils
import (
"fmt"
"strings"
)
// IndexName returns index name
func IndexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
}
// SeqName returns sequence name for some table
func SeqName(tableName string) string {
return "SEQ_" + strings.ToUpper(tableName)
}

View File

@ -20,3 +20,12 @@ func SliceEq(left, right []string) bool {
}
return true
}
func IndexSlice(s []string, c string) int {
for i, ss := range s {
if c == ss {
return i
}
}
return -1
}

View File

@ -22,6 +22,7 @@ const (
MYSQL DBType = "mysql"
MSSQL DBType = "mssql"
ORACLE DBType = "oracle"
DB2 DBType = "db2"
)
// SQLType represents SQL types

View File

@ -307,16 +307,59 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
// if there is auto increment column and driver don't support return it
if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID {
var sql = sqlStr
if session.engine.dialect.URI().DBType == schemas.ORACLE {
sql = "select seq_atable.currval from dual"
var sql string
var newArgs []interface{}
var needCommit bool
var id int64
var i = utils.IndexSlice(colNames, table.AutoIncrement)
if i > -1 {
id, err = convert.AsInt64(args[i])
if err != nil {
return 0, err
}
}
if session.engine.dialect.URI().DBType == schemas.DB2 || session.engine.dialect.URI().DBType == schemas.ORACLE {
if id == 0 && session.isAutoCommit { // if it's not in transaction
if err := session.Begin(); err != nil {
return 0, err
}
needCommit = true
}
_, err := session.exec(sqlStr, args...)
if err != nil {
if needCommit {
session.Rollback()
}
return 0, err
}
if session.engine.dialect.URI().DBType == schemas.ORACLE {
sql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName))
} else if session.engine.dialect.URI().DBType == schemas.DB2 {
sql = "select IDENTITY_VAL_LOCAL() as id FROM sysibm.sysdummy1"
}
} else {
sql = sqlStr
newArgs = args
}
rows, err := session.queryRows(sql, args...)
if err != nil {
return 0, err
if id == 0 {
err := session.queryRow(sql, newArgs...).Scan(&id)
if err != nil {
if needCommit {
session.Rollback()
}
return 0, err
}
if needCommit {
if err := session.Commit(); err != nil {
return 0, err
}
}
if id == 0 {
return 0, errors.New("insert successfully but not returned id")
}
}
defer rows.Close()
defer handleAfterInsertProcessorFunc(bean)
@ -331,16 +374,6 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
}
}
var id int64
if !rows.Next() {
if rows.Err() != nil {
return 0, rows.Err()
}
return 0, errors.New("insert successfully but not returned id")
}
if err := rows.Scan(&id); err != nil {
return 1, err
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)

View File

@ -235,6 +235,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx)
if err != nil {
fmt.Println("------", tables, err)
return err
}
@ -244,6 +245,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
session.resetStatement()
}()
fmt.Println("-----", tables, len(tables), len(beans))
for _, bean := range beans {
v := utils.ReflectValue(bean)
table, err := engine.tagParser.ParseWithCache(v)
@ -260,6 +263,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
var oriTable *schemas.Table
for _, tb := range tables {
fmt.Println("----", tb.Name, engine.tbNameWithSchema(tb.Name), "===", tbName, engine.tbNameWithSchema(tbName))
if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
oriTable = tb
break