diff --git a/.drone.yml b/.drone.yml index 34b9a514..faaeca17 100644 --- a/.drone.yml +++ b/.drone.yml @@ -363,6 +363,41 @@ services: commands: - /cockroach/cockroach start --insecure +# --- +# kind: pipeline +# name: test-dameng +# depends_on: +# - test-cockroach +# trigger: +# ref: +# - refs/heads/master +# - refs/pull/*/head +# steps: +# - name: test-dameng +# pull: never +# image: golang:1.15 +# volumes: +# - name: cache +# path: /go/pkg/mod +# environment: +# TEST_DAMENG_HOST: "dameng:5236" +# TEST_DAMENG_USERNAME: SYSDBA +# TEST_DAMENG_PASSWORD: SYSDBA +# commands: +# - sleep 30 +# - make test-dameng + +# volumes: +# - name: cache +# host: +# path: /tmp/cache + +# services: +# - name: dameng +# image: lunny/dm:v1.0 +# commands: +# - /bin/bash /startDm.sh + --- kind: pipeline name: merge_coverage @@ -374,6 +409,7 @@ depends_on: - test-mssql - test-tidb - test-cockroach + #- test-dameng trigger: ref: - refs/heads/master diff --git a/Makefile b/Makefile index e986082e..e9bd4129 100644 --- a/Makefile +++ b/Makefile @@ -43,6 +43,10 @@ TEST_TIDB_DBNAME ?= xorm_test TEST_TIDB_USERNAME ?= root TEST_TIDB_PASSWORD ?= +TEST_DAMENG_HOST ?= dameng:5236 +TEST_DAMENG_USERNAME ?= SYSDBA +TEST_DAMENG_PASSWORD ?= SYSDBA + TEST_CACHE_ENABLE ?= false TEST_QUOTE_POLICY ?= always @@ -240,7 +244,6 @@ test-sqlite\#%: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -cache=$(TEST_CACHE_ENABLE) -db=sqlite -conn_str="./test.db?cache=shared&mode=rwc" \ -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic - .PNONY: test-tidb test-tidb: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \ @@ -253,6 +256,18 @@ test-tidb\#%: go-check -conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \ -quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic +.PNONY: test-dameng +test-dameng: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=dm -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ + -conn_str="dm://$(TEST_DAMENG_USERNAME):$(TEST_DAMENG_PASSWORD)@$(TEST_DAMENG_HOST)" \ + -coverprofile=dameng.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m + +.PHONY: test-dameng\#% +test-dameng\#%: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=dm -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ + -conn_str="dm://$(TEST_DAMENG_USERNAME):$(TEST_DAMENG_PASSWORD)@$(TEST_DAMENG_HOST)" \ + -coverprofile=dameng.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m + .PHONY: vet vet: $(GO) vet $(shell $(GO) list ./...) diff --git a/convert/conversion.go b/convert/conversion.go index 78a9fd78..b69e345c 100644 --- a/convert/conversion.go +++ b/convert/conversion.go @@ -283,11 +283,9 @@ func Assign(dest, src interface{}, originalLocation *time.Location, convertedLoc } } - var sv reflect.Value - switch d := dest.(type) { case *string: - sv = reflect.ValueOf(src) + var sv = reflect.ValueOf(src) switch sv.Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, diff --git a/dialects/dameng.go b/dialects/dameng.go new file mode 100644 index 00000000..5ba0cfb5 --- /dev/null +++ b/dialects/dameng.go @@ -0,0 +1,1181 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/url" + "strconv" + "strings" + + "xorm.io/xorm/convert" + "xorm.io/xorm/core" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" +) + +func init() { + RegisterDriver("dm", &damengDriver{}) + RegisterDialect(schemas.DAMENG, func() Dialect { + return &dameng{} + }) +} + +var ( + damengReservedWords = map[string]bool{ + "ACCESS": true, + "ACCOUNT": true, + "ACTIVATE": true, + "ADD": true, + "ADMIN": true, + "ADVISE": true, + "AFTER": true, + "ALL": true, + "ALL_ROWS": true, + "ALLOCATE": true, + "ALTER": true, + "ANALYZE": true, + "AND": true, + "ANY": true, + "ARCHIVE": true, + "ARCHIVELOG": true, + "ARRAY": true, + "AS": true, + "ASC": true, + "AT": true, + "AUDIT": true, + "AUTHENTICATED": true, + "AUTHORIZATION": true, + "AUTOEXTEND": true, + "AUTOMATIC": true, + "BACKUP": true, + "BECOME": true, + "BEFORE": true, + "BEGIN": true, + "BETWEEN": true, + "BFILE": true, + "BITMAP": true, + "BLOB": true, + "BLOCK": true, + "BODY": true, + "BY": true, + "CACHE": true, + "CACHE_INSTANCES": true, + "CANCEL": true, + "CASCADE": true, + "CAST": true, + "CFILE": true, + "CHAINED": true, + "CHANGE": true, + "CHAR": true, + "CHAR_CS": true, + "CHARACTER": true, + "CHECK": true, + "CHECKPOINT": true, + "CHOOSE": true, + "CHUNK": true, + "CLEAR": true, + "CLOB": true, + "CLONE": true, + "CLOSE": true, + "CLOSE_CACHED_OPEN_CURSORS": true, + "CLUSTER": true, + "COALESCE": true, + "COLUMN": true, + "COLUMNS": true, + "COMMENT": true, + "COMMIT": true, + "COMMITTED": true, + "COMPATIBILITY": true, + "COMPILE": true, + "COMPLETE": true, + "COMPOSITE_LIMIT": true, + "COMPRESS": true, + "COMPUTE": true, + "CONNECT": true, + "CONNECT_TIME": true, + "CONSTRAINT": true, + "CONSTRAINTS": true, + "CONTENTS": true, + "CONTINUE": true, + "CONTROLFILE": true, + "CONVERT": true, + "COST": true, + "CPU_PER_CALL": true, + "CPU_PER_SESSION": true, + "CREATE": true, + "CURRENT": true, + "CURRENT_SCHEMA": true, + "CURREN_USER": true, + "CURSOR": true, + "CYCLE": true, + "DANGLING": true, + "DATABASE": true, + "DATAFILE": true, + "DATAFILES": true, + "DATAOBJNO": true, + "DATE": true, + "DBA": true, + "DBHIGH": true, + "DBLOW": true, + "DBMAC": true, + "DEALLOCATE": true, + "DEBUG": true, + "DEC": true, + "DECIMAL": true, + "DECLARE": true, + "DEFAULT": true, + "DEFERRABLE": true, + "DEFERRED": true, + "DEGREE": true, + "DELETE": true, + "DEREF": true, + "DESC": true, + "DIRECTORY": true, + "DISABLE": true, + "DISCONNECT": true, + "DISMOUNT": true, + "DISTINCT": true, + "DISTRIBUTED": true, + "DML": true, + "DOUBLE": true, + "DROP": true, + "DUMP": true, + "EACH": true, + "ELSE": true, + "ENABLE": true, + "END": true, + "ENFORCE": true, + "ENTRY": true, + "ESCAPE": true, + "EXCEPT": true, + "EXCEPTIONS": true, + "EXCHANGE": true, + "EXCLUDING": true, + "EXCLUSIVE": true, + "EXECUTE": true, + "EXISTS": true, + "EXPIRE": true, + "EXPLAIN": true, + "EXTENT": true, + "EXTENTS": true, + "EXTERNALLY": true, + "FAILED_LOGIN_ATTEMPTS": true, + "FALSE": true, + "FAST": true, + "FILE": true, + "FIRST_ROWS": true, + "FLAGGER": true, + "FLOAT": true, + "FLOB": true, + "FLUSH": true, + "FOR": true, + "FORCE": true, + "FOREIGN": true, + "FREELIST": true, + "FREELISTS": true, + "FROM": true, + "FULL": true, + "FUNCTION": true, + "GLOBAL": true, + "GLOBALLY": true, + "GLOBAL_NAME": true, + "GRANT": true, + "GROUP": true, + "GROUPS": true, + "HASH": true, + "HASHKEYS": true, + "HAVING": true, + "HEADER": true, + "HEAP": true, + "IDENTIFIED": true, + "IDGENERATORS": true, + "IDLE_TIME": true, + "IF": true, + "IMMEDIATE": true, + "IN": true, + "INCLUDING": true, + "INCREMENT": true, + "INDEX": true, + "INDEXED": true, + "INDEXES": true, + "INDICATOR": true, + "IND_PARTITION": true, + "INITIAL": true, + "INITIALLY": true, + "INITRANS": true, + "INSERT": true, + "INSTANCE": true, + "INSTANCES": true, + "INSTEAD": true, + "INT": true, + "INTEGER": true, + "INTERMEDIATE": true, + "INTERSECT": true, + "INTO": true, + "IS": true, + "ISOLATION": true, + "ISOLATION_LEVEL": true, + "KEEP": true, + "KEY": true, + "KILL": true, + "LABEL": true, + "LAYER": true, + "LESS": true, + "LEVEL": true, + "LIBRARY": true, + "LIKE": true, + "LIMIT": true, + "LINK": true, + "LIST": true, + "LOB": true, + "LOCAL": true, + "LOCK": true, + "LOCKED": true, + "LOG": true, + "LOGFILE": true, + "LOGGING": true, + "LOGICAL_READS_PER_CALL": true, + "LOGICAL_READS_PER_SESSION": true, + "LONG": true, + "MANAGE": true, + "MASTER": true, + "MAX": true, + "MAXARCHLOGS": true, + "MAXDATAFILES": true, + "MAXEXTENTS": true, + "MAXINSTANCES": true, + "MAXLOGFILES": true, + "MAXLOGHISTORY": true, + "MAXLOGMEMBERS": true, + "MAXSIZE": true, + "MAXTRANS": true, + "MAXVALUE": true, + "MIN": true, + "MEMBER": true, + "MINIMUM": true, + "MINEXTENTS": true, + "MINUS": true, + "MINVALUE": true, + "MLSLABEL": true, + "MLS_LABEL_FORMAT": true, + "MODE": true, + "MODIFY": true, + "MOUNT": true, + "MOVE": true, + "MTS_DISPATCHERS": true, + "MULTISET": true, + "NATIONAL": true, + "NCHAR": true, + "NCHAR_CS": true, + "NCLOB": true, + "NEEDED": true, + "NESTED": true, + "NETWORK": true, + "NEW": true, + "NEXT": true, + "NOARCHIVELOG": true, + "NOAUDIT": true, + "NOCACHE": true, + "NOCOMPRESS": true, + "NOCYCLE": true, + "NOFORCE": true, + "NOLOGGING": true, + "NOMAXVALUE": true, + "NOMINVALUE": true, + "NONE": true, + "NOORDER": true, + "NOOVERRIDE": true, + "NOPARALLEL": true, + "NOREVERSE": true, + "NORMAL": true, + "NOSORT": true, + "NOT": true, + "NOTHING": true, + "NOWAIT": true, + "NULL": true, + "NUMBER": true, + "NUMERIC": true, + "NVARCHAR2": true, + "OBJECT": true, + "OBJNO": true, + "OBJNO_REUSE": true, + "OF": true, + "OFF": true, + "OFFLINE": true, + "OID": true, + "OIDINDEX": true, + "OLD": true, + "ON": true, + "ONLINE": true, + "ONLY": true, + "OPCODE": true, + "OPEN": true, + "OPTIMAL": true, + "OPTIMIZER_GOAL": true, + "OPTION": true, + "OR": true, + "ORDER": true, + "ORGANIZATION": true, + "OSLABEL": true, + "OVERFLOW": true, + "OWN": true, + "PACKAGE": true, + "PARALLEL": true, + "PARTITION": true, + "PASSWORD": true, + "PASSWORD_GRACE_TIME": true, + "PASSWORD_LIFE_TIME": true, + "PASSWORD_LOCK_TIME": true, + "PASSWORD_REUSE_MAX": true, + "PASSWORD_REUSE_TIME": true, + "PASSWORD_VERIFY_FUNCTION": true, + "PCTFREE": true, + "PCTINCREASE": true, + "PCTTHRESHOLD": true, + "PCTUSED": true, + "PCTVERSION": true, + "PERCENT": true, + "PERMANENT": true, + "PLAN": true, + "PLSQL_DEBUG": true, + "POST_TRANSACTION": true, + "PRECISION": true, + "PRESERVE": true, + "PRIMARY": true, + "PRIOR": true, + "PRIVATE": true, + "PRIVATE_SGA": true, + "PRIVILEGE": true, + "PRIVILEGES": true, + "PROCEDURE": true, + "PROFILE": true, + "PUBLIC": true, + "PURGE": true, + "QUEUE": true, + "QUOTA": true, + "RANGE": true, + "RAW": true, + "RBA": true, + "READ": true, + "READUP": true, + "REAL": true, + "REBUILD": true, + "RECOVER": true, + "RECOVERABLE": true, + "RECOVERY": true, + "REF": true, + "REFERENCES": true, + "REFERENCING": true, + "REFRESH": true, + "RENAME": true, + "REPLACE": true, + "RESET": true, + "RESETLOGS": true, + "RESIZE": true, + "RESOURCE": true, + "RESTRICTED": true, + "RETURN": true, + "RETURNING": true, + "REUSE": true, + "REVERSE": true, + "REVOKE": true, + "ROLE": true, + "ROLES": true, + "ROLLBACK": true, + "ROW": true, + "ROWID": true, + "ROWNUM": true, + "ROWS": true, + "RULE": true, + "SAMPLE": true, + "SAVEPOINT": true, + "SB4": true, + "SCAN_INSTANCES": true, + "SCHEMA": true, + "SCN": true, + "SCOPE": true, + "SD_ALL": true, + "SD_INHIBIT": true, + "SD_SHOW": true, + "SEGMENT": true, + "SEG_BLOCK": true, + "SEG_FILE": true, + "SELECT": true, + "SEQUENCE": true, + "SERIALIZABLE": true, + "SESSION": true, + "SESSION_CACHED_CURSORS": true, + "SESSIONS_PER_USER": true, + "SET": true, + "SHARE": true, + "SHARED": true, + "SHARED_POOL": true, + "SHRINK": true, + "SIZE": true, + "SKIP": true, + "SKIP_UNUSABLE_INDEXES": true, + "SMALLINT": true, + "SNAPSHOT": true, + "SOME": true, + "SORT": true, + "SPECIFICATION": true, + "SPLIT": true, + "SQL_TRACE": true, + "STANDBY": true, + "START": true, + "STATEMENT_ID": true, + "STATISTICS": true, + "STOP": true, + "STORAGE": true, + "STORE": true, + "STRUCTURE": true, + "SUCCESSFUL": true, + "SWITCH": true, + "SYS_OP_ENFORCE_NOT_NULL$": true, + "SYS_OP_NTCIMG$": true, + "SYNONYM": true, + "SYSDATE": true, + "SYSDBA": true, + "SYSOPER": true, + "SYSTEM": true, + "TABLE": true, + "TABLES": true, + "TABLESPACE": true, + "TABLESPACE_NO": true, + "TABNO": true, + "TEMPORARY": true, + "THAN": true, + "THE": true, + "THEN": true, + "THREAD": true, + "TIMESTAMP": true, + "TIME": true, + "TO": true, + "TOPLEVEL": true, + "TRACE": true, + "TRACING": true, + "TRANSACTION": true, + "TRANSITIONAL": true, + "TRIGGER": true, + "TRIGGERS": true, + "TRUE": true, + "TRUNCATE": true, + "TX": true, + "TYPE": true, + "UB2": true, + "UBA": true, + "UID": true, + "UNARCHIVED": true, + "UNDO": true, + "UNION": true, + "UNIQUE": true, + "UNLIMITED": true, + "UNLOCK": true, + "UNRECOVERABLE": true, + "UNTIL": true, + "UNUSABLE": true, + "UNUSED": true, + "UPDATABLE": true, + "UPDATE": true, + "USAGE": true, + "USE": true, + "USER": true, + "USING": true, + "VALIDATE": true, + "VALIDATION": true, + "VALUE": true, + "VALUES": true, + "VARCHAR": true, + "VARCHAR2": true, + "VARYING": true, + "VIEW": true, + "WHEN": true, + "WHENEVER": true, + "WHERE": true, + "WITH": true, + "WITHOUT": true, + "WORK": true, + "WRITE": true, + "WRITEDOWN": true, + "WRITEUP": true, + "XID": true, + "YEAR": true, + "ZONE": true, + } + + damengQuoter = schemas.Quoter{ + Prefix: '"', + Suffix: '"', + IsReserved: schemas.AlwaysReserve, + } +) + +type dameng struct { + Base +} + +func (db *dameng) Init(uri *URI) error { + db.quoter = damengQuoter + return db.Base.Init(db, uri) +} + +func (db *dameng) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { + rows, err := queryer.QueryContext(ctx, "SELECT * FROM V$VERSION") // select id_code + if err != nil { + return nil, err + } + defer rows.Close() + + var version string + if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") + } + + if err := rows.Scan(&version); err != nil { + return nil, err + } + return &schemas.Version{ + Number: version, + }, nil +} + +func (db *dameng) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: SequenceAutoincrMode, + } +} + +// DropIndexSQL returns a SQL to drop index +func (db *dameng) DropIndexSQL(tableName string, index *schemas.Index) string { + quote := db.dialect.Quoter().Quote + var name string + if index.IsRegular { + name = index.XName(tableName) + } else { + name = index.Name + } + return fmt.Sprintf("DROP INDEX %v", quote(name)) +} + +func (db *dameng) SQLType(c *schemas.Column) string { + var res string + switch t := c.SQLType.Name; t { + case schemas.TinyInt, "BYTE": + return "TINYINT" + case schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedTinyInt: + return "INTEGER" + case schemas.BigInt, + schemas.UnsignedBigInt, schemas.UnsignedBit, schemas.UnsignedInt, + schemas.Serial, schemas.BigSerial: + return "BIGINT" + case schemas.Bit, schemas.Bool, schemas.Boolean: + return schemas.Bit + case schemas.Uuid: + res = schemas.Varchar + c.Length = 40 + case schemas.Binary: + if c.Length == 0 { + return schemas.Binary + "(MAX)" + } + case schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: + return schemas.VarBinary + case schemas.Date: + return schemas.Date + case schemas.Time: + if c.Length > 0 { + return fmt.Sprintf("%s(%d)", schemas.Time, c.Length) + } + return schemas.Time + case schemas.DateTime, schemas.TimeStamp: + res = schemas.TimeStamp + case schemas.TimeStampz: + if c.Length > 0 { + return fmt.Sprintf("TIMESTAMP(%d) WITH TIME ZONE", c.Length) + } + return "TIMESTAMP WITH TIME ZONE" + case schemas.Float: + res = "FLOAT" + case schemas.Real, schemas.Double: + res = "REAL" + case schemas.Numeric, schemas.Decimal, "NUMBER": + res = "NUMERIC" + case schemas.Text, schemas.Json: + return "TEXT" + case schemas.MediumText, schemas.LongText: + res = "CLOB" + case schemas.Char, schemas.Varchar, schemas.TinyText: + res = "VARCHAR2" + default: + res = t + } + + hasLen1 := (c.Length > 0) + hasLen2 := (c.Length2 > 0) + + if hasLen2 { + res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" + } else if hasLen1 { + res += "(" + strconv.Itoa(c.Length) + ")" + } + return res +} + +func (db *dameng) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATE": + return schemas.TIME_TYPE + case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB": + return schemas.TEXT_TYPE + case "NUMBER": + return schemas.NUMERIC_TYPE + case "BLOB": + return schemas.BLOB_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + +func (db *dameng) AutoIncrStr() string { + return "IDENTITY" +} + +func (db *dameng) IsReserved(name string) bool { + _, ok := damengReservedWords[strings.ToUpper(name)] + return ok +} + +func (db *dameng) DropTableSQL(tableName string) (string, bool) { + return fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)), false +} + +// ModifyColumnSQL returns a SQL to modify SQL +func (db *dameng) ModifyColumnSQL(tableName string, col *schemas.Column) string { + s, _ := ColumnString(db.dialect, col, false) + return fmt.Sprintf("ALTER TABLE %s MODIFY %s", db.quoter.Quote(tableName), s) +} + +func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { + if tableName == "" { + tableName = table.Name + } + + quoter := db.Quoter() + var b strings.Builder + b.WriteString("CREATE TABLE ") + quoter.QuoteTo(&b, tableName) + b.WriteString(" (") + + pkList := table.PrimaryKeys + + for i, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + if col.SQLType.IsBool() && !col.DefaultIsEmpty { + if col.Default == "true" { + col.Default = "1" + } else if col.Default == "false" { + col.Default = "0" + } + } + + s, _ := ColumnString(db, col, false) + b.WriteString(s) + if i != len(table.ColumnsSeq())-1 { + b.WriteString(", ") + } + } + + if len(pkList) > 0 { + if len(table.ColumnsSeq()) > 0 { + b.WriteString(", ") + } + b.WriteString(fmt.Sprintf("CONSTRAINT PK_%s PRIMARY KEY (", tableName)) + quoter.JoinWrite(&b, pkList, ",") + b.WriteString(")") + } + b.WriteString(")") + + return b.String(), false, nil +} + +func (db *dameng) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = damengQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = damengQuoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = damengQuoter + } +} + +func (db *dameng) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { + args := []interface{}{tableName, idxName} + return `SELECT INDEX_NAME FROM USER_INDEXES ` + + `WHERE TABLE_NAME = ? AND INDEX_NAME = ?`, args +} + +func (db *dameng) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { + return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = ?`, tableName) +} + +func (db *dameng) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) { + var cnt int + rows, err := queryer.QueryContext(ctx, "SELECT COUNT(*) FROM user_sequences WHERE sequence_name = ?", seqName) + if err != nil { + return false, err + } + defer rows.Close() + if !rows.Next() { + if rows.Err() != nil { + return false, rows.Err() + } + return false, errors.New("query sequence failed") + } + + if err := rows.Scan(&cnt); err != nil { + return false, err + } + return cnt > 0, nil +} + +func (db *dameng) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { + args := []interface{}{tableName, colName} + query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" + + " AND column_name = ?" + return db.HasRecords(queryer, ctx, query, args...) +} + +var _ sql.Scanner = &dmClobScanner{} + +type dmClobScanner struct { + valid bool + data string +} + +type dmClobObject interface { + GetLength() (int64, error) + ReadString(int, int) (string, error) +} + +//var _ dmClobObject = &dm.DmClob{} + +func (d *dmClobScanner) Scan(data interface{}) error { + if data == nil { + return nil + } + + switch t := data.(type) { + case dmClobObject: // *dm.DmClob + if t == nil { + return nil + } + l, err := t.GetLength() + if err != nil { + return err + } + if l == 0 { + d.valid = true + return nil + } + d.data, err = t.ReadString(1, int(l)) + if err != nil { + return err + } + d.valid = true + return nil + case []byte: + if t == nil { + return nil + } + d.data = string(t) + d.valid = true + return nil + default: + return fmt.Errorf("cannot convert %T as dmClobScanner", data) + } +} + +func addSingleQuote(name string) string { + if len(name) < 2 { + return name + } + if name[0] == '\'' && name[len(name)-1] == '\'' { + return name + } + return fmt.Sprintf("'%s'", name) +} + +func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { + s := `select column_name from user_cons_columns + where constraint_name = (select constraint_name from user_constraints + where table_name = ? and constraint_type ='P')` + rows, err := queryer.QueryContext(ctx, s, tableName) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + var pkNames []string + for rows.Next() { + var pkName string + err = rows.Scan(&pkName) + if err != nil { + return nil, nil, err + } + pkNames = append(pkNames, pkName) + } + if rows.Err() != nil { + return nil, nil, rows.Err() + } + rows.Close() + + s = `SELECT USER_TAB_COLS.COLUMN_NAME, USER_TAB_COLS.DATA_DEFAULT, USER_TAB_COLS.DATA_TYPE, USER_TAB_COLS.DATA_LENGTH, + USER_TAB_COLS.data_precision, USER_TAB_COLS.data_scale, USER_TAB_COLS.NULLABLE, + user_col_comments.comments + FROM USER_TAB_COLS + LEFT JOIN user_col_comments on user_col_comments.TABLE_NAME=USER_TAB_COLS.TABLE_NAME + AND user_col_comments.COLUMN_NAME=USER_TAB_COLS.COLUMN_NAME + WHERE USER_TAB_COLS.table_name = ?` + rows, err = queryer.QueryContext(ctx, s, tableName) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + cols := make(map[string]*schemas.Column) + colSeq := make([]string, 0) + for rows.Next() { + col := new(schemas.Column) + col.Indexes = make(map[string]int) + + var colDefault dmClobScanner + var colName, nullable, dataType, dataPrecision, comment sql.NullString + var dataScale, dataLen sql.NullInt64 + + err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision, + &dataScale, &nullable, &comment) + if err != nil { + return nil, nil, err + } + + if !colName.Valid { + return nil, nil, errors.New("column name is nil") + } + + col.Name = strings.Trim(colName.String, `" `) + if colDefault.valid { + col.Default = colDefault.data + } else { + col.DefaultIsEmpty = true + } + + if nullable.String == "Y" { + col.Nullable = true + } else { + col.Nullable = false + } + + if !comment.Valid { + col.Comment = comment.String + } + if utils.IndexSlice(pkNames, col.Name) > -1 { + col.IsPrimaryKey = true + has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = ?", utils.SeqName(tableName)) + if err != nil { + return nil, nil, err + } + if has { + col.IsAutoIncrement = true + } + } + + var ( + ignore bool + dt string + len1, len2 int + ) + + dts := strings.Split(dataType.String, "(") + dt = dts[0] + if len(dts) > 1 { + lens := strings.Split(dts[1][:len(dts[1])-1], ",") + if len(lens) > 1 { + len1, _ = strconv.Atoi(lens[0]) + len2, _ = strconv.Atoi(lens[1]) + } else { + len1, _ = strconv.Atoi(lens[0]) + } + } + + switch dt { + case "VARCHAR2": + col.SQLType = schemas.SQLType{Name: "VARCHAR2", DefaultLength: len1, DefaultLength2: len2} + case "VARCHAR": + col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: len1, DefaultLength2: len2} + case "TIMESTAMP WITH TIME ZONE": + col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} + case "NUMBER": + col.SQLType = schemas.SQLType{Name: "NUMBER", DefaultLength: len1, DefaultLength2: len2} + case "LONG", "LONG RAW", "NCLOB", "CLOB", "TEXT": + col.SQLType = schemas.SQLType{Name: schemas.Text, DefaultLength: 0, DefaultLength2: 0} + case "RAW": + col.SQLType = schemas.SQLType{Name: schemas.Binary, DefaultLength: 0, DefaultLength2: 0} + case "ROWID": + col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: 18, DefaultLength2: 0} + case "AQ$_SUBSCRIBERS": + ignore = true + default: + col.SQLType = schemas.SQLType{Name: strings.ToUpper(dt), DefaultLength: len1, DefaultLength2: len2} + } + + if ignore { + continue + } + + if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok { + return nil, nil, fmt.Errorf("unknown colType %v %v", dataType.String, col.SQLType) + } + + if col.SQLType.Name == "TIMESTAMP" { + col.Length = int(dataScale.Int64) + } else { + col.Length = int(dataLen.Int64) + } + + if col.SQLType.IsTime() { + if !col.DefaultIsEmpty && !strings.EqualFold(col.Default, "CURRENT_TIMESTAMP") { + col.Default = addSingleQuote(col.Default) + } + } + cols[col.Name] = col + colSeq = append(colSeq, col.Name) + } + if rows.Err() != nil { + return nil, nil, rows.Err() + } + + return colSeq, cols, nil +} + +func (db *dameng) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { + s := "SELECT table_name FROM user_tables WHERE temporary = 'N' AND table_name NOT LIKE ?" + args := []interface{}{strings.ToUpper(db.uri.User), "%$%"} + + rows, err := queryer.QueryContext(ctx, s, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + tables := make([]*schemas.Table, 0) + for rows.Next() { + table := schemas.NewEmptyTable() + err = rows.Scan(&table.Name) + if err != nil { + return nil, err + } + + tables = append(tables, table) + } + if rows.Err() != nil { + return nil, rows.Err() + } + return tables, nil +} + +func (db *dameng) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { + args := []interface{}{tableName, tableName} + s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + + "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =?" + + " AND t.index_name not in (SELECT index_name FROM ALL_CONSTRAINTS WHERE CONSTRAINT_TYPE='P' AND table_name = ?)" + + rows, err := queryer.QueryContext(ctx, s, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + indexes := make(map[string]*schemas.Index) + for rows.Next() { + var indexType int + var indexName, colName, uniqueness string + + err = rows.Scan(&colName, &uniqueness, &indexName) + if err != nil { + return nil, err + } + + indexName = strings.Trim(indexName, `" `) + + var isRegular bool + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { + indexName = indexName[5+len(tableName):] + isRegular = true + } + + if uniqueness == "UNIQUE" { + indexType = schemas.UniqueType + } else { + indexType = schemas.IndexType + } + + var index *schemas.Index + var ok bool + if index, ok = indexes[indexName]; !ok { + index = new(schemas.Index) + index.Type = indexType + index.Name = indexName + index.IsRegular = isRegular + indexes[indexName] = index + } + index.AddColumn(colName) + } + if rows.Err() != nil { + return nil, rows.Err() + } + return indexes, nil +} + +func (db *dameng) Filters() []Filter { + return []Filter{} +} + +type damengDriver struct { + baseDriver +} + +// Features return features +func (d *damengDriver) Features() *DriverFeatures { + return &DriverFeatures{ + SupportReturnInsertedID: false, + } +} + +// Parse parse the datasource +// dm://userName:password@ip:port +func (d *damengDriver) Parse(driverName, dataSourceName string) (*URI, error) { + u, err := url.Parse(dataSourceName) + if err != nil { + return nil, err + } + + if u.User == nil { + return nil, errors.New("user/password needed") + } + + passwd, _ := u.User.Password() + return &URI{ + DBType: schemas.DAMENG, + Proto: u.Scheme, + Host: u.Hostname(), + Port: u.Port(), + DBName: u.User.Username(), + User: u.User.Username(), + Passwd: passwd, + }, nil +} + +func (d *damengDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB": + var s sql.NullString + return &s, nil + case "NUMBER": + var s sql.NullString + return &s, nil + case "BIGINT": + var s sql.NullInt64 + return &s, nil + case "INTEGER": + var s sql.NullInt32 + return &s, nil + case "DATE", "TIMESTAMP": + var s sql.NullString + return &s, nil + case "BLOB": + var r sql.RawBytes + return &r, nil + case "FLOAT": + var s sql.NullFloat64 + return &s, nil + default: + var r sql.RawBytes + return &r, nil + } +} + +func (d *damengDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, vv ...interface{}) error { + var scanResults = make([]interface{}, 0, len(types)) + var replaces = make([]bool, 0, len(types)) + var err error + for i, v := range vv { + var replaced bool + var scanResult interface{} + switch types[i].DatabaseTypeName() { + case "CLOB", "TEXT": + scanResult = &dmClobScanner{} + replaced = true + case "TIMESTAMP": + scanResult = &sql.NullString{} + replaced = true + default: + scanResult = v + } + + scanResults = append(scanResults, scanResult) + replaces = append(replaces, replaced) + } + + if err = rows.Scan(scanResults...); err != nil { + return err + } + + for i, replaced := range replaces { + if replaced { + switch t := scanResults[i].(type) { + case *dmClobScanner: + var d interface{} + if t.valid { + d = t.data + } else { + d = nil + } + if err := convert.Assign(vv[i], d, ctx.DBLocation, ctx.UserLocation); err != nil { + return err + } + default: + switch types[i].DatabaseTypeName() { + case "TIMESTAMP": + ns := t.(*sql.NullString) + if !ns.Valid { + return nil + } + s := ns.String + fields := strings.Split(s, "+") + if err := convert.Assign(vv[i], strings.Replace(fields[0], "T", " ", -1), ctx.DBLocation, ctx.UserLocation); err != nil { + return err + } + default: + return fmt.Errorf("don't support convert %T to %T", t, vv[i]) + } + } + } + } + + return nil +} diff --git a/dialects/dialect.go b/dialects/dialect.go index b6c0853a..460ab56a 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -38,11 +38,21 @@ func (uri *URI) SetSchema(schema string) { } } +const ( + IncrAutoincrMode = iota + SequenceAutoincrMode +) + +type DialectFeatures struct { + AutoincrMode int // 0 autoincrement column, 1 sequence +} + // Dialect represents a kind of database type Dialect interface { Init(*URI) error URI() *URI Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) + Features() *DialectFeatures SQLType(*schemas.Column) string Alias(string) string // return what a sql type's alias of @@ -61,9 +71,13 @@ type Dialect interface { GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) - CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) + CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) DropTableSQL(tableName string) (string, bool) + CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) + IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) + DropSequenceSQL(seqName string) (string, error) + GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error) AddColumnSQL(tableName string, col *schemas.Column) string @@ -104,7 +118,7 @@ func (db *Base) URI() *URI { } // CreateTableSQL implements Dialect -func (db *Base) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { +func (db *Base) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { if tableName == "" { tableName = table.Name } @@ -133,7 +147,25 @@ func (db *Base) CreateTableSQL(table *schemas.Table, tableName string) ([]string b.WriteString(")") - return []string{b.String()}, false + return b.String(), false, nil +} + +func (db *Base) CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) { + return fmt.Sprintf(`CREATE SEQUENCE %s + minvalue 1 + nomaxvalue + start with 1 + increment by 1 + nocycle + nocache`, seqName), nil +} + +func (db *Base) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) { + return false, fmt.Errorf("unsupported sequence feature") +} + +func (db *Base) DropSequenceSQL(seqName string) (string, error) { + return fmt.Sprintf("DROP SEQUENCE %s", seqName), nil } // DropTableSQL returns drop table SQL @@ -285,43 +317,41 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) return "", err } - if err := bd.WriteByte(' '); err != nil { - return "", err - } - if includePrimaryKey && col.IsPrimaryKey { - if _, err := bd.WriteString("PRIMARY KEY "); err != nil { + if _, err := bd.WriteString(" PRIMARY KEY"); err != nil { return "", err } - if col.IsAutoIncrement { - if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil { - return "", err - } if err := bd.WriteByte(' '); err != nil { return "", err } + if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil { + return "", err + } } } - if col.Default != "" { - if _, err := bd.WriteString("DEFAULT "); err != nil { + if !col.DefaultIsEmpty { + if _, err := bd.WriteString(" DEFAULT "); err != nil { return "", err } - if _, err := bd.WriteString(col.Default); err != nil { - return "", err - } - if err := bd.WriteByte(' '); err != nil { - return "", err + if col.Default == "" { + if _, err := bd.WriteString("''"); err != nil { + return "", err + } + } else { + if _, err := bd.WriteString(col.Default); err != nil { + return "", err + } } } if col.Nullable { - if _, err := bd.WriteString("NULL "); err != nil { + if _, err := bd.WriteString(" NULL"); err != nil { return "", err } } else { - if _, err := bd.WriteString("NOT NULL "); err != nil { + if _, err := bd.WriteString(" NOT NULL"); err != nil { return "", err } } diff --git a/dialects/mssql.go b/dialects/mssql.go index ab010eb0..cd19afb9 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -282,6 +282,12 @@ func (db *mssql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Ve }, nil } +func (db *mssql) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: IncrAutoincrMode, + } +} + func (db *mssql) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { @@ -625,7 +631,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? return indexes, nil } -func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { +func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { if tableName == "" { tableName = table.Name } @@ -656,7 +662,7 @@ func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) ([]strin b.WriteString(")") - return []string{b.String()}, true + return b.String(), true, nil } func (db *mssql) ForUpdateSQL(query string) string { diff --git a/dialects/mysql.go b/dialects/mysql.go index 0489904a..ce3bd705 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -244,6 +244,12 @@ func (db *mysql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Ve }, nil } +func (db *mysql) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: IncrAutoincrMode, + } +} + func (db *mysql) SetParams(params map[string]string) { rowFormat, ok := params["rowFormat"] if ok { @@ -625,7 +631,7 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName return indexes, nil } -func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { +func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { if tableName == "" { tableName = table.Name } @@ -678,7 +684,8 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]strin b.WriteString(" ROW_FORMAT=") b.WriteString(db.rowFormat) } - return []string{b.String()}, true + + return b.String(), true, nil } func (db *mysql) Filters() []Filter { diff --git a/dialects/oracle.go b/dialects/oracle.go index 11a6653b..04652bd6 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -539,6 +539,12 @@ func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.V }, nil } +func (db *oracle) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: SequenceAutoincrMode, + } +} + func (db *oracle) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { @@ -599,7 +605,7 @@ func (db *oracle) DropTableSQL(tableName string) (string, bool) { return fmt.Sprintf("DROP TABLE `%s`", tableName), false } -func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { +func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { var sql = "CREATE TABLE " if tableName == "" { tableName = table.Name @@ -629,7 +635,7 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]stri } sql = sql[:len(sql)-2] + ")" - return []string{sql}, false + return sql, false, nil } func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { diff --git a/dialects/postgres.go b/dialects/postgres.go index 6b5a8b2f..822d3a70 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -941,6 +941,12 @@ func (db *postgres) SQLType(c *schemas.Column) string { return res } +func (db *postgres) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: IncrAutoincrMode, + } +} + func (db *postgres) ColumnTypeKind(t string) int { switch strings.ToUpper(t) { case "DATETIME", "TIMESTAMP": diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 4eba8dad..4ff9a39e 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -184,6 +184,12 @@ func (db *sqlite3) Version(ctx context.Context, queryer core.Queryer) (*schemas. }, nil } +func (db *sqlite3) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: IncrAutoincrMode, + } +} + func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: diff --git a/engine.go b/engine.go index ec066109..3681ff3c 100644 --- a/engine.go +++ b/engine.go @@ -248,11 +248,6 @@ func (engine *Engine) SQLType(c *schemas.Column) string { return engine.dialect.SQLType(c) } -// AutoIncrStr Database's autoincrement statement -func (engine *Engine) AutoIncrStr() string { - return engine.dialect.AutoIncrStr() -} - // SetConnMaxLifetime sets the maximum amount of time a connection may be reused. func (engine *Engine) SetConnMaxLifetime(d time.Duration) { engine.DB().SetConnMaxLifetime(d) @@ -441,7 +436,7 @@ func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp .. // DumpTables dump specify tables to io.Writer func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { - return engine.dumpTables(tables, w, tp...) + return engine.dumpTables(context.Background(), tables, w, tp...) } func formatBool(s string, dstDialect dialects.Dialect) string { @@ -457,7 +452,7 @@ func formatBool(s string, dstDialect dialects.Dialect) string { } // dumpTables dump database all table structs and data to w with specify db type -func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { +func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { var dstDialect dialects.Dialect if len(tp) == 0 { dstDialect = engine.dialect @@ -494,9 +489,12 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } } - dstTableName := dstTable.Name + var dstTableName = dstTable.Name + var quoter = dstDialect.Quoter().Quote + var quotedDstTableName = quoter(dstTable.Name) if dstDialect.URI().Schema != "" { dstTableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, dstTable.Name) + quotedDstTableName = fmt.Sprintf("%s.%s", quoter(dstDialect.URI().Schema), quoter(dstTable.Name)) } originalTableName := table.Name if engine.dialect.URI().Schema != "" { @@ -509,13 +507,26 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } } - sqls, _ := dstDialect.CreateTableSQL(dstTable, dstTableName) - for _, s := range sqls { - _, err = io.WriteString(w, s+";\n") + if dstTable.AutoIncrement != "" && dstDialect.Features().AutoincrMode == dialects.SequenceAutoincrMode { + sqlstr, err := dstDialect.CreateSequenceSQL(ctx, engine.db, utils.SeqName(dstTableName)) + if err != nil { + return err + } + _, err = io.WriteString(w, sqlstr+";\n") if err != nil { return err } } + + sqlstr, _, err := dstDialect.CreateTableSQL(ctx, engine.db, dstTable, dstTableName) + if err != nil { + return err + } + _, err = io.WriteString(w, sqlstr+";\n") + if err != nil { + return err + } + if len(dstTable.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL { fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", dstTable.Name) } @@ -552,7 +563,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch sess := engine.NewSession() defer sess.Close() for rows.Next() { - _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") + _, err = io.WriteString(w, "INSERT INTO "+quotedDstTableName+" ("+destColNames+") VALUES (") if err != nil { return err } @@ -563,36 +574,27 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } for i, scanResult := range scanResults { stp := schemas.SQLType{Name: types[i].DatabaseTypeName()} - if stp.IsNumeric() { - s := scanResult.(*sql.NullString) - if s.Valid { - if _, err = io.WriteString(w, formatBool(s.String, dstDialect)); err != nil { - return err - } - } else { - if _, err = io.WriteString(w, "NULL"); err != nil { - return err - } - } - } else if stp.IsBool() { - s := scanResult.(*sql.NullString) - if s.Valid { - if _, err = io.WriteString(w, formatBool(s.String, dstDialect)); err != nil { - return err - } - } else { - if _, err = io.WriteString(w, "NULL"); err != nil { - return err - } + s := scanResult.(*sql.NullString) + if !s.Valid { + if _, err = io.WriteString(w, "NULL"); err != nil { + return err } } else { - s := scanResult.(*sql.NullString) - if s.Valid { - if _, err = io.WriteString(w, "'"+strings.ReplaceAll(s.String, "'", "''")+"'"); err != nil { + if stp.IsBool() || (dstDialect.URI().DBType == schemas.MSSQL && strings.EqualFold(stp.Name, schemas.Bit)) { + if _, err = io.WriteString(w, formatBool(s.String, dstDialect)); err != nil { + return err + } + } else if stp.IsNumeric() { + if _, err = io.WriteString(w, s.String); err != nil { + return err + } + } else if sess.engine.dialect.URI().DBType == schemas.DAMENG && stp.IsTime() && len(s.String) == 25 { + r := strings.Replace(s.String[:19], "T", " ", -1) + if _, err = io.WriteString(w, "'"+r+"'"); err != nil { return err } } else { - if _, err = io.WriteString(w, "NULL"); err != nil { + if _, err = io.WriteString(w, "'"+strings.ReplaceAll(s.String, "'", "''")+"'"); err != nil { return err } } diff --git a/go.mod b/go.mod index d645011c..98b5617c 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module xorm.io/xorm go 1.13 require ( + gitee.com/travelliu/dm v1.8.11192 github.com/denisenkom/go-mssqldb v0.10.0 github.com/go-sql-driver/mysql v1.6.0 github.com/goccy/go-json v0.7.4 diff --git a/go.sum b/go.sum index e8024945..4596f326 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s= gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU= +gitee.com/travelliu/dm v1.8.11192 h1:aqJT0xhodZjRutIfEXxKYv0CxqmHUHzsbz6SFaRL6OY= +gitee.com/travelliu/dm v1.8.11192/go.mod h1:DHTzyhCrM843x9VdKVbZ+GKXGRbKM2sJ4LxihRxShkE= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= @@ -95,8 +97,9 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= diff --git a/integrations/cache_test.go b/integrations/cache_test.go index 44e817b1..6fdaa5cb 100644 --- a/integrations/cache_test.go +++ b/integrations/cache_test.go @@ -62,7 +62,8 @@ func TestCacheFind(t *testing.T) { } boxes = make([]MailBox, 0, 2) - assert.NoError(t, testEngine.Alias("a").Where("a.id > -1").Asc("a.id").Find(&boxes)) + assert.NoError(t, testEngine.Alias("a").Where("`a`.`id`> -1"). + Asc("`a`.`id`").Find(&boxes)) assert.EqualValues(t, 2, len(boxes)) for i, box := range boxes { assert.Equal(t, inserts[i].Id, box.Id) @@ -77,7 +78,8 @@ func TestCacheFind(t *testing.T) { } boxes2 := make([]MailBox4, 0, 2) - assert.NoError(t, testEngine.Table("mail_box").Where("mail_box.id > -1").Asc("mail_box.id").Find(&boxes2)) + assert.NoError(t, testEngine.Table("mail_box").Where("`mail_box`.`id` > -1"). + Asc("mail_box.id").Find(&boxes2)) assert.EqualValues(t, 2, len(boxes2)) for i, box := range boxes2 { assert.Equal(t, inserts[i].Id, box.Id) @@ -164,14 +166,14 @@ func TestCacheGet(t *testing.T) { assert.NoError(t, err) var box1 MailBox3 - has, err := testEngine.Where("id = ?", inserts[0].Id).Get(&box1) + has, err := testEngine.Where("`id` = ?", inserts[0].Id).Get(&box1) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "user1", box1.Username) assert.EqualValues(t, "pass1", box1.Password) var box2 MailBox3 - has, err = testEngine.Where("id = ?", inserts[0].Id).Get(&box2) + has, err = testEngine.Where("`id` = ?", inserts[0].Id).Get(&box2) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "user1", box2.Username) diff --git a/integrations/engine_dm_test.go b/integrations/engine_dm_test.go new file mode 100644 index 00000000..6c2f6103 --- /dev/null +++ b/integrations/engine_dm_test.go @@ -0,0 +1,13 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build dm + +package integrations + +import "xorm.io/xorm/schemas" + +func init() { + dbtypes = append(dbtypes, schemas.DAMENG) +} diff --git a/integrations/engine_test.go b/integrations/engine_test.go index 02b35a2c..f45d6859 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -14,6 +14,7 @@ import ( "xorm.io/xorm" "xorm.io/xorm/schemas" + _ "gitee.com/travelliu/dm" _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" _ "github.com/jackc/pgx/v4/stdlib" @@ -134,6 +135,8 @@ func TestDump(t *testing.T) { } } +var dbtypes = []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL} + func TestDumpTables(t *testing.T) { assert.NoError(t, PrepareEngine()) @@ -169,7 +172,7 @@ func TestDumpTables(t *testing.T) { assert.NoError(t, err) assert.NoError(t, sess.Commit()) - for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL} { + for _, tp := range dbtypes { name := fmt.Sprintf("dump_%v-table.sql", tp) t.Run(name, func(t *testing.T) { assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp)) diff --git a/integrations/performance_test.go b/integrations/performance_test.go index 4b54b40c..241dc7b9 100644 --- a/integrations/performance_test.go +++ b/integrations/performance_test.go @@ -32,7 +32,7 @@ func BenchmarkGetVars(b *testing.B) { b.StartTimer() var myname string for i := 0; i < b.N; i++ { - has, err := testEngine.Cols("name").Table("benchmark_get_vars").Where("id=?", v.Id).Get(&myname) + has, err := testEngine.Cols("name").Table("benchmark_get_vars").Where("`id`=?", v.Id).Get(&myname) b.StopTimer() myname = "" assert.True(b, has) diff --git a/integrations/session_cols_test.go b/integrations/session_cols_test.go index b74c6f8a..18ba4001 100644 --- a/integrations/session_cols_test.go +++ b/integrations/session_cols_test.go @@ -45,7 +45,7 @@ func TestSetExpr(t *testing.T) { assert.EqualValues(t, 1, cnt) var not = "NOT" - if testEngine.Dialect().URI().DBType == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL || testEngine.Dialect().URI().DBType == schemas.DAMENG { not = "~" } cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr)) @@ -54,9 +54,9 @@ func TestSetExpr(t *testing.T) { tableName := testEngine.TableName(new(UserExprIssue), true) cnt, err = testEngine.SetExpr("issue_id", - builder.Select("id"). - From(tableName). - Where(builder.Eq{"id": issue.Id})). + builder.Select("`id`"). + From(testEngine.Quote(tableName)). + Where(builder.Eq{"`id`": issue.Id})). ID(1). Update(new(UserExpr)) assert.NoError(t, err) diff --git a/integrations/session_cond_test.go b/integrations/session_cond_test.go index a0a91cad..72a4caf5 100644 --- a/integrations/session_cond_test.go +++ b/integrations/session_cond_test.go @@ -37,49 +37,50 @@ func TestBuilder(t *testing.T) { assert.NoError(t, err) var cond Condition - has, err := testEngine.Where(builder.Eq{"col_name": "col1"}).Get(&cond) + var q = testEngine.Quote + has, err := testEngine.Where(builder.Eq{q("col_name"): "col1"}).Get(&cond) assert.NoError(t, err) assert.Equal(t, true, has, "records should exist") - has, err = testEngine.Where(builder.Eq{"col_name": "col1"}. - And(builder.Eq{"op": OpEqual})). + has, err = testEngine.Where(builder.Eq{q("col_name"): "col1"}. + And(builder.Eq{q("op"): OpEqual})). NoAutoCondition(). Get(&cond) assert.NoError(t, err) assert.Equal(t, true, has, "records should exist") - has, err = testEngine.Where(builder.Eq{"col_name": "col1", "op": OpEqual, "value": "1"}). + has, err = testEngine.Where(builder.Eq{q("col_name"): "col1", q("op"): OpEqual, q("value"): "1"}). NoAutoCondition(). Get(&cond) assert.NoError(t, err) assert.Equal(t, true, has, "records should exist") - has, err = testEngine.Where(builder.Eq{"col_name": "col1"}. - And(builder.Neq{"op": OpEqual})). + has, err = testEngine.Where(builder.Eq{q("col_name"): "col1"}. + And(builder.Neq{q("op"): OpEqual})). NoAutoCondition(). Get(&cond) assert.NoError(t, err) assert.Equal(t, false, has, "records should not exist") var conds []Condition - err = testEngine.Where(builder.Eq{"col_name": "col1"}. - And(builder.Eq{"op": OpEqual})). + err = testEngine.Where(builder.Eq{q("col_name"): "col1"}. + And(builder.Eq{q("op"): OpEqual})). Find(&conds) assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") conds = make([]Condition, 0) - err = testEngine.Where(builder.Like{"col_name", "col"}).Find(&conds) + err = testEngine.Where(builder.Like{q("col_name"), "col"}).Find(&conds) assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") conds = make([]Condition, 0) - err = testEngine.Where(builder.Expr("col_name = ?", "col1")).Find(&conds) + err = testEngine.Where(builder.Expr(q("col_name")+" = ?", "col1")).Find(&conds) assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") conds = make([]Condition, 0) - err = testEngine.Where(builder.In("col_name", "col1", "col2")).Find(&conds) + err = testEngine.Where(builder.In(q("col_name"), "col1", "col2")).Find(&conds) assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") @@ -91,8 +92,8 @@ func TestBuilder(t *testing.T) { // complex condtions var where = builder.NewCond() if true { - where = where.And(builder.Eq{"col_name": "col1"}) - where = where.Or(builder.And(builder.In("col_name", "col1", "col2"), builder.Expr("col_name = ?", "col1"))) + where = where.And(builder.Eq{q("col_name"): "col1"}) + where = where.Or(builder.And(builder.In(q("col_name"), "col1", "col2"), builder.Expr(q("col_name")+" = ?", "col1"))) } conds = make([]Condition, 0) @@ -215,7 +216,7 @@ func TestFindAndCount(t *testing.T) { assert.NoError(t, err) var results []FindAndCount - sess := testEngine.Where("name = ?", "test1") + sess := testEngine.Where("`name` = ?", "test1") conds := sess.Conds() err = sess.Find(&results) assert.NoError(t, err) diff --git a/integrations/session_count_test.go b/integrations/session_count_test.go index 1517dede..e49b6045 100644 --- a/integrations/session_count_test.go +++ b/integrations/session_count_test.go @@ -63,7 +63,7 @@ func TestSQLCount(t *testing.T) { assertSync(t, new(UserinfoCount2), new(UserinfoBooks)) - total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableName("userinfo_count2", true)). + total, err := testEngine.SQL("SELECT count(`id`) FROM " + testEngine.Quote(testEngine.TableName("userinfo_count2", true))). Count() assert.NoError(t, err) assert.EqualValues(t, 0, total) @@ -89,7 +89,7 @@ func TestCountWithOthers(t *testing.T) { }) assert.NoError(t, err) - total, err := testEngine.OrderBy("id desc").Limit(1).Count(new(CountWithOthers)) + total, err := testEngine.OrderBy("`id` desc").Limit(1).Count(new(CountWithOthers)) assert.NoError(t, err) assert.EqualValues(t, 2, total) } @@ -118,11 +118,11 @@ func TestWithTableName(t *testing.T) { }) assert.NoError(t, err) - total, err := testEngine.OrderBy("id desc").Count(new(CountWithTableName)) + total, err := testEngine.OrderBy("`id` desc").Count(new(CountWithTableName)) assert.NoError(t, err) assert.EqualValues(t, 2, total) - total, err = testEngine.OrderBy("id desc").Count(CountWithTableName{}) + total, err = testEngine.OrderBy("`id` desc").Count(CountWithTableName{}) assert.NoError(t, err) assert.EqualValues(t, 2, total) } @@ -146,7 +146,7 @@ func TestCountWithSelectCols(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, total) - total, err = testEngine.Select("count(id)").Count(CountWithTableName{}) + total, err = testEngine.Select("count(`id`)").Count(CountWithTableName{}) assert.NoError(t, err) assert.EqualValues(t, 2, total) } @@ -166,7 +166,7 @@ func TestCountWithGroupBy(t *testing.T) { }) assert.NoError(t, err) - cnt, err := testEngine.GroupBy("name").Count(new(CountWithTableName)) + cnt, err := testEngine.GroupBy("`name`").Count(new(CountWithTableName)) assert.NoError(t, err) assert.EqualValues(t, 2, cnt) } diff --git a/integrations/session_exist_test.go b/integrations/session_exist_test.go index 29546376..a0f65211 100644 --- a/integrations/session_exist_test.go +++ b/integrations/session_exist_test.go @@ -48,19 +48,19 @@ func TestExistStruct(t *testing.T) { assert.NoError(t, err) assert.False(t, has) - has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{}) + has, err = testEngine.Where("`name` = ?", "test1").Exist(&RecordExist{}) assert.NoError(t, err) assert.True(t, has) - has, err = testEngine.Where("name = ?", "test2").Exist(&RecordExist{}) + has, err = testEngine.Where("`name` = ?", "test2").Exist(&RecordExist{}) assert.NoError(t, err) assert.False(t, has) - has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test1").Exist() + has, err = testEngine.SQL("select * from "+testEngine.Quote(testEngine.TableName("record_exist", true))+" where `name` = ?", "test1").Exist() assert.NoError(t, err) assert.True(t, has) - has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test2").Exist() + has, err = testEngine.SQL("select * from "+testEngine.Quote(testEngine.TableName("record_exist", true))+" where `name` = ?", "test2").Exist() assert.NoError(t, err) assert.False(t, has) @@ -68,11 +68,11 @@ func TestExistStruct(t *testing.T) { assert.NoError(t, err) assert.True(t, has) - has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist() + has, err = testEngine.Table("record_exist").Where("`name` = ?", "test1").Exist() assert.NoError(t, err) assert.True(t, has) - has, err = testEngine.Table("record_exist").Where("name = ?", "test2").Exist() + has, err = testEngine.Table("record_exist").Where("`name` = ?", "test2").Exist() assert.NoError(t, err) assert.False(t, has) @@ -124,43 +124,43 @@ func TestExistStructForJoin(t *testing.T) { defer session.Close() session.Table("number"). - Join("INNER", "order_list", "order_list.id = number.lid"). - Join("LEFT", "player", "player.id = order_list.eid"). - Where("number.lid = ?", 1) + Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`"). + Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`"). + Where("`number`.`lid` = ?", 1) has, err := session.Exist() assert.NoError(t, err) assert.True(t, has) session.Table("number"). - Join("INNER", "order_list", "order_list.id = number.lid"). - Join("LEFT", "player", "player.id = order_list.eid"). - Where("number.lid = ?", 2) + Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`"). + Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`"). + Where("`number`.`lid` = ?", 2) has, err = session.Exist() assert.NoError(t, err) assert.False(t, has) session.Table("number"). - Select("order_list.id"). - Join("INNER", "order_list", "order_list.id = number.lid"). - Join("LEFT", "player", "player.id = order_list.eid"). - Where("order_list.id = ?", 1) + Select("`order_list`.`id`"). + Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`"). + Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`"). + Where("`order_list`.`id` = ?", 1) has, err = session.Exist() assert.NoError(t, err) assert.True(t, has) session.Table("number"). Select("player.id"). - Join("INNER", "order_list", "order_list.id = number.lid"). - Join("LEFT", "player", "player.id = order_list.eid"). - Where("player.id = ?", 2) + Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`"). + Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`"). + Where("`player`.`id` = ?", 2) has, err = session.Exist() assert.NoError(t, err) assert.False(t, has) session.Table("number"). Select("player.id"). - Join("INNER", "order_list", "order_list.id = number.lid"). - Join("LEFT", "player", "player.id = order_list.eid") + Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`"). + Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`") has, err = session.Exist() assert.NoError(t, err) assert.True(t, has) @@ -174,15 +174,15 @@ func TestExistStructForJoin(t *testing.T) { session.Table("number"). Select("player.id"). - Join("INNER", "order_list", "order_list.id = number.lid"). - Join("LEFT", "player", "player.id = order_list.eid") + Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`"). + Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`") has, err = session.Exist() assert.Error(t, err) assert.False(t, has) session.Table("number"). Select("player.id"). - Join("LEFT", "player", "player.id = number.lid") + Join("LEFT", "player", "`player`.`id` = `number`.`lid`") has, err = session.Exist() assert.NoError(t, err) assert.True(t, has) diff --git a/integrations/session_find_test.go b/integrations/session_find_test.go index 1cbf5e42..9b503f25 100644 --- a/integrations/session_find_test.go +++ b/integrations/session_find_test.go @@ -56,8 +56,8 @@ func TestJoinLimit(t *testing.T) { var salaries []Salary err = testEngine.Table("salary"). - Join("INNER", "check_list", "check_list.id = salary.lid"). - Join("LEFT", "empsetting", "empsetting.id = check_list.eid"). + Join("INNER", "check_list", "`check_list`.`id` = `salary`.`lid`"). + Join("LEFT", "empsetting", "`empsetting`.`id` = `check_list`.`eid`"). Limit(10, 0). Find(&salaries) assert.NoError(t, err) @@ -69,10 +69,10 @@ func TestWhere(t *testing.T) { assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.Where("id > ?", 2).Find(&users) + err := testEngine.Where("`id` > ?", 2).Find(&users) assert.NoError(t, err) - err = testEngine.Where("id > ?", 2).And("id < ?", 10).Find(&users) + err = testEngine.Where("`id` > ?", 2).And("`id` < ?", 10).Find(&users) assert.NoError(t, err) } @@ -125,50 +125,50 @@ func TestFind3(t *testing.T) { assert.NoError(t, err) var teams []Team - err = testEngine.Cols("`team`.id"). - Where("`team_user`.org_id=?", 1). - And("`team_user`.uid=?", 2). - Join("INNER", "`team_user`", "`team_user`.team_id=`team`.id"). + err = testEngine.Cols("`team`.`id`"). + Where("`team_user`.`org_id`=?", 1). + And("`team_user`.`uid`=?", 2). + Join("INNER", "`team_user`", "`team_user`.`team_id`=`team`.`id`"). Find(&teams) assert.NoError(t, err) teams = make([]Team, 0) err = testEngine.Cols("`team`.id"). - Where("`team_user`.org_id=?", 1). - And("`team_user`.uid=?", 2). - Join("INNER", teamUser, "`team_user`.team_id=`team`.id"). + Where("`team_user`.`org_id`=?", 1). + And("`team_user`.`uid`=?", 2). + Join("INNER", teamUser, "`team_user`.`team_id`=`team`.`id`"). Find(&teams) assert.NoError(t, err) teams = make([]Team, 0) - err = testEngine.Cols("`team`.id"). - Where("`team_user`.org_id=?", 1). - And("`team_user`.uid=?", 2). - Join("INNER", []interface{}{teamUser}, "`team_user`.team_id=`team`.id"). + err = testEngine.Cols("`team`.`id`"). + Where("`team_user`.`org_id`=?", 1). + And("`team_user`.`uid`=?", 2). + Join("INNER", []interface{}{teamUser}, "`team_user`.`team_id`=`team`.`id`"). Find(&teams) assert.NoError(t, err) teams = make([]Team, 0) - err = testEngine.Cols("`team`.id"). - Where("`tu`.org_id=?", 1). - And("`tu`.uid=?", 2). - Join("INNER", []string{"team_user", "tu"}, "`tu`.team_id=`team`.id"). + err = testEngine.Cols("`team`.`id`"). + Where("`tu`.`org_id`=?", 1). + And("`tu`.`uid`=?", 2). + Join("INNER", []string{"team_user", "tu"}, "`tu`.`team_id`=`team`.`id`"). Find(&teams) assert.NoError(t, err) teams = make([]Team, 0) - err = testEngine.Cols("`team`.id"). - Where("`tu`.org_id=?", 1). - And("`tu`.uid=?", 2). - Join("INNER", []interface{}{"team_user", "tu"}, "`tu`.team_id=`team`.id"). + err = testEngine.Cols("`team`.`id`"). + Where("`tu`.`org_id`=?", 1). + And("`tu`.`uid`=?", 2). + Join("INNER", []interface{}{"team_user", "tu"}, "`tu`.`team_id`=`team`.`id`"). Find(&teams) assert.NoError(t, err) teams = make([]Team, 0) - err = testEngine.Cols("`team`.id"). - Where("`tu`.org_id=?", 1). - And("`tu`.uid=?", 2). - Join("INNER", []interface{}{teamUser, "tu"}, "`tu`.team_id=`team`.id"). + err = testEngine.Cols("`team`.`id`"). + Where("`tu`.`org_id`=?", 1). + And("`tu`.`uid`=?", 2). + Join("INNER", []interface{}{teamUser, "tu"}, "`tu`.`team_id`=`team`.`id`"). Find(&teams) assert.NoError(t, err) } @@ -241,7 +241,7 @@ func TestOrder(t *testing.T) { assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.OrderBy("id desc").Find(&users) + err := testEngine.OrderBy("`id` desc").Find(&users) assert.NoError(t, err) users2 := make([]Userinfo, 0) @@ -254,7 +254,7 @@ func TestGroupBy(t *testing.T) { assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.GroupBy("id, username").Find(&users) + err := testEngine.GroupBy("`id`, `username`").Find(&users) assert.NoError(t, err) } @@ -263,7 +263,7 @@ func TestHaving(t *testing.T) { assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users) + err := testEngine.GroupBy("`username`").Having("`username`='xlw'").Find(&users) assert.NoError(t, err) } @@ -499,7 +499,7 @@ func TestFindAndCountOneFunc(t *testing.T) { assert.EqualValues(t, 2, cnt) results = make([]FindAndCountStruct, 0, 1) - cnt, err = testEngine.Where("msg = ?", true).FindAndCount(&results) + cnt, err = testEngine.Where("`msg` = ?", true).FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, cnt) @@ -549,21 +549,21 @@ func TestFindAndCountOneFunc(t *testing.T) { }, results[0]) results = make([]FindAndCountStruct, 0, 1) - cnt, err = testEngine.Where("msg = ?", true).Select("id, content, msg"). + cnt, err = testEngine.Where("`msg` = ?", true).Select("`id`, `content`, `msg`"). Limit(1).FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, cnt) results = make([]FindAndCountStruct, 0, 1) - cnt, err = testEngine.Where("msg = ?", true).Cols("id", "content", "msg"). + cnt, err = testEngine.Where("`msg` = ?", true).Cols("id", "content", "msg"). Limit(1).FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, cnt) results = make([]FindAndCountStruct, 0, 1) - cnt, err = testEngine.Where("msg = ?", true).Desc("id"). + cnt, err = testEngine.Where("`msg` = ?", true).Desc("id"). Limit(1).Cols("content").FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) @@ -649,7 +649,7 @@ func TestFindAndCount2(t *testing.T) { cnt, err = testEngine. Table(new(TestFindAndCountHotel)). Alias("t"). - Where("t.region like '6501%'"). + Where("`t`.`region` like '6501%'"). Limit(10, 0). FindAndCount(&hotels) assert.NoError(t, err) @@ -705,7 +705,7 @@ func TestFindAndCountWithGroupBy(t *testing.T) { assert.NoError(t, err) var results []FindAndCountWithGroupBy - cnt, err := testEngine.GroupBy("age").FindAndCount(&results) + cnt, err := testEngine.GroupBy("`age`").FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 2, cnt) assert.EqualValues(t, 2, len(results)) @@ -735,14 +735,14 @@ func TestFindMapStringId(t *testing.T) { deviceMaps := make(map[string]*FindMapDevice, len(deviceIDs)) err = testEngine. - Where("status = ?", 1). + Where("`status` = ?", 1). In("deviceid", deviceIDs). Find(&deviceMaps) assert.NoError(t, err) deviceMaps2 := make(map[string]FindMapDevice, len(deviceIDs)) err = testEngine. - Where("status = ?", 1). + Where("`status` = ?", 1). In("deviceid", deviceIDs). Find(&deviceMaps2) assert.NoError(t, err) @@ -919,17 +919,17 @@ func TestFindJoin(t *testing.T) { assertSync(t, new(SceneItem), new(DeviceUserPrivrels), new(Order)) var scenes []SceneItem - err := testEngine.Join("LEFT OUTER", "device_user_privrels", "device_user_privrels.device_id=scene_item.device_id"). - Where("scene_item.type=?", 3).Or("device_user_privrels.user_id=?", 339).Find(&scenes) + err := testEngine.Join("LEFT OUTER", "device_user_privrels", "`device_user_privrels`.`device_id`=`scene_item`.`device_id`"). + Where("`scene_item`.`type`=?", 3).Or("`device_user_privrels`.`user_id`=?", 339).Find(&scenes) assert.NoError(t, err) scenes = make([]SceneItem, 0) - err = testEngine.Join("LEFT OUTER", new(DeviceUserPrivrels), "device_user_privrels.device_id=scene_item.device_id"). - Where("scene_item.type=?", 3).Or("device_user_privrels.user_id=?", 339).Find(&scenes) + err = testEngine.Join("LEFT OUTER", new(DeviceUserPrivrels), "`device_user_privrels`.`device_id`=`scene_item`.`device_id`"). + Where("`scene_item`.`type`=?", 3).Or("`device_user_privrels`.`user_id`=?", 339).Find(&scenes) assert.NoError(t, err) scenes = make([]SceneItem, 0) - err = testEngine.Join("INNER", "order", "`scene_item`.device_id=`order`.id").Find(&scenes) + err = testEngine.Join("INNER", "order", "`scene_item`.`device_id`=`order`.`id`").Find(&scenes) assert.NoError(t, err) } @@ -949,7 +949,7 @@ func TestJoinFindLimit(t *testing.T) { assertSync(t, new(JoinFindLimit1), new(JoinFindLimit2)) var finds []JoinFindLimit1 - err := testEngine.Join("INNER", new(JoinFindLimit2), "join_find_limit2.eid=join_find_limit1.id"). + err := testEngine.Join("INNER", new(JoinFindLimit2), "`join_find_limit2`.`eid`=`join_find_limit1`.`id`"). Limit(10, 10).Find(&finds) assert.NoError(t, err) } @@ -981,9 +981,9 @@ func TestMoreExtends(t *testing.T) { assertSync(t, new(MoreExtendsUsers), new(MoreExtendsBooks)) var books []MoreExtendsBooksExtend - err := testEngine.Table("more_extends_books").Select("more_extends_books.*, more_extends_users.*"). - Join("INNER", "more_extends_users", "more_extends_books.user_id = more_extends_users.id"). - Where("more_extends_books.name LIKE ?", "abc"). + err := testEngine.Table("more_extends_books").Select("`more_extends_books`.*, `more_extends_users`.*"). + Join("INNER", "more_extends_users", "`more_extends_books`.`user_id` = `more_extends_users`.`id`"). + Where("`more_extends_books`.`name` LIKE ?", "abc"). Limit(10, 10). Find(&books) assert.NoError(t, err) @@ -991,9 +991,9 @@ func TestMoreExtends(t *testing.T) { books = make([]MoreExtendsBooksExtend, 0, len(books)) err = testEngine.Table("more_extends_books"). Alias("m"). - Select("m.*, more_extends_users.*"). - Join("INNER", "more_extends_users", "m.user_id = more_extends_users.id"). - Where("m.name LIKE ?", "abc"). + Select("`m`.*, `more_extends_users`.*"). + Join("INNER", "more_extends_users", "`m`.`user_id` = `more_extends_users`.`id`"). + Where("`m`.`name` LIKE ?", "abc"). Limit(10, 10). Find(&books) assert.NoError(t, err) @@ -1038,11 +1038,11 @@ func TestUpdateFind(t *testing.T) { } _, err := session.Insert(&tuf) assert.NoError(t, err) - _, err = session.Where("id = ?", tuf.Id).Update(&TestUpdateFind{}) + _, err = session.Where("`id` = ?", tuf.Id).Update(&TestUpdateFind{}) assert.EqualError(t, xorm.ErrNoColumnsTobeUpdated, err.Error()) var tufs []TestUpdateFind - err = session.Where("id = ?", tuf.Id).Find(&tufs) + err = session.Where("`id` = ?", tuf.Id).Find(&tufs) assert.NoError(t, err) } diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 99ecd462..4fc30adb 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -15,6 +15,7 @@ import ( "xorm.io/xorm" "xorm.io/xorm/contexts" "xorm.io/xorm/convert" + "xorm.io/xorm/dialects" "xorm.io/xorm/schemas" "github.com/shopspring/decimal" @@ -55,15 +56,15 @@ func TestGetVar(t *testing.T) { assert.Equal(t, 28, age) var ageMax int - has, err = testEngine.SQL("SELECT max(age) FROM "+testEngine.TableName("get_var", true)+" WHERE `id` = ?", data.Id).Get(&ageMax) + has, err = testEngine.SQL("SELECT max(`age`) FROM "+testEngine.Quote(testEngine.TableName("get_var", true))+" WHERE `id` = ?", data.Id).Get(&ageMax) assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, 28, ageMax) var age2 int64 has, err = testEngine.Table("get_var").Cols("age"). - Where("age > ?", 20). - And("age < ?", 30). + Where("`age` > ?", 20). + And("`age` < ?", 30). Get(&age2) assert.NoError(t, err) assert.Equal(t, true, has) @@ -77,8 +78,8 @@ func TestGetVar(t *testing.T) { var age4 int16 has, err = testEngine.Table("get_var").Cols("age"). - Where("age > ?", 20). - And("age < ?", 30). + Where("`age` > ?", 20). + And("`age` < ?", 30). Get(&age4) assert.NoError(t, err) assert.Equal(t, true, has) @@ -86,8 +87,8 @@ func TestGetVar(t *testing.T) { var age5 int32 has, err = testEngine.Table("get_var").Cols("age"). - Where("age > ?", 20). - And("age < ?", 30). + Where("`age` > ?", 20). + And("`age` < ?", 30). Get(&age5) assert.NoError(t, err) assert.Equal(t, true, has) @@ -101,8 +102,8 @@ func TestGetVar(t *testing.T) { var age7 int64 has, err = testEngine.Table("get_var").Cols("age"). - Where("age > ?", 20). - And("age < ?", 30). + Where("`age` > ?", 20). + And("`age` < ?", 30). Get(&age7) assert.NoError(t, err) assert.Equal(t, true, has) @@ -116,8 +117,8 @@ func TestGetVar(t *testing.T) { var age9 int16 has, err = testEngine.Table("get_var").Cols("age"). - Where("age > ?", 20). - And("age < ?", 30). + Where("`age` > ?", 20). + And("`age` < ?", 30). Get(&age9) assert.NoError(t, err) assert.Equal(t, true, has) @@ -125,8 +126,8 @@ func TestGetVar(t *testing.T) { var age10 int32 has, err = testEngine.Table("get_var").Cols("age"). - Where("age > ?", 20). - And("age < ?", 30). + Where("`age` > ?", 20). + And("`age` < ?", 30). Get(&age10) assert.NoError(t, err) assert.Equal(t, true, has) @@ -161,16 +162,16 @@ func TestGetVar(t *testing.T) { var money2 float64 if testEngine.Dialect().URI().DBType == schemas.MSSQL { - has, err = testEngine.SQL("SELECT TOP 1 money FROM " + testEngine.TableName("get_var", true)).Get(&money2) + has, err = testEngine.SQL("SELECT TOP 1 `money` FROM " + testEngine.Quote(testEngine.TableName("get_var", true))).Get(&money2) } else { - has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2) + has, err = testEngine.SQL("SELECT `money` FROM " + testEngine.Quote(testEngine.TableName("get_var", true)) + " LIMIT 1").Get(&money2) } assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2)) var money3 float64 - has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " WHERE money > 20").Get(&money3) + has, err = testEngine.SQL("SELECT `money` FROM " + testEngine.Quote(testEngine.TableName("get_var", true)) + " WHERE `money` > 20").Get(&money3) assert.NoError(t, err) assert.Equal(t, false, has) @@ -187,7 +188,7 @@ func TestGetVar(t *testing.T) { // for mymysql driver, interface{} will be []byte, so ignore it currently if testEngine.DriverName() != "mymysql" { var valuesInter = make(map[string]interface{}) - has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter) + has, err = testEngine.Table("get_var").Where("`id` = ?", 1).Select("*").Get(&valuesInter) assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, 5, len(valuesInter)) @@ -243,7 +244,7 @@ func TestGetStruct(t *testing.T) { if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) - _, err = session.Exec("SET IDENTITY_INSERT userinfo_get ON") + _, err = session.Exec("SET IDENTITY_INSERT `userinfo_get` ON") assert.NoError(t, err) } cnt, err := session.Insert(&UserinfoGet{Uid: 2}) @@ -300,6 +301,11 @@ func TestGetSlice(t *testing.T) { func TestGetMap(t *testing.T) { assert.NoError(t, PrepareEngine()) + if testEngine.Dialect().Features().AutoincrMode == dialects.SequenceAutoincrMode { + t.SkipNow() + return + } + type UserinfoMap struct { Uid int `xorm:"pk autoincr"` IsMan bool @@ -308,7 +314,7 @@ func TestGetMap(t *testing.T) { assertSync(t, new(UserinfoMap)) tableName := testEngine.Quote(testEngine.TableName("userinfo_map", true)) - _, err := testEngine.Exec(fmt.Sprintf("INSERT INTO %s (is_man) VALUES (NULL)", tableName)) + _, err := testEngine.Exec(fmt.Sprintf("INSERT INTO %s (`is_man`) VALUES (NULL)", tableName)) assert.NoError(t, err) var valuesString = make(map[string]string) @@ -479,7 +485,7 @@ func TestGetStructId(t *testing.T) { //var id int64 var maxid maxidst - sql := "select max(id) as id from " + testEngine.TableName(&TestGetStruct{}, true) + sql := "select max(`id`) as id from " + testEngine.Quote(testEngine.TableName(&TestGetStruct{}, true)) has, err := testEngine.SQL(sql).Get(&maxid) assert.NoError(t, err) assert.True(t, has) @@ -597,73 +603,78 @@ func TestGetNullVar(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(TestGetNullVarStruct)) - affected, err := testEngine.Exec("insert into " + testEngine.TableName(new(TestGetNullVarStruct), true) + " (name,age) values (null,null)") + if testEngine.Dialect().Features().AutoincrMode == dialects.SequenceAutoincrMode { + t.SkipNow() + return + } + + affected, err := testEngine.Exec("insert into " + testEngine.Quote(testEngine.TableName(new(TestGetNullVarStruct), true)) + " (`name`,`age`) values (null,null)") assert.NoError(t, err) a, _ := affected.RowsAffected() assert.EqualValues(t, 1, a) var name string - has, err := testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("name").Get(&name) + has, err := testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("name").Get(&name) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "", name) var age int - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age) var age2 int8 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age2) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age2) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age2) var age3 int16 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age3) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age3) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age3) var age4 int32 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age4) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age4) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age4) var age5 int64 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age5) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age5) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age5) var age6 uint - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age6) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age6) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age6) var age7 uint8 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age7) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age7) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age7) var age8 int16 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age8) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age8) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age8) var age9 int32 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age9) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age9) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age9) var age10 int64 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age10) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age10) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age10) @@ -697,7 +708,7 @@ func TestCustomTypes(t *testing.T) { assert.EqualValues(t, "test", name) var age MyInt - has, err = testEngine.Table(new(TestCustomizeStruct)).ID(s.Id).Select("age").Get(&age) + has, err = testEngine.Table(new(TestCustomizeStruct)).ID(s.Id).Select("`age`").Get(&age) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 32, age) @@ -759,7 +770,7 @@ func TestGetBigFloat(t *testing.T) { assert.NoError(t, err) var m big.Float - has, err := testEngine.Table("get_big_float").Cols("money").Where("id=?", gf.Id).Get(&m) + has, err := testEngine.Table("get_big_float").Cols("money").Where("`id`=?", gf.Id).Get(&m) assert.NoError(t, err) assert.True(t, has) assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) @@ -785,7 +796,7 @@ func TestGetBigFloat(t *testing.T) { assert.NoError(t, err) var m2 big.Float - has, err := testEngine.Table("get_big_float2").Cols("money").Where("id=?", gf2.Id).Get(&m2) + has, err := testEngine.Table("get_big_float2").Cols("money").Where("`id`=?", gf2.Id).Get(&m2) assert.NoError(t, err) assert.True(t, has) assert.True(t, m2.String() == gf2.Money.String(), "%v != %v", m2.String(), gf2.Money.String()) @@ -825,7 +836,7 @@ func TestGetDecimal(t *testing.T) { assert.NoError(t, err) var m decimal.Decimal - has, err := testEngine.Table("get_decimal").Cols("money").Where("id=?", gf.Id).Get(&m) + has, err := testEngine.Table("get_decimal").Cols("money").Where("`id`=?", gf.Id).Get(&m) assert.NoError(t, err) assert.True(t, has) assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) @@ -850,7 +861,7 @@ func TestGetDecimal(t *testing.T) { assert.NoError(t, err) var m decimal.Decimal - has, err := testEngine.Table("get_decimal2").Cols("money").Where("id=?", gf.Id).Get(&m) + has, err := testEngine.Table("get_decimal2").Cols("money").Where("`id`=?", gf.Id).Get(&m) assert.NoError(t, err) assert.True(t, has) assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) diff --git a/integrations/session_insert_test.go b/integrations/session_insert_test.go index cd56a958..70ec13f3 100644 --- a/integrations/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -11,6 +11,7 @@ import ( "time" "xorm.io/xorm" + "xorm.io/xorm/schemas" "github.com/stretchr/testify/assert" ) @@ -191,8 +192,8 @@ func TestInsertDefault(t *testing.T) { assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, -1, di.Status) - assert.EqualValues(t, di2.Updated.Unix(), di.Updated.Unix()) - assert.EqualValues(t, di2.Created.Unix(), di.Created.Unix()) + assert.EqualValues(t, di2.Updated.Unix(), di.Updated.Unix(), di.Updated) + assert.EqualValues(t, di2.Created.Unix(), di.Created.Unix(), di.Created) } func TestInsertDefault2(t *testing.T) { @@ -624,6 +625,11 @@ func TestAnonymousStruct(t *testing.T) { } func TestInsertMap(t *testing.T) { + if testEngine.Dialect().URI().DBType == schemas.DAMENG { + t.SkipNow() + return + } + type InsertMap struct { Id int64 Width uint32 @@ -727,7 +733,7 @@ func TestInsertWhere(t *testing.T) { } inserted, err := testEngine.SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). - Where("repo_id=?", 1). + Where("`repo_id`=?", 1). Insert(&i) assert.NoError(t, err) assert.EqualValues(t, 1, inserted) @@ -740,7 +746,12 @@ func TestInsertWhere(t *testing.T) { i.Index = 1 assert.EqualValues(t, i, j) - inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + if testEngine.Dialect().URI().DBType == schemas.DAMENG { + t.SkipNow() + return + } + + inserted, err = testEngine.Table(new(InsertWhere)).Where("`repo_id`=?", 1). SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). Insert(map[string]interface{}{ "repo_id": 1, @@ -761,7 +772,7 @@ func TestInsertWhere(t *testing.T) { assert.EqualValues(t, "trest2", j2.Name) assert.EqualValues(t, 2, j2.Index) - inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + inserted, err = testEngine.Table(new(InsertWhere)).Where("`repo_id`=?", 1). SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). SetExpr("repo_id", "1"). Insert(map[string]string{ @@ -777,7 +788,7 @@ func TestInsertWhere(t *testing.T) { assert.EqualValues(t, "trest3", j3.Name) assert.EqualValues(t, 3, j3.Index) - inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + inserted, err = testEngine.Table(new(InsertWhere)).Where("`repo_id`=?", 1). SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). Insert(map[string]interface{}{ "repo_id": 1, @@ -793,7 +804,7 @@ func TestInsertWhere(t *testing.T) { assert.EqualValues(t, "10';delete * from insert_where; --", j4.Name) assert.EqualValues(t, 4, j4.Index) - inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + inserted, err = testEngine.Table(new(InsertWhere)).Where("`repo_id`=?", 1). SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). Insert(map[string]interface{}{ "repo_id": 1, @@ -846,6 +857,11 @@ func TestInsertExpr2(t *testing.T) { assert.EqualValues(t, 1, ie2.RepoId) assert.EqualValues(t, true, ie2.IsTag) + if testEngine.Dialect().URI().DBType == schemas.DAMENG { + t.SkipNow() + return + } + inserted, err = testEngine.Table(new(InsertExprsRelease)). SetExpr("is_draft", true). SetExpr("num_commits", 0). @@ -1067,6 +1083,11 @@ func TestInsertDeleted(t *testing.T) { } func TestInsertMultipleMap(t *testing.T) { + if testEngine.Dialect().URI().DBType == schemas.DAMENG { + t.SkipNow() + return + } + type InsertMultipleMap struct { Id int64 Width uint32 diff --git a/integrations/session_iterate_test.go b/integrations/session_iterate_test.go index 564f457b..95f5a282 100644 --- a/integrations/session_iterate_test.go +++ b/integrations/session_iterate_test.go @@ -91,7 +91,7 @@ func TestBufferIterate(t *testing.T) { assert.EqualValues(t, 7, cnt) cnt = 0 - err = testEngine.Where("id <= 10").BufferSize(2).Iterate(new(UserBufferIterate), func(i int, bean interface{}) error { + err = testEngine.Where("`id` <= 10").BufferSize(2).Iterate(new(UserBufferIterate), func(i int, bean interface{}) error { user := bean.(*UserBufferIterate) assert.EqualValues(t, cnt+1, user.Id) assert.EqualValues(t, true, user.IsMan) diff --git a/integrations/session_query_test.go b/integrations/session_query_test.go index ef2ccdd6..edc77aec 100644 --- a/integrations/session_query_test.go +++ b/integrations/session_query_test.go @@ -5,7 +5,6 @@ package integrations import ( - "fmt" "strconv" "testing" "time" @@ -37,7 +36,7 @@ func TestQueryString(t *testing.T) { _, err := testEngine.InsertOne(data) assert.NoError(t, err) - records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var2", true)) + records, err := testEngine.QueryString("select * from " + testEngine.Quote(testEngine.TableName("get_var2", true))) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 5, len(records[0])) @@ -63,7 +62,7 @@ func TestQueryString2(t *testing.T) { _, err := testEngine.Insert(data) assert.NoError(t, err) - records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var3", true)) + records, err := testEngine.QueryString("select * from " + testEngine.Quote(testEngine.TableName("get_var3", true))) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 2, len(records[0])) @@ -71,42 +70,6 @@ func TestQueryString2(t *testing.T) { assert.True(t, "0" == records[0]["msg"] || "false" == records[0]["msg"]) } -func toString(i interface{}) string { - switch i.(type) { - case []byte: - return string(i.([]byte)) - case string: - return i.(string) - } - return fmt.Sprintf("%v", i) -} - -func toInt64(i interface{}) int64 { - switch i.(type) { - case []byte: - n, _ := strconv.ParseInt(string(i.([]byte)), 10, 64) - return n - case int: - return int64(i.(int)) - case int64: - return i.(int64) - } - return 0 -} - -func toFloat64(i interface{}) float64 { - switch i.(type) { - case []byte: - n, _ := strconv.ParseFloat(string(i.([]byte)), 64) - return n - case float64: - return i.(float64) - case float32: - return float64(i.(float32)) - } - return 0 -} - func toBool(i interface{}) bool { switch t := i.(type) { case int32: @@ -138,7 +101,7 @@ func TestQueryInterface(t *testing.T) { _, err := testEngine.InsertOne(data) assert.NoError(t, err) - records, err := testEngine.QueryInterface("select * from " + testEngine.TableName("get_var_interface", true)) + records, err := testEngine.QueryInterface("select * from " + testEngine.Quote(testEngine.TableName("get_var_interface", true))) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 5, len(records[0])) @@ -192,7 +155,7 @@ func TestQueryNoParams(t *testing.T) { assert.NoError(t, err) assertResult(t, results) - results, err = testEngine.SQL("select * from " + testEngine.TableName("query_no_params", true)).Query() + results, err = testEngine.SQL("select * from " + testEngine.Quote(testEngine.TableName("query_no_params", true))).Query() assert.NoError(t, err) assertResult(t, results) } @@ -223,7 +186,7 @@ func TestQueryStringNoParam(t *testing.T) { assert.EqualValues(t, "0", records[0]["msg"]) } - records, err = testEngine.Table("get_var4").Where(builder.Eq{"id": 1}).QueryString() + records, err = testEngine.Table("get_var4").Where(builder.Eq{"`id`": 1}).QueryString() assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0]["id"]) @@ -260,7 +223,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { assert.EqualValues(t, "0", records[0][1]) } - records, err = testEngine.Table("get_var6").Where(builder.Eq{"id": 1}).QuerySliceString() + records, err = testEngine.Table("get_var6").Where(builder.Eq{"`id`": 1}).QuerySliceString() assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0][0]) @@ -293,7 +256,7 @@ func TestQueryInterfaceNoParam(t *testing.T) { assert.EqualValues(t, 1, records[0]["id"]) assert.False(t, toBool(records[0]["msg"])) - records, err = testEngine.Table("get_var5").Where(builder.Eq{"id": 1}).QueryInterface() + records, err = testEngine.Table("get_var5").Where(builder.Eq{"`id`": 1}).QueryInterface() assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, records[0]["id"]) @@ -340,7 +303,7 @@ func TestQueryWithBuilder(t *testing.T) { assert.EqualValues(t, 3000, money) } - results, err := testEngine.Query(builder.Select("*").From(testEngine.TableName("query_with_builder", true))) + results, err := testEngine.Query(builder.Select("*").From(testEngine.Quote(testEngine.TableName("query_with_builder", true)))) assert.NoError(t, err) assertResult(t, results) } @@ -383,14 +346,14 @@ func TestJoinWithSubQuery(t *testing.T) { tbName := testEngine.Quote(testEngine.TableName("join_with_sub_query_depart", true)) var querys []JoinWithSubQuery1 - err = testEngine.Join("INNER", builder.Select("id").From(tbName), - "join_with_sub_query_depart.id = join_with_sub_query1.depart_id").Find(&querys) + err = testEngine.Join("INNER", builder.Select("`id`").From(tbName), + "`join_with_sub_query_depart`.`id` = `join_with_sub_query1`.`depart_id`").Find(&querys) assert.NoError(t, err) assert.EqualValues(t, 1, len(querys)) assert.EqualValues(t, q, querys[0]) querys = make([]JoinWithSubQuery1, 0, 1) - err = testEngine.Join("INNER", "(SELECT id FROM "+tbName+") join_with_sub_query_depart", "join_with_sub_query_depart.id = join_with_sub_query1.depart_id"). + err = testEngine.Join("INNER", "(SELECT `id` FROM "+tbName+") `a`", "`a`.`id` = `join_with_sub_query1`.`depart_id`"). Find(&querys) assert.NoError(t, err) assert.EqualValues(t, 1, len(querys)) diff --git a/integrations/session_raw_test.go b/integrations/session_raw_test.go index 36677683..44718f46 100644 --- a/integrations/session_raw_test.go +++ b/integrations/session_raw_test.go @@ -22,13 +22,13 @@ func TestExecAndQuery(t *testing.T) { assert.NoError(t, testEngine.Sync2(new(UserinfoQuery))) - res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_query`", true)+" (uid, name) VALUES (?, ?)", 1, "user") + res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_query`", true)+" (`uid`, `name`) VALUES (?, ?)", 1, "user") assert.NoError(t, err) cnt, err := res.RowsAffected() assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - results, err := testEngine.Query("select * from " + testEngine.TableName("userinfo_query", true)) + results, err := testEngine.Query("select * from " + testEngine.Quote(testEngine.TableName("userinfo_query", true))) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) id, err := strconv.Atoi(string(results[0]["uid"])) @@ -48,19 +48,19 @@ func TestExecTime(t *testing.T) { assert.NoError(t, testEngine.Sync2(new(UserinfoExecTime))) now := time.Now() - res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_exec_time`", true)+" (uid, name, created) VALUES (?, ?, ?)", 1, "user", now) + res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_exec_time`", true)+" (`uid`, `name`, `created`) VALUES (?, ?, ?)", 1, "user", now) assert.NoError(t, err) cnt, err := res.RowsAffected() assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - results, err := testEngine.QueryString("SELECT * FROM " + testEngine.TableName("`userinfo_exec_time`", true)) + results, err := testEngine.QueryString("SELECT * FROM " + testEngine.Quote(testEngine.TableName("userinfo_exec_time", true))) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, now.In(testEngine.GetTZLocation()).Format("2006-01-02 15:04:05"), results[0]["created"]) var uet UserinfoExecTime - has, err := testEngine.Where("uid=?", 1).Get(&uet) + has, err := testEngine.Where("`uid`=?", 1).Get(&uet) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, now.In(testEngine.GetTZLocation()).Format("2006-01-02 15:04:05"), uet.Created.Format("2006-01-02 15:04:05")) diff --git a/integrations/session_schema_test.go b/integrations/session_schema_test.go index 9cbebcbf..98ad9657 100644 --- a/integrations/session_schema_test.go +++ b/integrations/session_schema_test.go @@ -526,8 +526,9 @@ func TestModifyColum(t *testing.T) { SQLType: schemas.SQLType{ Name: "VARCHAR", }, - Length: 16, - Nullable: false, + Length: 16, + Nullable: false, + DefaultIsEmpty: true, }) _, err := testEngine.Exec(alterSQL) assert.NoError(t, err) diff --git a/integrations/session_test.go b/integrations/session_test.go index c3ef0513..eacf2ff5 100644 --- a/integrations/session_test.go +++ b/integrations/session_test.go @@ -18,7 +18,7 @@ func TestClose(t *testing.T) { sess1.Close() assert.True(t, sess1.IsClosed()) - sess2 := testEngine.Where("a = ?", 1) + sess2 := testEngine.Where("`a` = ?", 1) sess2.Close() assert.True(t, sess2.IsClosed()) } diff --git a/integrations/session_tx_test.go b/integrations/session_tx_test.go index 4cff5610..8d6519d0 100644 --- a/integrations/session_tx_test.go +++ b/integrations/session_tx_test.go @@ -37,7 +37,7 @@ func TestTransaction(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "yyy"} - _, err = session.Where("id = ?", 0).Update(&user2) + _, err = session.Where("`id` = ?", 0).Update(&user2) assert.NoError(t, err) _, err = session.Delete(&user2) @@ -70,10 +70,10 @@ func TestCombineTransaction(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "zzz"} - _, err = session.Where("id = ?", 0).Update(&user2) + _, err = session.Where("`id` = ?", 0).Update(&user2) assert.NoError(t, err) - _, err = session.Exec("delete from "+testEngine.TableName("userinfo", true)+" where username = ?", user2.Username) + _, err = session.Exec("delete from "+testEngine.Quote(testEngine.TableName("userinfo", true))+" where `username` = ?", user2.Username) assert.NoError(t, err) err = session.Commit() @@ -113,10 +113,10 @@ func TestCombineTransactionSameMapper(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "zzz"} - _, err = session.Where("id = ?", 0).Update(&user2) + _, err = session.Where("`id` = ?", 0).Update(&user2) assert.NoError(t, err) - _, err = session.Exec("delete from "+testEngine.TableName("`Userinfo`", true)+" where `Username` = ?", user2.Username) + _, err = session.Exec("delete from "+testEngine.Quote(testEngine.TableName("Userinfo", true))+" where `Username` = ?", user2.Username) assert.NoError(t, err) err = session.Commit() @@ -144,7 +144,7 @@ func TestMultipleTransaction(t *testing.T) { assert.NoError(t, err) user2 := MultipleTransaction{Name: "zzz"} - _, err = session.Where("id = ?", 0).Update(&user2) + _, err = session.Where("`id` = ?", 0).Update(&user2) assert.NoError(t, err) err = session.Commit() @@ -158,7 +158,7 @@ func TestMultipleTransaction(t *testing.T) { err = session.Begin() assert.NoError(t, err) - _, err = session.Where("id=?", m1.Id).Delete(new(MultipleTransaction)) + _, err = session.Where("`id`=?", m1.Id).Delete(new(MultipleTransaction)) assert.NoError(t, err) err = session.Commit() diff --git a/integrations/session_update_test.go b/integrations/session_update_test.go index bbcc7600..4312d0e0 100644 --- a/integrations/session_update_test.go +++ b/integrations/session_update_test.go @@ -35,7 +35,7 @@ func TestUpdateMap(t *testing.T) { _, err := testEngine.Insert(&tb) assert.NoError(t, err) - cnt, err := testEngine.Table("update_table").Where("id = ?", tb.Id).Update(map[string]interface{}{ + cnt, err := testEngine.Table("update_table").Where("`id` = ?", tb.Id).Update(map[string]interface{}{ "name": "test2", "age": 36, }) @@ -93,7 +93,12 @@ func TestUpdateLimit(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - cnt, err = testEngine.OrderBy("name desc").Limit(1).Update(&UpdateTable2{ + if testEngine.Dialect().URI().DBType == schemas.DAMENG { + t.SkipNow() + return + } + + cnt, err = testEngine.OrderBy("`name` desc").Limit(1).Update(&UpdateTable2{ Age: 30, }) assert.NoError(t, err) @@ -166,7 +171,7 @@ func TestForUpdate(t *testing.T) { // use lock fList := make([]ForUpdate, 0) session1.ForUpdate() - session1.Where("id = ?", 1) + session1.Where("`id` = ?", 1) err = session1.Find(&fList) switch { case err != nil: @@ -187,7 +192,7 @@ func TestForUpdate(t *testing.T) { wg.Add(1) go func() { f2 := new(ForUpdate) - session2.Where("id = ?", 1).ForUpdate() + session2.Where("`id` = ?", 1).ForUpdate() has, err := session2.Get(f2) // wait release lock switch { case err != nil: @@ -207,7 +212,7 @@ func TestForUpdate(t *testing.T) { wg2.Add(1) go func() { f3 := new(ForUpdate) - session3.Where("id = ?", 1) + session3.Where("`id` = ?", 1) has, err := session3.Get(f3) // wait release lock switch { case err != nil: @@ -225,7 +230,7 @@ func TestForUpdate(t *testing.T) { f := new(ForUpdate) f.Name = "updated by session1" - session1.Where("id = ?", 1) + session1.Where("`id` = ?", 1) session1.Update(f) // release lock @@ -300,7 +305,7 @@ func TestUpdateMap2(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(UpdateMustCols)) - _, err := testEngine.Table("update_must_cols").Where("id =?", 1).Update(map[string]interface{}{ + _, err := testEngine.Table("update_must_cols").Where("`id` =?", 1).Update(map[string]interface{}{ "bool": true, }) assert.NoError(t, err) @@ -345,11 +350,11 @@ func TestUpdate1(t *testing.T) { userID := user.Uid has, err := testEngine.ID(userID). - And("username = ?", user.Username). - And("height = ?", user.Height). - And("departname = ?", ""). - And("detail_id = ?", 0). - And("is_man = ?", false). + And("`username` = ?", user.Username). + And("`height` = ?", user.Height). + And("`departname` = ?", ""). + And("`detail_id` = ?", 0). + And("`is_man` = ?", false). Get(&Userinfo{}) assert.NoError(t, err) assert.True(t, has, "cannot insert properly") @@ -362,12 +367,12 @@ func TestUpdate1(t *testing.T) { assert.EqualValues(t, 1, cnt, "update not returned 1") has, err = testEngine.ID(userID). - And("username = ?", updatedUser.Username). - And("height IS NULL"). - And("departname IS NULL"). - And("is_man IS NULL"). - And("created IS NULL"). - And("detail_id = ?", 0). + And("`username` = ?", updatedUser.Username). + And("`height` IS NULL"). + And("`departname` IS NULL"). + And("`is_man` IS NULL"). + And("`created` IS NULL"). + And("`detail_id` = ?", 0). Get(&Userinfo{}) assert.NoError(t, err) assert.True(t, has, "cannot update with null properly") @@ -825,7 +830,7 @@ func TestNewUpdate(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 0, af) - af, err = testEngine.Table(new(TbUserInfo)).Where("phone=?", "13126564922").Update(&changeUsr) + af, err = testEngine.Table(new(TbUserInfo)).Where("`phone`=?", "13126564922").Update(&changeUsr) assert.NoError(t, err) assert.EqualValues(t, 0, af) } @@ -1166,7 +1171,7 @@ func TestUpdateExprs(t *testing.T) { }) assert.NoError(t, err) - _, err = testEngine.SetExpr("num_issues", "num_issues+1").AllCols().Update(&UpdateExprs{ + _, err = testEngine.SetExpr("num_issues", "`num_issues`+1").AllCols().Update(&UpdateExprs{ NumIssues: 3, Name: "lunny xiao", }) @@ -1197,7 +1202,7 @@ func TestUpdateAlias(t *testing.T) { }) assert.NoError(t, err) - _, err = testEngine.Alias("ua").Where("ua.id = ?", 1).Update(&UpdateAlias{ + _, err = testEngine.Alias("ua").Where("ua.`id` = ?", 1).Update(&UpdateAlias{ NumIssues: 2, Name: "lunny xiao", }) @@ -1237,7 +1242,7 @@ func TestUpdateExprs2(t *testing.T) { assert.EqualValues(t, 1, inserted) updated, err := testEngine. - Where("repo_id = ? AND is_tag = ?", 1, false). + Where("`repo_id` = ? AND `is_tag` = ?", 1, false). SetExpr("is_draft", true). SetExpr("num_commits", 0). SetExpr("sha1", ""). @@ -1257,6 +1262,11 @@ func TestUpdateExprs2(t *testing.T) { } func TestUpdateMap3(t *testing.T) { + if testEngine.Dialect().URI().DBType == schemas.DAMENG { + t.SkipNow() + return + } + assert.NoError(t, PrepareEngine()) type UpdateMapUser struct { @@ -1308,7 +1318,7 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) { assertGetRecord := func() *TestOnlyFromDBField { var record TestOnlyFromDBField - has, err := testEngine.Where("id = ?", 1).Get(&record) + has, err := testEngine.Where("`id` = ?", 1).Get(&record) assert.NoError(t, err) assert.EqualValues(t, true, has) assert.EqualValues(t, "", record.OnlyFromDBField) diff --git a/integrations/tags_test.go b/integrations/tags_test.go index fc7b505e..b5bf222e 100644 --- a/integrations/tags_test.go +++ b/integrations/tags_test.go @@ -458,7 +458,7 @@ func TestExtends5(t *testing.T) { list := make([]Book, 0) err = session. Select(fmt.Sprintf( - "%s.%s, sc.%s AS %s, sc.%s AS %s, s.%s, s.%s", + "%s.%s, `sc`.%s AS %s, `sc`.%s AS %s, `s`.%s, `s`.%s", quote(bookTableName), quote("id"), quote("Width"), @@ -472,12 +472,12 @@ func TestExtends5(t *testing.T) { Join( "LEFT", sizeTableName+" AS `sc`", - bookTableName+".`SizeClosed`=sc.`id`", + bookTableName+".`SizeClosed`=`sc`.`id`", ). Join( "LEFT", sizeTableName+" AS `s`", - bookTableName+".`Size`=s.`id`", + bookTableName+".`Size`=`s`.`id`", ). Find(&list) assert.NoError(t, err) @@ -730,7 +730,7 @@ func TestLowerCase(t *testing.T) { err := testEngine.Sync2(&Lowercase{}) assert.NoError(t, err) - _, err = testEngine.Where("id > 0").Delete(&Lowercase{}) + _, err = testEngine.Where("`id` > 0").Delete(&Lowercase{}) assert.NoError(t, err) _, err = testEngine.Insert(&Lowercase{ended: 1}) diff --git a/integrations/time_test.go b/integrations/time_test.go index cd2e879f..a8447eea 100644 --- a/integrations/time_test.go +++ b/integrations/time_test.go @@ -324,7 +324,7 @@ func TestTimeUserDeleted(t *testing.T) { fmt.Println("user2 str", user2.CreatedAtStr, user2.UpdatedAtStr) var user3 UserDeleted - cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) + cnt, err = testEngine.Where("`id` = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) assert.True(t, !utils.IsTimeZero(user3.DeletedAt)) @@ -386,7 +386,7 @@ func TestTimeUserDeletedDiffLoc(t *testing.T) { fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) var user3 UserDeleted2 - cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) + cnt, err = testEngine.Where("`id` = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) assert.True(t, !utils.IsTimeZero(user3.DeletedAt)) @@ -457,7 +457,7 @@ func TestCustomTimeUserDeleted(t *testing.T) { fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) var user3 UserDeleted3 - cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) + cnt, err = testEngine.Where("`id` = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) assert.True(t, !utils.IsTimeZero(time.Time(user3.DeletedAt))) @@ -519,7 +519,7 @@ func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) var user3 UserDeleted4 - cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) + cnt, err = testEngine.Where("`id` = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) assert.True(t, !utils.IsTimeZero(time.Time(user3.DeletedAt))) diff --git a/integrations/types_null_test.go b/integrations/types_null_test.go index 86ce1939..8d98b456 100644 --- a/integrations/types_null_test.go +++ b/integrations/types_null_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/assert" ) -type NullType struct { +type NullStruct struct { Id int `xorm:"pk autoincr"` Name sql.NullString Age sql.NullInt64 @@ -65,26 +65,26 @@ func (m CustomStruct) Value() (driver.Value, error) { func TestCreateNullStructTable(t *testing.T) { assert.NoError(t, PrepareEngine()) - err := testEngine.CreateTables(new(NullType)) + err := testEngine.CreateTables(new(NullStruct)) assert.NoError(t, err) } func TestDropNullStructTable(t *testing.T) { assert.NoError(t, PrepareEngine()) - err := testEngine.DropTables(new(NullType)) + err := testEngine.DropTables(new(NullStruct)) assert.NoError(t, err) } func TestNullStructInsert(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) - item1 := new(NullType) + item1 := new(NullStruct) _, err := testEngine.Insert(item1) assert.NoError(t, err) assert.EqualValues(t, 1, item1.Id) - item := NullType{ + item := NullStruct{ Name: sql.NullString{String: "haolei", Valid: true}, Age: sql.NullInt64{Int64: 34, Valid: true}, Height: sql.NullFloat64{Float64: 1.72, Valid: true}, @@ -95,9 +95,9 @@ func TestNullStructInsert(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, item.Id) - items := []NullType{} + items := []NullStruct{} for i := 0; i < 5; i++ { - item := NullType{ + item := NullStruct{ Name: sql.NullString{String: "haolei_" + fmt.Sprint(i+1), Valid: true}, Age: sql.NullInt64{Int64: 30 + int64(i), Valid: true}, Height: sql.NullFloat64{Float64: 1.5 + 1.1*float64(i), Valid: true}, @@ -111,7 +111,7 @@ func TestNullStructInsert(t *testing.T) { _, err = testEngine.Insert(&items) assert.NoError(t, err) - items = make([]NullType, 0, 7) + items = make([]NullStruct, 0, 7) err = testEngine.Find(&items) assert.NoError(t, err) assert.EqualValues(t, 7, len(items)) @@ -119,9 +119,9 @@ func TestNullStructInsert(t *testing.T) { func TestNullStructUpdate(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) - _, err := testEngine.Insert([]NullType{ + _, err := testEngine.Insert([]NullStruct{ { Name: sql.NullString{ String: "name1", @@ -150,7 +150,7 @@ func TestNullStructUpdate(t *testing.T) { assert.NoError(t, err) if true { // 测试可插入NULL - item := new(NullType) + item := new(NullStruct) item.Age = sql.NullInt64{Int64: 23, Valid: true} item.Height = sql.NullFloat64{Float64: 0, Valid: false} // update to NULL @@ -160,7 +160,7 @@ func TestNullStructUpdate(t *testing.T) { } if true { // 测试In update - item := new(NullType) + item := new(NullStruct) item.Age = sql.NullInt64{Int64: 23, Valid: true} affected, err := testEngine.In("id", 3, 4).Cols("age", "height", "is_man").Update(item) assert.NoError(t, err) @@ -168,17 +168,17 @@ func TestNullStructUpdate(t *testing.T) { } if true { // 测试where - item := new(NullType) + item := new(NullStruct) item.Name = sql.NullString{String: "nullname", Valid: true} item.IsMan = sql.NullBool{Bool: true, Valid: true} item.Age = sql.NullInt64{Int64: 34, Valid: true} - _, err := testEngine.Where("age > ?", 34).Update(item) + _, err := testEngine.Where("`age` > ?", 34).Update(item) assert.NoError(t, err) } if true { // 修改全部时,插入空值 - item := &NullType{ + item := &NullStruct{ Name: sql.NullString{String: "winxxp", Valid: true}, Age: sql.NullInt64{Int64: 30, Valid: true}, Height: sql.NullFloat64{Float64: 1.72, Valid: true}, @@ -192,9 +192,9 @@ func TestNullStructUpdate(t *testing.T) { func TestNullStructFind(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) - _, err := testEngine.Insert([]NullType{ + _, err := testEngine.Insert([]NullStruct{ { Name: sql.NullString{ String: "name1", @@ -223,7 +223,7 @@ func TestNullStructFind(t *testing.T) { assert.NoError(t, err) if true { - item := new(NullType) + item := new(NullStruct) has, err := testEngine.ID(1).Get(item) assert.NoError(t, err) assert.True(t, has) @@ -235,7 +235,7 @@ func TestNullStructFind(t *testing.T) { } if true { - item := new(NullType) + item := new(NullStruct) item.Id = 2 has, err := testEngine.Get(item) assert.NoError(t, err) @@ -243,13 +243,13 @@ func TestNullStructFind(t *testing.T) { } if true { - item := make([]NullType, 0) + item := make([]NullStruct, 0) err := testEngine.ID(2).Find(&item) assert.NoError(t, err) } if true { - item := make([]NullType, 0) + item := make([]NullStruct, 0) err := testEngine.Asc("age").Find(&item) assert.NoError(t, err) } @@ -257,12 +257,12 @@ func TestNullStructFind(t *testing.T) { func TestNullStructIterate(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) if true { - err := testEngine.Where("age IS NOT NULL").OrderBy("age").Iterate(new(NullType), + err := testEngine.Where("`age` IS NOT NULL").OrderBy("age").Iterate(new(NullStruct), func(i int, bean interface{}) error { - nultype := bean.(*NullType) + nultype := bean.(*NullStruct) fmt.Println(i, nultype) return nil }) @@ -272,21 +272,21 @@ func TestNullStructIterate(t *testing.T) { func TestNullStructCount(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) if true { - item := new(NullType) - _, err := testEngine.Where("age IS NOT NULL").Count(item) + item := new(NullStruct) + _, err := testEngine.Where("`age` IS NOT NULL").Count(item) assert.NoError(t, err) } } func TestNullStructRows(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) - item := new(NullType) - rows, err := testEngine.Where("id > ?", 1).Rows(item) + item := new(NullStruct) + rows, err := testEngine.Where("`id` > ?", 1).Rows(item) assert.NoError(t, err) defer rows.Close() @@ -298,13 +298,13 @@ func TestNullStructRows(t *testing.T) { func TestNullStructDelete(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) - item := new(NullType) + item := new(NullStruct) _, err := testEngine.ID(1).Delete(item) assert.NoError(t, err) - _, err = testEngine.Where("id > ?", 1).Delete(item) + _, err = testEngine.Where("`id` > ?", 1).Delete(item) assert.NoError(t, err) } diff --git a/integrations/types_test.go b/integrations/types_test.go index b08dfec5..48facb21 100644 --- a/integrations/types_test.go +++ b/integrations/types_test.go @@ -428,7 +428,7 @@ func TestUnsignedUint64(t *testing.T) { assert.EqualValues(t, "INTEGER", tables[0].Columns()[0].SQLType.Name) case schemas.MYSQL: assert.EqualValues(t, "UNSIGNED BIGINT", tables[0].Columns()[0].SQLType.Name) - case schemas.POSTGRES: + case schemas.POSTGRES, schemas.DAMENG: assert.EqualValues(t, "BIGINT", tables[0].Columns()[0].SQLType.Name) case schemas.MSSQL: assert.EqualValues(t, "BIGINT", tables[0].Columns()[0].SQLType.Name) @@ -472,9 +472,7 @@ func TestUnsignedUint32(t *testing.T) { assert.EqualValues(t, "INTEGER", tables[0].Columns()[0].SQLType.Name) case schemas.MYSQL: assert.EqualValues(t, "UNSIGNED INT", tables[0].Columns()[0].SQLType.Name) - case schemas.POSTGRES: - assert.EqualValues(t, "BIGINT", tables[0].Columns()[0].SQLType.Name) - case schemas.MSSQL: + case schemas.POSTGRES, schemas.MSSQL, schemas.DAMENG: assert.EqualValues(t, "BIGINT", tables[0].Columns()[0].SQLType.Name) default: assert.False(t, true, "Unsigned is not implemented") @@ -507,7 +505,7 @@ func TestUnsignedTinyInt(t *testing.T) { assert.EqualValues(t, 1, len(tables[0].Columns())) switch testEngine.Dialect().URI().DBType { - case schemas.SQLITE: + case schemas.SQLITE, schemas.DAMENG: assert.EqualValues(t, "INTEGER", tables[0].Columns()[0].SQLType.Name) case schemas.MYSQL: assert.EqualValues(t, "UNSIGNED TINYINT", tables[0].Columns()[0].SQLType.Name) @@ -516,7 +514,7 @@ func TestUnsignedTinyInt(t *testing.T) { case schemas.MSSQL: assert.EqualValues(t, "INT", tables[0].Columns()[0].SQLType.Name) default: - assert.False(t, true, "Unsigned is not implemented") + assert.False(t, true, fmt.Sprintf("Unsigned is not implemented, returned %s", tables[0].Columns()[0].SQLType.Name)) } cnt, err := testEngine.Insert(&MyUnsignedTinyIntStruct{ diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 84547cdf..91a33319 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -10,6 +10,7 @@ import ( "strings" "xorm.io/builder" + "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -42,7 +43,19 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if len(colNames) <= 0 { + var hasInsertColumns = len(colNames) > 0 + var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) + if needSeq { + for _, col := range colNames { + if strings.EqualFold(col, table.AutoIncrement) { + needSeq = false + break + } + } + } + + if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE && + statement.dialect.URI().DBType != schemas.DAMENG { if statement.dialect.URI().DBType == schemas.MYSQL { if _, err := buf.WriteString(" VALUES ()"); err != nil { return "", nil, err @@ -60,6 +73,10 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } + if needSeq { + colNames = append(colNames, table.AutoIncrement) + } + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames()...), ","); err != nil { return "", nil, err } @@ -80,13 +97,23 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } + if needSeq { + if len(args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + } + if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil { + return "", nil, err + } + } if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err } - } - if err := exprs.WriteArgs(buf); err != nil { - return "", nil, err + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } } if _, err := buf.WriteString(" FROM "); err != nil { @@ -113,6 +140,18 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } + // Insert tablename (id) Values(seq_tablename.nextval) + if needSeq { + if hasInsertColumns { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + } + if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil { + return "", nil, err + } + } + if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 1fcc0bba..b85773db 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -539,7 +539,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) aliasName = schemas.CommonQuoter.Trim(aliasName) - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), aliasName, statement.ReplaceQuote(condition)) + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition)) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) case *builder.Builder: subSQL, subQueryArgs, err := tp.ToSQL() @@ -552,7 +552,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) aliasName = schemas.CommonQuoter.Trim(aliasName) - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), aliasName, statement.ReplaceQuote(condition)) + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition)) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) default: tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) @@ -560,6 +560,8 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition var buf strings.Builder statement.dialect.Quoter().QuoteTo(&buf, tbName) tbName = buf.String() + } else { + tbName = statement.ReplaceQuote(tbName) } fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition)) } @@ -642,14 +644,6 @@ func (statement *Statement) genColumnStr() string { return buf.String() } -// GenCreateTableSQL generated create table SQL -func (statement *Statement) GenCreateTableSQL() []string { - statement.RefTable.StoreEngine = statement.StoreEngine - statement.RefTable.Charset = statement.Charset - s, _ := statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName()) - return s -} - // GenIndexSQL generated create index SQL func (statement *Statement) GenIndexSQL() []string { var sqls []string diff --git a/internal/utils/name.go b/internal/utils/name.go index 840dd9e8..aeef683d 100644 --- a/internal/utils/name.go +++ b/internal/utils/name.go @@ -6,9 +6,15 @@ package utils import ( "fmt" + "strings" ) // IndexName returns index name func IndexName(tableName, idxName string) string { return fmt.Sprintf("IDX_%v_%v", tableName, idxName) } + +// SeqName returns sequence name for some table +func SeqName(tableName string) string { + return "SEQ_" + strings.ToUpper(tableName) +} diff --git a/internal/utils/slice.go b/internal/utils/slice.go index 89685706..06a1a006 100644 --- a/internal/utils/slice.go +++ b/internal/utils/slice.go @@ -11,8 +11,8 @@ func SliceEq(left, right []string) bool { if len(left) != len(right) { return false } - sort.Sort(sort.StringSlice(left)) - sort.Sort(sort.StringSlice(right)) + sort.Strings(left) + sort.Strings(right) for i := 0; i < len(left); i++ { if left[i] != right[i] { return false @@ -20,3 +20,12 @@ func SliceEq(left, right []string) bool { } return true } + +func IndexSlice(s []string, c string) int { + for i, ss := range s { + if c == ss { + return i + } + } + return -1 +} diff --git a/scan.go b/scan.go index 56d3c9d6..2788453e 100644 --- a/scan.go +++ b/scan.go @@ -136,7 +136,10 @@ func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, field scanResults[i] = &s } - if err := rows.Scan(scanResults...); err != nil { + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResults...); err != nil { return nil, err } diff --git a/schemas/quote_test.go b/schemas/quote_test.go index 8e351dc0..061a6ea2 100644 --- a/schemas/quote_test.go +++ b/schemas/quote_test.go @@ -176,6 +176,10 @@ func TestReplace(t *testing.T) { "UPDATE table SET `a` = ~ `a`, `b`='abc`'", "UPDATE table SET [a] = ~ [a], [b]='abc`'", }, + { + "INSERT INTO `insert_where` (`height`,`name`,`repo_id`,`width`,`index`) SELECT $1,$2,$3,$4,coalesce(MAX(`index`),0)+1 FROM `insert_where` WHERE (`repo_id`=$5)", + "INSERT INTO [insert_where] ([height],[name],[repo_id],[width],[index]) SELECT $1,$2,$3,$4,coalesce(MAX([index]),0)+1 FROM [insert_where] WHERE ([repo_id]=$5)", + }, } for _, kase := range kases { diff --git a/schemas/type.go b/schemas/type.go index cf730134..d192bac6 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -22,6 +22,7 @@ const ( MYSQL DBType = "mysql" MSSQL DBType = "mssql" ORACLE DBType = "oracle" + DAMENG DBType = "dameng" ) // SQLType represents SQL types @@ -105,12 +106,14 @@ var ( Integer = "INTEGER" BigInt = "BIGINT" UnsignedBigInt = "UNSIGNED BIGINT" + Number = "NUMBER" Enum = "ENUM" Set = "SET" Char = "CHAR" Varchar = "VARCHAR" + VARCHAR2 = "VARCHAR2" NChar = "NCHAR" NVarchar = "NVARCHAR" TinyText = "TINYTEXT" @@ -174,6 +177,7 @@ var ( Integer: NUMERIC_TYPE, BigInt: NUMERIC_TYPE, UnsignedBigInt: NUMERIC_TYPE, + Number: NUMERIC_TYPE, Enum: TEXT_TYPE, Set: TEXT_TYPE, @@ -185,6 +189,7 @@ var ( Char: TEXT_TYPE, NChar: TEXT_TYPE, Varchar: TEXT_TYPE, + VARCHAR2: TEXT_TYPE, NVarchar: TEXT_TYPE, TinyText: TEXT_TYPE, Text: TEXT_TYPE, diff --git a/session.go b/session.go index 499b7df4..a96d2fc9 100644 --- a/session.go +++ b/session.go @@ -524,6 +524,9 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec if !ok { return fmt.Errorf("cannot convert %#v as bytes", scanResult) } + if data == nil { + return nil + } return structConvert.FromDB(data) } } diff --git a/session_get.go b/session_get.go index 08172524..22b116a9 100644 --- a/session_get.go +++ b/session_get.go @@ -130,9 +130,6 @@ var ( valuerTypePlaceHolder driver.Valuer valuerType = reflect.TypeOf(&valuerTypePlaceHolder).Elem() - scannerTypePlaceHolder sql.Scanner - scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem() - conversionTypePlaceHolder convert.Conversion conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem() ) diff --git a/session_insert.go b/session_insert.go index a8f365c7..43a4118b 100644 --- a/session_insert.go +++ b/session_insert.go @@ -123,6 +123,12 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, e } fieldValue := *ptrFieldValue if col.IsAutoIncrement && utils.IsZero(fieldValue.Interface()) { + if session.engine.dialect.Features().AutoincrMode == dialects.SequenceAutoincrMode { + if i == 0 { + colNames = append(colNames, col.Name) + } + colPlaces = append(colPlaces, utils.SeqName(tableName)+".nextval") + } continue } if col.MapType == schemas.ONLYFROMDB { @@ -277,6 +283,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { if err != nil { return 0, err } + sqlStr = session.engine.dialect.Quoter().Replace(sqlStr) handleAfterInsertProcessorFunc := func(bean interface{}) { if session.isAutoCommit { @@ -307,16 +314,49 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { // if there is auto increment column and driver don't support return it if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID { - var sql = sqlStr - if session.engine.dialect.URI().DBType == schemas.ORACLE { - sql = "select seq_atable.currval from dual" + var sql string + var newArgs []interface{} + var needCommit bool + var id int64 + if session.engine.dialect.URI().DBType == schemas.ORACLE || session.engine.dialect.URI().DBType == schemas.DAMENG { + if session.isAutoCommit { // if it's not in transaction + if err := session.Begin(); err != nil { + return 0, err + } + needCommit = true + } + _, err := session.exec(sqlStr, args...) + if err != nil { + return 0, err + } + i := utils.IndexSlice(colNames, table.AutoIncrement) + if i > -1 { + id, err = convert.AsInt64(args[i]) + if err != nil { + return 0, err + } + } else { + sql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName)) + } + } else { + sql = sqlStr + newArgs = args } - rows, err := session.queryRows(sql, args...) - if err != nil { - return 0, err + if id == 0 { + err := session.queryRow(sql, newArgs...).Scan(&id) + if err != nil { + return 0, err + } + if needCommit { + if err := session.Commit(); err != nil { + return 0, err + } + } + if id == 0 { + return 0, errors.New("insert successfully but not returned id") + } } - defer rows.Close() defer handleAfterInsertProcessorFunc(bean) @@ -331,16 +371,6 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { } } - var id int64 - if !rows.Next() { - if rows.Err() != nil { - return 0, rows.Err() - } - return 0, errors.New("insert successfully but not returned id") - } - if err := rows.Scan(&id); err != nil { - return 1, err - } aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { session.engine.logger.Errorf("%v", err) @@ -628,6 +658,7 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, if err != nil { return 0, err } + sql = session.engine.dialect.Quoter().Replace(sql) if err := session.cacheInsert(tableName); err != nil { return 0, err @@ -654,6 +685,7 @@ func (session *Session) insertMultipleMap(columns []string, argss [][]interface{ if err != nil { return 0, err } + sql = session.engine.dialect.Quoter().Replace(sql) if err := session.cacheInsert(tableName); err != nil { return 0, err diff --git a/session_schema.go b/session_schema.go index 2e64350f..73352135 100644 --- a/session_schema.go +++ b/session_schema.go @@ -6,12 +6,14 @@ package xorm import ( "bufio" + "context" "database/sql" "fmt" "io" "os" "strings" + "xorm.io/xorm/dialects" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -40,13 +42,28 @@ func (session *Session) createTable(bean interface{}) error { return err } - sqlStrs := session.statement.GenCreateTableSQL() - for _, s := range sqlStrs { - _, err := session.exec(s) + session.statement.RefTable.StoreEngine = session.statement.StoreEngine + session.statement.RefTable.Charset = session.statement.Charset + tableName := session.statement.TableName() + refTable := session.statement.RefTable + if refTable.AutoIncrement != "" && session.engine.dialect.Features().AutoincrMode == dialects.SequenceAutoincrMode { + sqlStr, err := session.engine.dialect.CreateSequenceSQL(context.Background(), session.engine.db, utils.SeqName(tableName)) if err != nil { return err } + if _, err := session.exec(sqlStr); err != nil { + return err + } } + + sqlStr, _, err := session.engine.dialect.CreateTableSQL(context.Background(), session.engine.db, refTable, tableName) + if err != nil { + return err + } + if _, err := session.exec(sqlStr); err != nil { + return err + } + return nil } @@ -141,11 +158,32 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { checkIfExist = exist } - if checkIfExist { - _, err := session.exec(sqlStr) + if !checkIfExist { + return nil + } + if _, err := session.exec(sqlStr); err != nil { return err } - return nil + + if session.engine.dialect.Features().AutoincrMode == dialects.IncrAutoincrMode { + return nil + } + + var seqName = utils.SeqName(tableName) + exist, err := session.engine.dialect.IsSequenceExist(session.ctx, session.getQueryer(), seqName) + if err != nil { + return err + } + if !exist { + return nil + } + + sqlStr, err = session.engine.dialect.DropSequenceSQL(seqName) + if err != nil { + return err + } + _, err = session.exec(sqlStr) + return err } // IsTableExist if a table is exist diff --git a/session_update.go b/session_update.go index 7d91346e..4fd45a53 100644 --- a/session_update.go +++ b/session_update.go @@ -278,7 +278,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 condBeanIsStruct := false if len(condiBean) > 0 { if c, ok := condiBean[0].(map[string]interface{}); ok { - autoCond = builder.Eq(c) + var eq = make(builder.Eq) + for k, v := range c { + eq[session.engine.Quote(k)] = v + } + autoCond = builder.Eq(eq) } else { ct := reflect.TypeOf(condiBean[0]) k := ct.Kind()