From 5f616ccc5edc784bffbd18246d62c24e11a90605 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 17 Mar 2020 10:22:08 +0800 Subject: [PATCH 1/2] Fix setschema --- dialects/dialect.go | 2 +- engine_test.go | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/dialects/dialect.go b/dialects/dialect.go index 4b9976f7..4fdf35e9 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -31,7 +31,7 @@ type URI struct { } // SetSchema set schema -func (uri URI) SetSchema(schema string) { +func (uri *URI) SetSchema(schema string) { if uri.DBType == schemas.POSTGRES { uri.Schema = schema } diff --git a/engine_test.go b/engine_test.go index 459d63c4..ab454d0d 100644 --- a/engine_test.go +++ b/engine_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "xorm.io/xorm/schemas" ) func TestPingContext(t *testing.T) { @@ -97,3 +98,15 @@ func TestDump(t *testing.T) { assert.NoError(t, err) assert.NoError(t, sess.Commit()) } + +func TestSetSchema(t *testing.T) { + assert.NoError(t, prepareEngine()) + + if testEngine.Dialect().URI().DBType == schemas.POSTGRES { + oldSchema := testEngine.Dialect().URI().Schema + testEngine.SetSchema("my_schema") + assert.EqualValues(t, "my_schema", testEngine.Dialect().URI().Schema) + testEngine.SetSchema(oldSchema) + assert.EqualValues(t, oldSchema, testEngine.Dialect().URI().Schema) + } +} -- 2.40.1 From f3dc74c6e07c80b81e9dd538e5e1f575800acd64 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 17 Mar 2020 14:11:27 +0800 Subject: [PATCH 2/2] Fix schema --- dialects/postgres.go | 11 ++--------- engine.go | 41 +++++++++++++++++++++-------------------- 2 files changed, 23 insertions(+), 29 deletions(-) diff --git a/dialects/postgres.go b/dialects/postgres.go index 6fd9e64a..0a851fe2 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1000,7 +1000,7 @@ func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string } func (db *postgres) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { - args := []interface{}{tableName} + args := []interface{}{db.uri.Schema, tableName, db.uri.Schema} s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey @@ -1011,14 +1011,7 @@ FROM pg_attribute f LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name -WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` - - var f string - if len(db.uri.Schema) != 0 { - args = append(args, db.uri.Schema) - f = " AND s.table_schema = $2" - } - s = fmt.Sprintf(s, f) +WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_schema = $3 AND f.attnum > 0 ORDER BY f.attnum;` rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { diff --git a/engine.go b/engine.go index 99412c4f..52865a3b 100644 --- a/engine.go +++ b/engine.go @@ -350,49 +350,50 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch // dumpTables dump database all table structs and data to w with specify db type func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { - var dialect dialects.Dialect - var distDBName string + var dstDialect dialects.Dialect if len(tp) == 0 { - dialect = engine.dialect - distDBName = string(engine.dialect.URI().DBType) + dstDialect = engine.dialect } else { - dialect = dialects.QueryDialect(tp[0]) - if dialect == nil { + dstDialect = dialects.QueryDialect(tp[0]) + if dstDialect == nil { return errors.New("Unsupported database type") } - var destURI dialects.URI + uri := engine.dialect.URI() - destURI = *uri - dialect.Init(nil, &destURI) - distDBName = string(tp[0]) + destURI := *uri + dstDialect.Init(nil, &destURI) } _, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n", - time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, strings.ToUpper(distDBName))) + time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, dstDialect.URI().DBType)) if err != nil { return err } for i, table := range tables { + tableName := table.Name + if dstDialect.URI().Schema != "" { + tableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, table.Name) + } if i > 0 { _, err = io.WriteString(w, "\n") if err != nil { return err } } - sqls, _ := dialect.CreateTableSQL(table, "") + sqls, _ := dstDialect.CreateTableSQL(table, tableName) for _, s := range sqls { _, err = io.WriteString(w, s+";\n") if err != nil { return err } } - if len(table.PKColumns()) > 0 && engine.dialect.URI().DBType == schemas.MSSQL { + if len(table.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL { fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", table.Name) } for _, index := range table.Indexes { - _, err = io.WriteString(w, dialect.CreateIndexSQL(table.Name, index)+";\n") + _, err = io.WriteString(w, dstDialect.CreateIndexSQL(table.Name, index)+";\n") if err != nil { return err } @@ -400,9 +401,9 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch cols := table.ColumnsSeq() colNames := engine.dialect.Quoter().Join(cols, ", ") - destColNames := dialect.Quoter().Join(cols, ", ") + destColNames := dstDialect.Quoter().Join(cols, ", ") - rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(table.Name)) + rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(tableName)) if err != nil { return err } @@ -415,7 +416,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } - _, err = io.WriteString(w, "INSERT INTO "+dialect.Quoter().Quote(table.Name)+" ("+destColNames+") VALUES (") + _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(tableName)+" ("+destColNames+") VALUES (") if err != nil { return err } @@ -438,7 +439,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } } else if col.SQLType.IsBlob() { if reflect.TypeOf(d).Kind() == reflect.Slice { - temp += fmt.Sprintf(", %s", dialect.FormatBytes(d.([]byte))) + temp += fmt.Sprintf(", %s", dstDialect.FormatBytes(d.([]byte))) } else if reflect.TypeOf(d).Kind() == reflect.String { temp += fmt.Sprintf(", '%s'", d.(string)) } @@ -485,8 +486,8 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } // FIXME: Hack for postgres - if dialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { - _, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quoter().Quote(table.Name)+"), 1), false);\n") + if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { + _, err = io.WriteString(w, "SELECT setval('"+tableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(tableName)+"), 1), false);\n") if err != nil { return err } -- 2.40.1