diff --git a/core/db.go b/core/db.go index 50c64c6f..ef5ab227 100644 --- a/core/db.go +++ b/core/db.go @@ -23,6 +23,7 @@ var ( DefaultCacheSize = 200 ) +// MapToSlice map query and struct as sql and args func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -44,6 +45,7 @@ func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { return query, args, err } +// StructToSlice converts a query and struct as sql and args func StructToSlice(query string, st interface{}) (string, []interface{}, error) { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -176,6 +178,7 @@ func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) { return db.QueryMapContext(context.Background(), query, mp) } +// QueryStructContext query rows with struct func (db *DB) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) { query, args, err := StructToSlice(query, st) if err != nil { @@ -184,10 +187,12 @@ func (db *DB) QueryStructContext(ctx context.Context, query string, st interface return db.QueryContext(ctx, query, args...) } +// QueryStruct query rows with struct func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) { return db.QueryStructContext(context.Background(), query, st) } +// QueryRowContext query row with args func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := db.QueryContext(ctx, query, args...) if err != nil { @@ -196,10 +201,12 @@ func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interfa return &Row{rows, nil} } +// QueryRow query row with args func (db *DB) QueryRow(query string, args ...interface{}) *Row { return db.QueryRowContext(context.Background(), query, args...) } +// QueryRowMapContext query row with map func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row { query, args, err := MapToSlice(query, mp) if err != nil { @@ -208,10 +215,12 @@ func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface return db.QueryRowContext(ctx, query, args...) } +// QueryRowMap query row with map func (db *DB) QueryRowMap(query string, mp interface{}) *Row { return db.QueryRowMapContext(context.Background(), query, mp) } +// QueryRowStructContext query row with struct func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row { query, args, err := StructToSlice(query, st) if err != nil { @@ -220,6 +229,7 @@ func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interf return db.QueryRowContext(ctx, query, args...) } +// QueryRowStruct query row with struct func (db *DB) QueryRowStruct(query string, st interface{}) *Row { return db.QueryRowStructContext(context.Background(), query, st) } @@ -239,10 +249,12 @@ func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) return db.ExecContext(ctx, query, args...) } +// ExecMap exec query with map func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) { return db.ExecMapContext(context.Background(), query, mp) } +// ExecStructContext exec query with map func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) { query, args, err := StructToSlice(query, st) if err != nil { @@ -251,6 +263,7 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{ return db.ExecContext(ctx, query, args...) } +// ExecContext exec query with args func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { hookCtx := contexts.NewContextHook(ctx, query, args) ctx, err := db.beforeProcess(hookCtx) @@ -265,6 +278,7 @@ func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{} return res, nil } +// ExecStruct exec query with struct func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { return db.ExecStructContext(context.Background(), query, st) } @@ -288,6 +302,7 @@ func (db *DB) afterProcess(c *contexts.ContextHook) error { return err } +// AddHook adds hook func (db *DB) AddHook(h ...contexts.Hook) { db.hooks.AddHook(h...) } diff --git a/core/db_test.go b/core/db_test.go index 104c5b95..e9c2d82d 100644 --- a/core/db_test.go +++ b/core/db_test.go @@ -21,7 +21,7 @@ import ( var ( dbtype = flag.String("dbtype", "sqlite3", "database type") dbConn = flag.String("dbConn", "./db_test.db", "database connect string") - createTableSql string + createTableSQL string ) func TestMain(m *testing.M) { @@ -29,12 +29,12 @@ func TestMain(m *testing.M) { switch *dbtype { case "sqlite3", "sqlite": - createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " + + createTableSQL = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " + "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" case "mysql": fallthrough default: - createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " + + createTableSQL = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " + "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" } @@ -66,7 +66,7 @@ func BenchmarkOriQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -121,7 +121,7 @@ func BenchmarkStructQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -166,7 +166,7 @@ func BenchmarkStruct2Query(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -212,7 +212,7 @@ func BenchmarkSliceInterfaceQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -270,7 +270,7 @@ func BenchmarkSliceInterfaceQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -321,7 +321,7 @@ func BenchmarkSliceStringQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -372,7 +372,7 @@ func BenchmarkMapInterfaceQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -426,7 +426,7 @@ func BenchmarkMapInterfaceQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -473,7 +473,7 @@ func BenchmarkMapStringQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -519,7 +519,7 @@ func BenchmarkExec(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -544,7 +544,7 @@ func BenchmarkExecMap(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -577,7 +577,7 @@ func TestExecMap(t *testing.T) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { t.Error(err) } @@ -620,7 +620,7 @@ func TestExecStruct(t *testing.T) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { t.Error(err) } @@ -663,7 +663,7 @@ func BenchmarkExecStruct(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } diff --git a/core/rows.go b/core/rows.go index a1e8bfbc..c15a59a3 100644 --- a/core/rows.go +++ b/core/rows.go @@ -11,11 +11,13 @@ import ( "sync" ) +// Rows represents rows of table type Rows struct { *sql.Rows db *DB } +// ToMapString returns all records func (rs *Rows) ToMapString() ([]map[string]string, error) { cols, err := rs.Columns() if err != nil { @@ -34,7 +36,7 @@ func (rs *Rows) ToMapString() ([]map[string]string, error) { return results, nil } -// scan data to a struct's pointer according field index +// ScanStructByIndex scan data to a struct's pointer according field index func (rs *Rows) ScanStructByIndex(dest ...interface{}) error { if len(dest) == 0 { return errors.New("at least one struct") @@ -94,7 +96,7 @@ func fieldByName(v reflect.Value, name string) reflect.Value { return reflect.Zero(t) } -// scan data to a struct's pointer according field name +// ScanStructByName scan data to a struct's pointer according field name func (rs *Rows) ScanStructByName(dest interface{}) error { vv := reflect.ValueOf(dest) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -120,7 +122,7 @@ func (rs *Rows) ScanStructByName(dest interface{}) error { return rs.Rows.Scan(newDest...) } -// scan data to a slice's pointer, slice's length should equal to columns' number +// ScanSlice scan data to a slice's pointer, slice's length should equal to columns' number func (rs *Rows) ScanSlice(dest interface{}) error { vv := reflect.ValueOf(dest) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice { @@ -155,7 +157,7 @@ func (rs *Rows) ScanSlice(dest interface{}) error { return nil } -// scan data to a map's pointer +// ScanMap scan data to a map's pointer func (rs *Rows) ScanMap(dest interface{}) error { vv := reflect.ValueOf(dest) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -187,6 +189,7 @@ func (rs *Rows) ScanMap(dest interface{}) error { return nil } +// Row reprents a row of a tab type Row struct { rows *Rows // One of these two will be non-nil: @@ -205,6 +208,7 @@ func NewRow(rows *Rows, err error) *Row { return &Row{rows, err} } +// Columns returns all columns of the row func (row *Row) Columns() ([]string, error) { if row.err != nil { return nil, row.err @@ -212,6 +216,7 @@ func (row *Row) Columns() ([]string, error) { return row.rows.Columns() } +// Scan retrieves all row column values func (row *Row) Scan(dest ...interface{}) error { if row.err != nil { return row.err @@ -238,6 +243,7 @@ func (row *Row) Scan(dest ...interface{}) error { return row.rows.Close() } +// ScanStructByName retrieves all row column values into a struct func (row *Row) ScanStructByName(dest interface{}) error { if row.err != nil { return row.err @@ -258,6 +264,7 @@ func (row *Row) ScanStructByName(dest interface{}) error { return row.rows.Close() } +// ScanStructByIndex retrieves all row column values into a struct func (row *Row) ScanStructByIndex(dest interface{}) error { if row.err != nil { return row.err @@ -278,7 +285,7 @@ func (row *Row) ScanStructByIndex(dest interface{}) error { return row.rows.Close() } -// scan data to a slice's pointer, slice's length should equal to columns' number +// ScanSlice scan data to a slice's pointer, slice's length should equal to columns' number func (row *Row) ScanSlice(dest interface{}) error { if row.err != nil { return row.err @@ -300,7 +307,7 @@ func (row *Row) ScanSlice(dest interface{}) error { return row.rows.Close() } -// scan data to a map's pointer +// ScanMap scan data to a map's pointer func (row *Row) ScanMap(dest interface{}) error { if row.err != nil { return row.err @@ -322,6 +329,7 @@ func (row *Row) ScanMap(dest interface{}) error { return row.rows.Close() } +// ToMapString returns all clumns of this record func (row *Row) ToMapString() (map[string]string, error) { cols, err := row.Columns() if err != nil { diff --git a/core/scan.go b/core/scan.go index 897b5341..1e7e4525 100644 --- a/core/scan.go +++ b/core/scan.go @@ -10,12 +10,14 @@ import ( "time" ) +// NullTime defines a customize type NullTime type NullTime time.Time var ( _ driver.Valuer = NullTime{} ) +// Scan implements driver.Valuer func (ns *NullTime) Scan(value interface{}) error { if value == nil { return nil @@ -58,9 +60,11 @@ func convertTime(dest *NullTime, src interface{}) error { return nil } +// EmptyScanner represents an empty scanner type EmptyScanner struct { } +// Scan implements func (EmptyScanner) Scan(src interface{}) error { return nil } diff --git a/core/stmt.go b/core/stmt.go index d46ac9c6..260843d5 100644 --- a/core/stmt.go +++ b/core/stmt.go @@ -21,6 +21,7 @@ type Stmt struct { query string } +// PrepareContext creates a prepare statement func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { names := make(map[string]int) var i int @@ -42,10 +43,12 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { return &Stmt{stmt, db, names, query}, nil } +// Prepare creates a prepare statement func (db *DB) Prepare(query string) (*Stmt, error) { return db.PrepareContext(context.Background(), query) } +// ExecMapContext execute with map func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, error) { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -59,10 +62,12 @@ func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, return s.ExecContext(ctx, args...) } +// ExecMap executes with map func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) { return s.ExecMapContext(context.Background(), mp) } +// ExecStructContext executes with struct func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Result, error) { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -76,10 +81,12 @@ func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Resul return s.ExecContext(ctx, args...) } +// ExecStruct executes with struct func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { return s.ExecStructContext(context.Background(), st) } +// ExecContext with args func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { hookCtx := contexts.NewContextHook(ctx, s.query, args) ctx, err := s.db.beforeProcess(hookCtx) @@ -94,6 +101,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result return res, nil } +// QueryContext query with args func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { hookCtx := contexts.NewContextHook(ctx, s.query, args) ctx, err := s.db.beforeProcess(hookCtx) @@ -108,10 +116,12 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er return &Rows{rows, s.db}, nil } +// Query query with args func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return s.QueryContext(context.Background(), args...) } +// QueryMapContext query with map func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, error) { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -126,10 +136,12 @@ func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, erro return s.QueryContext(ctx, args...) } +// QueryMap query with map func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) { return s.QueryMapContext(context.Background(), mp) } +// QueryStructContext query with struct func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, error) { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -144,19 +156,23 @@ func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, e return s.QueryContext(ctx, args...) } +// QueryStruct query with struct func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) { return s.QueryStructContext(context.Background(), st) } +// QueryRowContext query row with args func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row { rows, err := s.QueryContext(ctx, args...) return &Row{rows, err} } +// QueryRow query row with args func (s *Stmt) QueryRow(args ...interface{}) *Row { return s.QueryRowContext(context.Background(), args...) } +// QueryRowMapContext query row with map func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -171,10 +187,12 @@ func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row { return s.QueryRowContext(ctx, args...) } +// QueryRowMap query row with map func (s *Stmt) QueryRowMap(mp interface{}) *Row { return s.QueryRowMapContext(context.Background(), mp) } +// QueryRowStructContext query row with struct func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -189,6 +207,7 @@ func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row { return s.QueryRowContext(ctx, args...) } +// QueryRowStruct query row with struct func (s *Stmt) QueryRowStruct(st interface{}) *Row { return s.QueryRowStructContext(context.Background(), st) } diff --git a/core/tx.go b/core/tx.go index 24d548b3..a2f745f8 100644 --- a/core/tx.go +++ b/core/tx.go @@ -51,10 +51,7 @@ func (tx *Tx) Commit() error { } err = tx.Tx.Commit() hookCtx.End(ctx, nil, err) - if err := tx.db.afterProcess(hookCtx); err != nil { - return err - } - return nil + return tx.db.afterProcess(hookCtx) } // Rollback rollback the transaction @@ -66,10 +63,7 @@ func (tx *Tx) Rollback() error { } err = tx.Tx.Rollback() hookCtx.End(ctx, nil, err) - if err := tx.db.afterProcess(hookCtx); err != nil { - return err - } - return nil + return tx.db.afterProcess(hookCtx) } // PrepareContext prepare the query diff --git a/integrations/session_delete_test.go b/integrations/session_delete_test.go index f3565963..cc7e861d 100644 --- a/integrations/session_delete_test.go +++ b/integrations/session_delete_test.go @@ -97,6 +97,7 @@ func TestDeleted(t *testing.T) { // Test normal Find() var records1 []Deleted err = testEngine.Where("`"+testEngine.GetColumnMapper().Obj2Table("Id")+"` > 0").Find(&records1, &Deleted{}) + assert.NoError(t, err) assert.EqualValues(t, 3, len(records1)) // Test normal Get() @@ -132,6 +133,7 @@ func TestDeleted(t *testing.T) { record2 := &Deleted{} has, err = testEngine.ID(2).Get(record2) assert.NoError(t, err) + assert.True(t, has) assert.True(t, record2.DeletedAt.IsZero()) // Test find all records whatever `deleted`. diff --git a/internal/statements/cache.go b/internal/statements/cache.go index cb33df08..669cd018 100644 --- a/internal/statements/cache.go +++ b/internal/statements/cache.go @@ -12,6 +12,7 @@ import ( "xorm.io/xorm/schemas" ) +// ConvertIDSQL converts SQL with id func (statement *Statement) ConvertIDSQL(sqlStr string) string { if statement.RefTable != nil { cols := statement.RefTable.PKColumns() @@ -37,6 +38,7 @@ func (statement *Statement) ConvertIDSQL(sqlStr string) string { return "" } +// ConvertUpdateSQL converts update SQL func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) { if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { return "", "" diff --git a/internal/statements/query.go b/internal/statements/query.go index ab3021bf..f1b36770 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -14,6 +14,7 @@ import ( "xorm.io/xorm/schemas" ) +// GenQuerySQL generate query SQL func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) { if len(sqlOrArgs) > 0 { return statement.ConvertSQLOrArgs(sqlOrArgs...) @@ -72,6 +73,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int return sqlStr, args, nil } +// GenSumSQL generates sum SQL func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { if statement.RawSQL != "" { return statement.GenRawSQL(), statement.RawParams, nil @@ -102,6 +104,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri return sqlStr, append(statement.joinArgs, condArgs...), nil } +// GenGetSQL generates Get SQL func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) { v := rValue(bean) isStruct := v.Kind() == reflect.Struct @@ -316,6 +319,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB return buf.String(), condArgs, nil } +// GenExistSQL generates Exist SQL func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) { if statement.RawSQL != "" { return statement.GenRawSQL(), statement.RawParams, nil @@ -385,6 +389,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac return sqlStr, args, nil } +// GenFindSQL generates Find SQL func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) { if statement.RawSQL != "" { return statement.GenRawSQL(), statement.RawParams, nil diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 87f785ae..3dd036a6 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -90,14 +90,11 @@ func NewStatement(dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZ return statement } +// SetTableName set table name func (statement *Statement) SetTableName(tableName string) { statement.tableName = tableName } -func (statement *Statement) omitStr() string { - return statement.dialect.Quoter().Join(statement.OmitColumnMap, " ,") -} - // GenRawSQL generates correct raw sql func (statement *Statement) GenRawSQL() string { return statement.ReplaceQuote(statement.RawSQL) @@ -112,6 +109,7 @@ func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []int return statement.ReplaceQuote(condSQL), condArgs, nil } +// ReplaceQuote replace sql key words with quote func (statement *Statement) ReplaceQuote(sql string) string { if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL || statement.dialect.URI().DBType == schemas.SQLITE { @@ -591,7 +589,7 @@ func (statement *Statement) Having(conditions string) *Statement { return statement } -// Unscoped always disable struct tag "deleted" +// SetUnscoped always disable struct tag "deleted" func (statement *Statement) SetUnscoped() *Statement { statement.unscoped = true return statement @@ -923,10 +921,7 @@ func (statement *Statement) mergeConds(bean interface{}) error { statement.cond = statement.cond.And(autoCond) } - if err := statement.ProcessIDParam(); err != nil { - return err - } - return nil + return statement.ProcessIDParam() } // GenConds generates conditions diff --git a/internal/statements/statement_args.go b/internal/statements/statement_args.go index dc14467d..64089c1e 100644 --- a/internal/statements/statement_args.go +++ b/internal/statements/statement_args.go @@ -77,6 +77,7 @@ func convertArg(arg interface{}, convertFunc func(string) string) string { const insertSelectPlaceHolder = true +// WriteArg writes an arg func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error { switch argv := arg.(type) { case *builder.Builder: @@ -116,6 +117,7 @@ func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) er return nil } +// WriteArgs writes args func (statement *Statement) WriteArgs(w *builder.BytesWriter, args []interface{}) error { for i, arg := range args { if err := statement.WriteArg(w, arg); err != nil { diff --git a/names/mapper.go b/names/mapper.go index 79add76e..b0ce8076 100644 --- a/names/mapper.go +++ b/names/mapper.go @@ -16,6 +16,7 @@ type Mapper interface { Table2Obj(string) string } +// CacheMapper represents a cache mapper type CacheMapper struct { oriMapper Mapper obj2tableCache map[string]string @@ -24,12 +25,14 @@ type CacheMapper struct { table2objMutex sync.RWMutex } +// NewCacheMapper creates a cache mapper func NewCacheMapper(mapper Mapper) *CacheMapper { return &CacheMapper{oriMapper: mapper, obj2tableCache: make(map[string]string), table2objCache: make(map[string]string), } } +// Obj2Table implements Mapper func (m *CacheMapper) Obj2Table(o string) string { m.obj2tableMutex.RLock() t, ok := m.obj2tableCache[o] @@ -45,6 +48,7 @@ func (m *CacheMapper) Obj2Table(o string) string { return t } +// Table2Obj implements Mapper func (m *CacheMapper) Table2Obj(t string) string { m.table2objMutex.RLock() o, ok := m.table2objCache[t] @@ -60,15 +64,17 @@ func (m *CacheMapper) Table2Obj(t string) string { return o } -// SameMapper implements IMapper and provides same name between struct and +// SameMapper implements Mapper and provides same name between struct and // database table type SameMapper struct { } +// Obj2Table implements Mapper func (m SameMapper) Obj2Table(o string) string { return o } +// Table2Obj implements Mapper func (m SameMapper) Table2Obj(t string) string { return t } @@ -98,6 +104,7 @@ func snakeCasedName(name string) string { return b2s(newstr) } +// Obj2Table implements Mapper func (mapper SnakeMapper) Obj2Table(name string) string { return snakeCasedName(name) } @@ -127,6 +134,7 @@ func titleCasedName(name string) string { return b2s(newstr) } +// Table2Obj implements Mapper func (mapper SnakeMapper) Table2Obj(name string) string { return titleCasedName(name) } @@ -168,10 +176,12 @@ func gonicCasedName(name string) string { return strings.ToLower(string(newstr)) } +// Obj2Table implements Mapper func (mapper GonicMapper) Obj2Table(name string) string { return gonicCasedName(name) } +// Table2Obj implements Mapper func (mapper GonicMapper) Table2Obj(name string) string { newstr := make([]rune, 0) @@ -234,14 +244,17 @@ type PrefixMapper struct { Prefix string } +// Obj2Table implements Mapper func (mapper PrefixMapper) Obj2Table(name string) string { return mapper.Prefix + mapper.Mapper.Obj2Table(name) } +// Table2Obj implements Mapper func (mapper PrefixMapper) Table2Obj(name string) string { return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) } +// NewPrefixMapper creates a prefix mapper func NewPrefixMapper(mapper Mapper, prefix string) PrefixMapper { return PrefixMapper{mapper, prefix} } @@ -252,14 +265,17 @@ type SuffixMapper struct { Suffix string } +// Obj2Table implements Mapper func (mapper SuffixMapper) Obj2Table(name string) string { return mapper.Mapper.Obj2Table(name) + mapper.Suffix } +// Table2Obj implements Mapper func (mapper SuffixMapper) Table2Obj(name string) string { return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)]) } +// NewSuffixMapper creates a suffix mapper func NewSuffixMapper(mapper Mapper, suffix string) SuffixMapper { return SuffixMapper{mapper, suffix} } diff --git a/names/table_name.go b/names/table_name.go index 0afb1ae3..cc0e9274 100644 --- a/names/table_name.go +++ b/names/table_name.go @@ -19,6 +19,7 @@ var ( tvCache sync.Map ) +// GetTableName returns table name func GetTableName(mapper Mapper, v reflect.Value) string { if v.Type().Implements(tpTableName) { return v.Interface().(TableName).TableName() diff --git a/schemas/column.go b/schemas/column.go index 4f32afab..5808b84d 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -13,6 +13,7 @@ import ( "time" ) +// enumerates all database mapping way const ( TWOSIDES = iota + 1 ONLYTODB diff --git a/schemas/index.go b/schemas/index.go index 9541250f..8f31af52 100644 --- a/schemas/index.go +++ b/schemas/index.go @@ -28,6 +28,7 @@ func NewIndex(name string, indexType int) *Index { return &Index{true, name, indexType, make([]string, 0)} } +// XName returns the special index name for the table func (index *Index) XName(tableName string) string { if !strings.HasPrefix(index.Name, "UQE_") && !strings.HasPrefix(index.Name, "IDX_") { @@ -43,11 +44,10 @@ func (index *Index) XName(tableName string) string { // AddColumn add columns which will be composite index func (index *Index) AddColumn(cols ...string) { - for _, col := range cols { - index.Cols = append(index.Cols, col) - } + index.Cols = append(index.Cols, cols...) } +// Equal return true if the two Index is equal func (index *Index) Equal(dst *Index) bool { if index.Type != dst.Type { return false diff --git a/schemas/pk.go b/schemas/pk.go index 03916b44..da3c7899 100644 --- a/schemas/pk.go +++ b/schemas/pk.go @@ -11,13 +11,16 @@ import ( "xorm.io/xorm/internal/utils" ) +// PK represents primary key values type PK []interface{} +// NewPK creates primay keys func NewPK(pks ...interface{}) *PK { p := PK(pks) return &p } +// IsZero return true if primay keys are zero func (p *PK) IsZero() bool { for _, k := range *p { if utils.IsZero(k) { @@ -27,6 +30,7 @@ func (p *PK) IsZero() bool { return false } +// ToString convert to SQL string func (p *PK) ToString() (string, error) { buf := new(bytes.Buffer) enc := gob.NewEncoder(buf) @@ -34,6 +38,7 @@ func (p *PK) ToString() (string, error) { return buf.String(), err } +// FromString reads content to load primary keys func (p *PK) FromString(content string) error { dec := gob.NewDecoder(bytes.NewBufferString(content)) err := dec.Decode(p) diff --git a/schemas/quote.go b/schemas/quote.go index a0070048..71040ad9 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -16,10 +16,10 @@ type Quoter struct { } var ( - // AlwaysFalseReverse always think it's not a reverse word + // AlwaysNoReserve always think it's not a reverse word AlwaysNoReserve = func(string) bool { return false } - // AlwaysReverse always reverse the word + // AlwaysReserve always reverse the word AlwaysReserve = func(string) bool { return true } // CommanQuoteMark represnets the common quote mark @@ -29,10 +29,12 @@ var ( CommonQuoter = Quoter{CommanQuoteMark, CommanQuoteMark, AlwaysReserve} ) +// IsEmpty return true if no prefix and suffix func (q Quoter) IsEmpty() bool { return q.Prefix == 0 && q.Suffix == 0 } +// Quote quote a string func (q Quoter) Quote(s string) string { var buf strings.Builder q.QuoteTo(&buf, s) @@ -59,12 +61,14 @@ func (q Quoter) Trim(s string) string { return buf.String() } +// Join joins a slice with quoters func (q Quoter) Join(a []string, sep string) string { var b strings.Builder q.JoinWrite(&b, a, sep) return b.String() } +// JoinWrite writes quoted content to a builder func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error { if len(a) == 0 { return nil diff --git a/schemas/table.go b/schemas/table.go index 7ca9531f..bfa517aa 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -90,23 +90,28 @@ func (table *Table) PKColumns() []*Column { return columns } +// ColumnType returns a column's type func (table *Table) ColumnType(name string) reflect.Type { t, _ := table.Type.FieldByName(name) return t.Type } +// AutoIncrColumn returns autoincrement column func (table *Table) AutoIncrColumn() *Column { return table.GetColumn(table.AutoIncrement) } +// VersionColumn returns version column's information func (table *Table) VersionColumn() *Column { return table.GetColumn(table.Version) } +// UpdatedColumn returns updated column's information func (table *Table) UpdatedColumn() *Column { return table.GetColumn(table.Updated) } +// DeletedColumn returns deleted column's information func (table *Table) DeletedColumn() *Column { return table.GetColumn(table.Deleted) } diff --git a/schemas/type.go b/schemas/type.go index c6cdfb87..fc02f015 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -222,53 +222,55 @@ var ( // !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison var ( - c_EMPTY_STRING string - c_BOOL_DEFAULT bool - c_BYTE_DEFAULT byte - c_COMPLEX64_DEFAULT complex64 - c_COMPLEX128_DEFAULT complex128 - c_FLOAT32_DEFAULT float32 - c_FLOAT64_DEFAULT float64 - c_INT64_DEFAULT int64 - c_UINT64_DEFAULT uint64 - c_INT32_DEFAULT int32 - c_UINT32_DEFAULT uint32 - c_INT16_DEFAULT int16 - c_UINT16_DEFAULT uint16 - c_INT8_DEFAULT int8 - c_UINT8_DEFAULT uint8 - c_INT_DEFAULT int - c_UINT_DEFAULT uint - c_TIME_DEFAULT time.Time + emptyString string + boolDefault bool + byteDefault byte + complex64Default complex64 + complex128Default complex128 + float32Default float32 + float64Default float64 + int64Default int64 + uint64Default uint64 + int32Default int32 + uint32Default uint32 + int16Default int16 + uint16Default uint16 + int8Default int8 + uint8Default uint8 + intDefault int + uintDefault uint + timeDefault time.Time ) +// enumerates all types var ( - IntType = reflect.TypeOf(c_INT_DEFAULT) - Int8Type = reflect.TypeOf(c_INT8_DEFAULT) - Int16Type = reflect.TypeOf(c_INT16_DEFAULT) - Int32Type = reflect.TypeOf(c_INT32_DEFAULT) - Int64Type = reflect.TypeOf(c_INT64_DEFAULT) + IntType = reflect.TypeOf(intDefault) + Int8Type = reflect.TypeOf(int8Default) + Int16Type = reflect.TypeOf(int16Default) + Int32Type = reflect.TypeOf(int32Default) + Int64Type = reflect.TypeOf(int64Default) - UintType = reflect.TypeOf(c_UINT_DEFAULT) - Uint8Type = reflect.TypeOf(c_UINT8_DEFAULT) - Uint16Type = reflect.TypeOf(c_UINT16_DEFAULT) - Uint32Type = reflect.TypeOf(c_UINT32_DEFAULT) - Uint64Type = reflect.TypeOf(c_UINT64_DEFAULT) + UintType = reflect.TypeOf(uintDefault) + Uint8Type = reflect.TypeOf(uint8Default) + Uint16Type = reflect.TypeOf(uint16Default) + Uint32Type = reflect.TypeOf(uint32Default) + Uint64Type = reflect.TypeOf(uint64Default) - Float32Type = reflect.TypeOf(c_FLOAT32_DEFAULT) - Float64Type = reflect.TypeOf(c_FLOAT64_DEFAULT) + Float32Type = reflect.TypeOf(float32Default) + Float64Type = reflect.TypeOf(float64Default) - Complex64Type = reflect.TypeOf(c_COMPLEX64_DEFAULT) - Complex128Type = reflect.TypeOf(c_COMPLEX128_DEFAULT) + Complex64Type = reflect.TypeOf(complex64Default) + Complex128Type = reflect.TypeOf(complex128Default) - StringType = reflect.TypeOf(c_EMPTY_STRING) - BoolType = reflect.TypeOf(c_BOOL_DEFAULT) - ByteType = reflect.TypeOf(c_BYTE_DEFAULT) + StringType = reflect.TypeOf(emptyString) + BoolType = reflect.TypeOf(boolDefault) + ByteType = reflect.TypeOf(byteDefault) BytesType = reflect.SliceOf(ByteType) - TimeType = reflect.TypeOf(c_TIME_DEFAULT) + TimeType = reflect.TypeOf(timeDefault) ) +// enumerates all types var ( PtrIntType = reflect.PtrTo(IntType) PtrInt8Type = reflect.PtrTo(Int8Type) @@ -313,7 +315,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) { case reflect.Complex64, reflect.Complex128: st = SQLType{Varchar, 64, 0} case reflect.Array, reflect.Slice, reflect.Map: - if t.Elem() == reflect.TypeOf(c_BYTE_DEFAULT) { + if t.Elem() == reflect.TypeOf(byteDefault) { st = SQLType{Blob, 0, 0} } else { st = SQLType{Text, 0, 0} @@ -337,7 +339,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) { return } -// default sql type change to go types +// SQLType2Type convert default sql type change to go types func SQLType2Type(st SQLType) reflect.Type { name := strings.ToUpper(st.Name) switch name { @@ -356,7 +358,7 @@ func SQLType2Type(st SQLType) reflect.Type { case Bool: return reflect.TypeOf(true) case DateTime, Date, Time, TimeStamp, TimeStampz, SmallDateTime, Year: - return reflect.TypeOf(c_TIME_DEFAULT) + return reflect.TypeOf(timeDefault) case Decimal, Numeric, Money, SmallMoney: return reflect.TypeOf("") default: diff --git a/tags/parser.go b/tags/parser.go index 45dd6d9d..5ad67b53 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -21,9 +21,11 @@ import ( ) var ( + // ErrUnsupportedType represents an unsupported type error ErrUnsupportedType = errors.New("Unsupported type") ) +// Parser represents a parser for xorm tag type Parser struct { identifier string dialect dialects.Dialect @@ -34,6 +36,7 @@ type Parser struct { tableCache sync.Map // map[reflect.Type]*schemas.Table } +// NewParser creates a tag parser func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnMapper names.Mapper, cacherMgr *caches.Manager) *Parser { return &Parser{ identifier: identifier, @@ -45,29 +48,35 @@ func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnM } } +// GetTableMapper returns table mapper func (parser *Parser) GetTableMapper() names.Mapper { return parser.tableMapper } +// SetTableMapper sets table mapper func (parser *Parser) SetTableMapper(mapper names.Mapper) { parser.ClearCaches() parser.tableMapper = mapper } +// GetColumnMapper returns column mapper func (parser *Parser) GetColumnMapper() names.Mapper { return parser.columnMapper } +// SetColumnMapper sets column mapper func (parser *Parser) SetColumnMapper(mapper names.Mapper) { parser.ClearCaches() parser.columnMapper = mapper } +// SetIdentifier sets tag identifier func (parser *Parser) SetIdentifier(identifier string) { parser.ClearCaches() parser.identifier = identifier } +// ParseWithCache parse a struct with cache func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) { t := v.Type() tableI, ok := parser.tableCache.Load(t)