Add oracle tests #1463

Open
lunny wants to merge 24 commits from lunny/test_oracle2 into master
  1. 39
      .drone.yml
  2. 50
      Makefile
  3. 1
      core/db_test.go
  4. 1
      core/interface.go
  5. 1
      dialects/dialect.go
  6. 184
      dialects/oracle.go
  7. 6
      dialects/time.go
  8. 2
      go.mod
  9. 8
      go.sum
  10. 7
      integrations/engine_test.go
  11. 11
      integrations/oracle_test.go
  12. 8
      integrations/processors_test.go
  13. 2
      integrations/session_cond_test.go
  14. 60
      internal/statements/insert.go
  15. 4
      internal/statements/statement.go
  16. 24
      session_insert.go
  17. 23
      session_schema.go
  18. 6
      session_update.go

39
.drone.yml

@ -398,6 +398,44 @@ services:
# commands:
# - /bin/bash /startDm.sh
---
kind: pipeline
name: test-oracle
depends_on:
- test-cockroach
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-oracle
pull: never
image: golang:1.15
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_ORACLE_HOST: oracle:1521
TEST_ORACLE_DBNAME: xe
TEST_ORACLE_USERNAME: system
TEST_ORACLE_PASSWORD: oracle
TEST_CACHE_ENABLE: false
commands:
- make test-oracle
- TEST_ORACLE_SCHEMA=xorm make test-oracle
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: oracle
image: oracleinanutshell/oracle-xe-11g:latest
environment:
ORACLE_ALLOW_REMOTE: true
ORACLE_ENABLE_XDB: true
---
kind: pipeline
name: merge_coverage
@ -410,6 +448,7 @@ depends_on:
- test-tidb
- test-cockroach
#- test-dameng
- test-oracle
trigger:
ref:
- refs/heads/master

50
Makefile

@ -32,6 +32,12 @@ TEST_MYSQL_DBNAME ?= xorm_test
TEST_MYSQL_USERNAME ?= root
TEST_MYSQL_PASSWORD ?=
TEST_ORACLE_HOST ?= oracle:1521
TEST_ORACLE_SCHEMA ?=
TEST_ORACLE_DBNAME ?= xe
TEST_ORACLE_USERNAME ?= system
TEST_ORACLE_PASSWORD ?= oracle
TEST_PGSQL_HOST ?= pgsql:5432
TEST_PGSQL_SCHEMA ?=
TEST_PGSQL_DBNAME ?= xorm_test
@ -134,6 +140,23 @@ misspell-check:
fi
misspell -error -i unknwon,destory $(GOFILES)
.PHONY: install-instant-client
install-instant-client:
ifeq ("$(PKG_CONFIG_PATH)", )
wget https://download.oracle.com/otn_software/linux/instantclient/19600/instantclient-basic-linux.x64-19.6.0.0.0dbru.zip
unzip instantclient-basic-linux.x64-19.6.0.0.0dbru.zip -d /usr/local/instantclient
echo "prefixdir=/usr/local/instantclient
libdir=${prefixdir}
includedir=${prefixdir}/sdk/include
Name: OCI
Description: Oracle database driver
Version: 11.2
Libs: -L${libdir} -lclntsh
Cflags: -I${includedir}" > /usr/local/instantclient/oci8.pc
export PKG_CONFIG_PATH=/usr/local/instantclient/oci8.pc
endif
.PHONY: test
test: go-check
$(GO) test $(PACKAGES)
@ -190,6 +213,33 @@ test-mysql\#%: go-check
-conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \
-coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-oracle
test-oracle: test-godror
.PNONY: test-oci8
test-oci8: go-check install-instant-client
$(GO) test $(INTEGRATION_PACKAGES) -v -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-oci8\#%
test-oci8\#%: go-check install-instant-client
$(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)" \
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-godror
test-godror: go-check install-instant-client
$(GO) test $(INTEGRATION_PACKAGES) -v -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-godror\#%
test-godror\#%: go-check install-instant-client
$(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PNONY: test-postgres
test-postgres: go-check
$(GO) test $(INTEGRATION_PACKAGES) -v -race -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \

1
core/db_test.go

@ -96,7 +96,6 @@ func BenchmarkOriQuery(b *testing.B) {
if err != nil {
b.Error(err)
}
//fmt.Println(Id, Name, Title, Age, Alias, NickName)
}
rows.Close()
}

1
core/interface.go

@ -7,6 +7,7 @@ import (
// Queryer represents an interface to query a SQL to get data from database
type Queryer interface {
QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row
QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error)
}

1
dialects/dialect.go

@ -43,6 +43,7 @@ const (
SequenceAutoincrMode
)
// DialectFeatures represents the features that the dialect supports
type DialectFeatures struct {
AutoincrMode int // 0 autoincrement column, 1 sequence
}

184
dialects/oracle.go

@ -9,11 +9,13 @@ import (
"database/sql"
"errors"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"xorm.io/xorm/core"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
@ -548,14 +550,18 @@ func (db *oracle) Features() *DialectFeatures {
func (db *oracle) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {
case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt, schemas.Bool, schemas.Serial, schemas.BigSerial:
case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt,
schemas.UnsignedBigInt, schemas.UnsignedBit, schemas.UnsignedInt,
schemas.Bool,
schemas.Serial, schemas.BigSerial:
res = "NUMBER"
case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
return schemas.Blob
case schemas.Time, schemas.DateTime, schemas.TimeStamp:
res = schemas.TimeStamp
case schemas.Date, schemas.Time, schemas.DateTime, schemas.TimeStamp:
res = schemas.Date
return res
case schemas.TimeStampz:
res = "TIMESTAMP WITH TIME ZONE"
res = "TIMESTAMP"
case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal:
res = "NUMBER"
case schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json:
@ -592,6 +598,14 @@ func (db *oracle) ColumnTypeKind(t string) int {
}
}
func (db *oracle) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) {
var cnt int
if err := queryer.QueryRowContext(ctx, "SELECT COUNT(*) FROM user_sequences WHERE sequence_name = :1", seqName).Scan(&cnt); err != nil {
return false, err
}
return cnt > 0, nil
}
func (db *oracle) AutoIncrStr() string {
return "AUTO_INCREMENT"
}
@ -602,40 +616,42 @@ func (db *oracle) IsReserved(name string) bool {
}
func (db *oracle) DropTableSQL(tableName string) (string, bool) {
return fmt.Sprintf("DROP TABLE `%s`", tableName), false
return fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)), false
}
func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
var sql = "CREATE TABLE "
if tableName == "" {
tableName = table.Name
}
quoter := db.Quoter()
sql += quoter.Quote(tableName) + " ("
var b strings.Builder
b.WriteString("CREATE TABLE ")
quoter.QuoteTo(&b, tableName)
b.WriteString(" (")
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
/*if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect)
} else {*/
s, _ := ColumnString(db, col, false)
sql += s
// }
sql = strings.TrimSpace(sql)
sql += ", "
b.WriteString(s)
if i != len(table.ColumnsSeq())-1 {
b.WriteString(", ")
}
}
if len(pkList) > 0 {
sql += "PRIMARY KEY ( "
sql += quoter.Join(pkList, ",")
sql += " ), "
if len(table.ColumnsSeq()) > 0 {
b.WriteString(", ")
}
b.WriteString("PRIMARY KEY (")
quoter.JoinWrite(&b, pkList, ",")
b.WriteString(")")
}
b.WriteString(")")
sql = sql[:len(sql)-2] + ")"
return sql, false, nil
return b.String(), false, nil
}
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
@ -673,11 +689,30 @@ func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, table
}
func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName}
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
//s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
// "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
s := `select column_name from user_cons_columns
where constraint_name = (select constraint_name from user_constraints
where table_name = :1 and constraint_type ='P')`
var pkName string
err := queryer.QueryRowContext(ctx, s, tableName).Scan(&pkName)
if err != nil {
if err == sql.ErrNoRows {
err = nil
}
return nil, nil, err
}
rows, err := queryer.QueryContext(ctx, s, args...)
s = `SELECT USER_TAB_COLS.COLUMN_NAME, USER_TAB_COLS.DATA_DEFAULT, USER_TAB_COLS.DATA_TYPE, USER_TAB_COLS.DATA_LENGTH,
USER_TAB_COLS.data_precision, USER_TAB_COLS.data_scale, USER_TAB_COLS.NULLABLE,
user_col_comments.comments
FROM USER_TAB_COLS
LEFT JOIN user_col_comments on user_col_comments.TABLE_NAME=USER_TAB_COLS.TABLE_NAME
AND user_col_comments.COLUMN_NAME=USER_TAB_COLS.COLUMN_NAME
WHERE USER_TAB_COLS.table_name = :1`
rows, err := queryer.QueryContext(ctx, s, tableName)
if err != nil {
return nil, nil, err
}
@ -689,11 +724,11 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
col := new(schemas.Column)
col.Indexes = make(map[string]int)
var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string
var colName, colDefault, nullable, dataType, dataPrecision, dataScale, comment *string
var dataLen int
err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision,
&dataScale, &nullable)
&dataScale, &nullable, &comment)
if err != nil {
return nil, nil, err
}
@ -710,10 +745,26 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
col.Nullable = false
}
var ignore bool
if comment != nil {
col.Comment = *comment
}
if pkName != "" && pkName == col.Name {
col.IsPrimaryKey = true
var dt string
var len1, len2 int
has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", utils.SeqName(tableName))
if err != nil {
return nil, nil, err
}
if has {
col.IsAutoIncrement = true
}
}
var (
ignore bool
dt string
len1, len2 int
)
dts := strings.Split(*dataType, "(")
dt = dts[0]
if len(dts) > 1 {
@ -735,7 +786,7 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
case "NUMBER":
col.SQLType = schemas.SQLType{Name: schemas.Double, DefaultLength: len1, DefaultLength2: len2}
case "LONG", "LONG RAW":
case "LONG", "LONG RAW", "NCLOB", "CLOB":
col.SQLType = schemas.SQLType{Name: schemas.Text, DefaultLength: 0, DefaultLength2: 0}
case "RAW":
col.SQLType = schemas.SQLType{Name: schemas.Binary, DefaultLength: 0, DefaultLength2: 0}
@ -752,7 +803,7 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
}
if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, fmt.Errorf("Unknown colType %v %v", *dataType, col.SQLType)
return nil, nil, fmt.Errorf("unknown colType %v %v", *dataType, col.SQLType)
}
col.Length = dataLen
@ -773,8 +824,8 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
}
func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
args := []interface{}{}
s := "SELECT table_name FROM user_tables"
s := "SELECT table_name FROM user_tables WHERE TABLESPACE_NAME = :1 AND table_name NOT LIKE :2"
args := []interface{}{strings.ToUpper(db.uri.User), "%$%"}
rows, err := queryer.QueryContext(ctx, s, args...)
if err != nil {
@ -856,6 +907,7 @@ func (db *oracle) Filters() []Filter {
}
}
// https://github.com/godror/godror
type godrorDriver struct {
baseDriver
}
@ -866,22 +918,49 @@ func (g *godrorDriver) Features() *DriverFeatures {
}
}
func (g *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) {
func parseNoProtocol(driverName, dataSourceName string) (*URI, error) {
db := &URI{DBType: schemas.ORACLE}
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
`^(?P<user>.*)\/(?P<password>.*)@` + // user:password@
`(?P<net>.*)` + // ip:port
`\/(?P<dbname>.*)`) // dbname
matches := dsnPattern.FindStringSubmatch(dataSourceName)
// tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames()
for i, match := range matches {
if names[i] == "dbname" {
switch names[i] {
case "dbname":
db.DBName = match
}
}
if db.DBName == "" && len(matches) != 0 {
return nil, errors.New("dbname is empty")
}
return db, nil
}
func parseOracle(driverName, dataSourceName string) (*URI, error) {
var connStr = dataSourceName
if !strings.HasPrefix(connStr, "oracle://") {
return parseNoProtocol(driverName, dataSourceName)
}
u, err := url.Parse(connStr)
if err != nil {
return nil, err
}
db := &URI{
DBType: schemas.ORACLE,
Host: u.Hostname(),
Port: u.Port(),
DBName: strings.TrimLeft(u.RequestURI(), "/"),
}
if u.User != nil {
db.User = u.User.Username()
db.Passwd, _ = u.User.Password()
}
if db.DBName == "" {
return nil, errors.New("dbname is empty")
}
@ -908,27 +987,12 @@ func (g *godrorDriver) GenScanResult(colType string) (interface{}, error) {
}
}
type oci8Driver struct {
godrorDriver
func (g *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) {
return parseOracle(driverName, dataSourceName)
}
// dataSourceName=user/password@ipv4:port/dbname
// dataSourceName=user/password@[ipv6]:port/dbname
func (o *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &URI{DBType: schemas.ORACLE}
dsnPattern := regexp.MustCompile(
`^(?P<user>.*)\/(?P<password>.*)@` + // user:password@
`(?P<net>.*)` + // ip:port
`\/(?P<dbname>.*)`) // dbname
matches := dsnPattern.FindStringSubmatch(dataSourceName)
names := dsnPattern.SubexpNames()
for i, match := range matches {
if names[i] == "dbname" {
db.DBName = match
}
}
if db.DBName == "" && len(matches) != 0 {
return nil, errors.New("dbname is empty")
}
return db, nil
type oci8Driver struct {
godrorDriver
}

6
dialects/time.go

@ -32,6 +32,9 @@ func FormatColumnTime(dialect Dialect, dbLocation *time.Location, col *schemas.C
switch col.SQLType.Name {
case schemas.Date:
if dialect.URI().DBType == schemas.ORACLE {
return t, nil
}
return t.Format("2006-01-02"), nil
case schemas.Time:
var layout = "15:04:05"
@ -40,6 +43,9 @@ func FormatColumnTime(dialect Dialect, dbLocation *time.Location, col *schemas.C
}
return t.Format(layout), nil
case schemas.DateTime, schemas.TimeStamp:
if dialect.URI().DBType == schemas.ORACLE {
return t, nil
}
var layout = "2006-01-02 15:04:05"
if col.Length > 0 {
layout += "." + strings.Repeat("0", col.Length)

2
go.mod

@ -8,8 +8,10 @@ require (
github.com/go-sql-driver/mysql v1.6.0
github.com/goccy/go-json v0.7.4
github.com/jackc/pgx/v4 v4.12.0
github.com/godror/godror v0.25.3
github.com/json-iterator/go v1.1.11
github.com/lib/pq v1.10.2
github.com/mattn/go-oci8 v0.1.0
github.com/mattn/go-sqlite3 v1.14.8
github.com/shopspring/decimal v1.2.0
github.com/stretchr/testify v1.7.0

8
go.sum

@ -75,6 +75,8 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih4=
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
@ -87,6 +89,8 @@ github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFG
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4=
github.com/godror/godror v0.25.3 h1:ltL94Ct9otjMfUNTRMqyZh0GpepPd9f9pyFgtUciT9k=
github.com/godror/godror v0.25.3/go.mod h1:JgtdZ1iSaNoioa/B53BVVWji9J9iGPDDj2763T5d1So=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
@ -105,6 +109,7 @@ github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.3 h1:x95R7cp+rSeeqAMI2knLtQ0DKlaBhv2NrtrOvafPHRo=
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@ -245,6 +250,8 @@ github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2y
github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/mattn/go-oci8 v0.1.0 h1:HnmdOy+/cLPN43obUokKka4hRE4b7Hp3U3E0fs1clp8=
github.com/mattn/go-oci8 v0.1.0/go.mod h1:wjDx6Xm9q7dFtHJvIlrI99JytznLw5wQ4R+9mNXJwGI=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU=
github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
@ -435,6 +442,7 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

7
integrations/engine_test.go

@ -17,12 +17,14 @@ import (
_ "gitee.com/travelliu/dm"
_ "github.com/denisenkom/go-mssqldb"
_ "github.com/go-sql-driver/mysql"
_ "github.com/godror/godror"
_ "github.com/jackc/pgx/v4/stdlib"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
_ "github.com/ziutek/mymysql/godrv"
_ "modernc.org/sqlite"
"github.com/stretchr/testify/assert"
)
func TestPing(t *testing.T) {
@ -61,8 +63,7 @@ func TestAutoTransaction(t *testing.T) {
engine.Transaction(func(session *xorm.Session) (interface{}, error) {
_, err := session.Insert(TestTx{Msg: "hi"})
assert.NoError(t, err)
return nil, nil
return nil, err
})
has, err := engine.Exist(&TestTx{Msg: "hi"})

11
integrations/oracle_test.go

@ -0,0 +1,11 @@
// Copyright 2021 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build oralce
package integrations
import (
_ "github.com/mattn/go-oci8"
)

8
integrations/processors_test.go

@ -10,6 +10,7 @@ import (
"testing"
"xorm.io/xorm"
"xorm.io/xorm/dialects"
"github.com/stretchr/testify/assert"
)
@ -885,7 +886,8 @@ func TestAfterLoadProcessor(t *testing.T) {
}
type AfterInsertStruct struct {
Id int64
Id int64
Dialect dialects.Dialect `xorm:"-"`
}
func (a *AfterInsertStruct) AfterInsert() {
@ -899,6 +901,8 @@ func TestAfterInsert(t *testing.T) {
assertSync(t, new(AfterInsertStruct))
_, err := testEngine.Insert(&AfterInsertStruct{})
_, err := testEngine.Insert(&AfterInsertStruct{
Dialect: testEngine.Dialect(),
})
assert.NoError(t, err)
}

2
integrations/session_cond_test.go

@ -85,7 +85,7 @@ func TestBuilder(t *testing.T) {
assert.EqualValues(t, 1, len(conds), "records should exist")
conds = make([]Condition, 0)
err = testEngine.NotIn("col_name", "col1", "col2").Find(&conds)
err = testEngine.NotIn("`col_name`", "col1", "col2").Find(&conds)
assert.NoError(t, err)
assert.EqualValues(t, 0, len(conds), "records should not exist")

60
internal/statements/insert.go

@ -27,7 +27,7 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem
}
// GenInsertSQL generates insert beans SQL
func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (string, []interface{}, error) {
func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (*builder.BytesWriter, error) {
var (
buf = builder.NewWriter()
exprs = statement.ExprColumns
@ -36,11 +36,11 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
)
if _, err := buf.WriteString("INSERT INTO "); err != nil {
return "", nil, err
return nil, err
}
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil {
return "", nil, err
return nil, err
}
var hasInsertColumns = len(colNames) > 0
@ -58,19 +58,19 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
statement.dialect.URI().DBType != schemas.DAMENG {
if statement.dialect.URI().DBType == schemas.MYSQL {
if _, err := buf.WriteString(" VALUES ()"); err != nil {
return "", nil, err
return nil, err
}
} else {
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
return "", nil, err
return nil, err
}
if _, err := buf.WriteString(" DEFAULT VALUES"); err != nil {
return "", nil, err
return nil, err
}
}
} else {
if _, err := buf.WriteString(" ("); err != nil {
return "", nil, err
return nil, err
}
if needSeq {
@ -78,106 +78,106 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
}
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames()...), ","); err != nil {
return "", nil, err
return nil, err
}
if _, err := buf.WriteString(")"); err != nil {
return "", nil, err
return nil, err
}
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
return "", nil, err
return nil, err
}
if statement.Conds().IsValid() {
if _, err := buf.WriteString(" SELECT "); err != nil {
return "", nil, err
return nil, err
}
if err := statement.WriteArgs(buf, args); err != nil {
return "", nil, err
return nil, err
}
if needSeq {
if len(args) > 0 {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
return nil, err
}
}
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
return "", nil, err
return nil, err
}
}
if len(exprs) > 0 {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
return nil, err
}
if err := exprs.WriteArgs(buf); err != nil {
return "", nil, err
return nil, err
}
}
if _, err := buf.WriteString(" FROM "); err != nil {
return "", nil, err
return nil, err
}
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil {
return "", nil, err
return nil, err
}
if _, err := buf.WriteString(" WHERE "); err != nil {
return "", nil, err
return nil, err
}
if err := statement.Conds().WriteTo(buf); err != nil {
return "", nil, err
return nil, err
}
} else {
if _, err := buf.WriteString(" VALUES ("); err != nil {
return "", nil, err
return nil, err
}
if err := statement.WriteArgs(buf, args); err != nil {
return "", nil, err
return nil, err
}
// Insert tablename (id) Values(seq_tablename.nextval)
if needSeq {
if hasInsertColumns {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
return nil, err
}
}
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
return "", nil, err
return nil, err
}
}
if len(exprs) > 0 {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
return nil, err
}
}
if err := exprs.WriteArgs(buf); err != nil {
return "", nil, err
return nil, err
}
if _, err := buf.WriteString(")"); err != nil {
return "", nil, err
return nil, err
}
}
}
if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES {
if _, err := buf.WriteString(" RETURNING "); err != nil {
return "", nil, err
return nil, err
}
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil {
return "", nil, err
return nil, err
}
}
return buf.String(), buf.Args(), nil
return buf, nil
}
// GenInsertMapSQL generates insert map SQL

4
internal/statements/statement.go

@ -91,6 +91,10 @@ func NewStatement(dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZ
return statement
}
func (statement *Statement) Dialect() dialects.Dialect {
return statement.dialect
}
// SetTableName set table name
func (statement *Statement) SetTableName(tableName string) {
statement.tableName = tableName

24
session_insert.go

@ -257,7 +257,9 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
if err := session.statement.SetRefBean(bean); err != nil {
return 0, err
}
if len(session.statement.TableName()) <= 0 {
var tableName = session.statement.TableName()
if tableName == "" {
return 0, ErrTableNotFound
}
@ -271,7 +273,6 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
processor.BeforeInsert()
}
var tableName = session.statement.TableName()
table := session.statement.RefTable
colNames, args, err := session.genInsertColumns(bean)
@ -279,11 +280,12 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
return 0, err
}
sqlStr, args, err := session.statement.GenInsertSQL(colNames, args)
buf, err := session.statement.GenInsertSQL(colNames, args)
if err != nil {
return 0, err
}
sqlStr = session.engine.dialect.Quoter().Replace(sqlStr)
var sqlStr = session.engine.dialect.Quoter().Replace(buf.String())
args = buf.Args()
handleAfterInsertProcessorFunc := func(bean interface{}) {
if session.isAutoCommit {
@ -383,7 +385,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
return 1, convert.AssignValue(*aiValue, id)
}
res, err := session.exec(sqlStr, args...)
res, err := session.exec(buf.String(), buf.Args()...)
if err != nil {
return 0, err
}
@ -503,13 +505,21 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
}
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
// if time is non-empty, then set to auto time
val, t, err := session.engine.nowTime(col)
if err != nil {
return nil, nil, err
}
args = append(args, val)
if session.engine.dialect.URI().DBType == schemas.ORACLE {
if col.SQLType.IsNumeric() {
args = append(args, t.Unix())
} else {
args = append(args, t)
}
} else {
args = append(args, val)
}
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {

23
session_schema.go

@ -148,8 +148,27 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
}
func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName := session.engine.TableName(beanOrTableName)
sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true))
var tableName string
switch beanOrTableName.(type) {
case *schemas.Table:
table := beanOrTableName.(*schemas.Table)
tableName = table.Name
case string:
tableName = beanOrTableName.(string)
default:
v := utils.ReflectValue(beanOrTableName)
table, err := session.engine.tagParser.ParseWithCache(v)
if err != nil {
return err
}
if session.statement.AltTableName != "" {
tableName = session.statement.AltTableName
} else {
tableName = table.Name
}
}
sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(tableName)
if !checkIfExist {
exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
if err != nil {

6
session_update.go

@ -220,7 +220,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, err
}
if session.engine.dialect.URI().DBType == schemas.ORACLE {
args = append(args, t)
if col.SQLType.IsNumeric() {
args = append(args, t.Unix())
} else {
args = append(args, t)
}
} else {
args = append(args, val)
}

Loading…
Cancel
Save