diff --git a/engine.go b/engine.go index f52148a..ee98ed0 100644 --- a/engine.go +++ b/engine.go @@ -168,6 +168,8 @@ func (engine *Engine) SetLogger(logger interface{}) { realLogger = t case log.Logger: realLogger = log.NewLoggerAdapter(t) + default: + panic("logger should implement either log.ContextLogger or log.Logger") } engine.logger = realLogger engine.DB().Logger = realLogger @@ -209,6 +211,11 @@ func (engine *Engine) SetColumnMapper(mapper names.Mapper) { engine.tagParser.SetColumnMapper(mapper) } +// SetTagIdentifier set the tag identifier +func (engine *Engine) SetTagIdentifier(tagIdentifier string) { + engine.tagParser.SetIdentifier(tagIdentifier) +} + // Quote Use QuoteStr quote the string sql func (engine *Engine) Quote(value string) string { value = strings.TrimSpace(value) diff --git a/engine_group.go b/engine_group.go index cdd9dd4..3e91cbd 100644 --- a/engine_group.go +++ b/engine_group.go @@ -167,6 +167,14 @@ func (eg *EngineGroup) SetMapper(mapper names.Mapper) { } } +// SetTagIdentifier set the tag identifier +func (eg *EngineGroup) SetTagIdentifier(tagIdentifier string) { + eg.Engine.SetTagIdentifier(tagIdentifier) + for i := 0; i < len(eg.slaves); i++ { + eg.slaves[i].SetTagIdentifier(tagIdentifier) + } +} + // SetMaxIdleConns set the max idle connections on pool, default is 2 func (eg *EngineGroup) SetMaxIdleConns(conns int) { eg.Engine.DB().SetMaxIdleConns(conns) diff --git a/interface.go b/interface.go index 0fe9cbe..55162c8 100644 --- a/interface.go +++ b/interface.go @@ -101,6 +101,7 @@ type EngineInterface interface { SetCacher(string, caches.Cacher) SetConnMaxLifetime(time.Duration) SetColumnMapper(names.Mapper) + SetTagIdentifier(string) SetDefaultCacher(caches.Cacher) SetLogger(logger interface{}) SetLogLevel(log.LogLevel) diff --git a/session_tx.go b/session_tx.go index 5779170..f50bbf1 100644 --- a/session_tx.go +++ b/session_tx.go @@ -84,3 +84,8 @@ func (session *Session) Commit() error { } return nil } + +// if current session is in a transaction +func (session *Session) IsInTx() bool { + return !session.isAutoCommit +} \ No newline at end of file diff --git a/tags/parser.go b/tags/parser.go index a301d12..45dd6d9 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -63,6 +63,11 @@ func (parser *Parser) SetColumnMapper(mapper names.Mapper) { parser.columnMapper = mapper } +func (parser *Parser) SetIdentifier(identifier string) { + parser.ClearCaches() + parser.identifier = identifier +} + func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) { t := v.Type() tableI, ok := parser.tableCache.Load(t) diff --git a/tags/parser_test.go b/tags/parser_test.go index ff304a5..c3bf805 100644 --- a/tags/parser_test.go +++ b/tags/parser_test.go @@ -80,3 +80,26 @@ func TestUnexportField(t *testing.T) { assert.NotEqual(t, "public", col.Name) } } + +func TestParseWithOtherIdentifier(t *testing.T) { + parser := NewParser( + "xorm", + dialects.QueryDialect("mysql"), + names.GonicMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithDBTag struct { + FieldFoo string `db:"foo"` + } + parser.SetIdentifier("db") + table, err := parser.Parse(reflect.ValueOf(new(StructWithDBTag))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_db_tag", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + + for _, col := range table.Columns() { + assert.EqualValues(t, "foo", col.Name) + } +}