Fix setschema #1606

Merged
lunny merged 2 commits from lunny/fix_set_schema into master 2020-03-17 06:50:10 +00:00
4 changed files with 37 additions and 30 deletions

View File

@ -31,7 +31,7 @@ type URI struct {
}
// SetSchema set schema
func (uri URI) SetSchema(schema string) {
func (uri *URI) SetSchema(schema string) {
if uri.DBType == schemas.POSTGRES {
uri.Schema = schema
}

View File

@ -1000,7 +1000,7 @@ func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string
}
func (db *postgres) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName}
args := []interface{}{db.uri.Schema, tableName, db.uri.Schema}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
@ -1011,14 +1011,7 @@ FROM pg_attribute f
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
var f string
if len(db.uri.Schema) != 0 {
args = append(args, db.uri.Schema)
f = " AND s.table_schema = $2"
}
s = fmt.Sprintf(s, f)
WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_schema = $3 AND f.attnum > 0 ORDER BY f.attnum;`
rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {

View File

@ -350,49 +350,50 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
// dumpTables dump database all table structs and data to w with specify db type
func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error {
var dialect dialects.Dialect
var distDBName string
var dstDialect dialects.Dialect
if len(tp) == 0 {
dialect = engine.dialect
distDBName = string(engine.dialect.URI().DBType)
dstDialect = engine.dialect
} else {
dialect = dialects.QueryDialect(tp[0])
if dialect == nil {
dstDialect = dialects.QueryDialect(tp[0])
if dstDialect == nil {
return errors.New("Unsupported database type")
}
var destURI dialects.URI
uri := engine.dialect.URI()
destURI = *uri
dialect.Init(nil, &destURI)
distDBName = string(tp[0])
destURI := *uri
dstDialect.Init(nil, &destURI)
}
_, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n",
time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, strings.ToUpper(distDBName)))
time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, dstDialect.URI().DBType))
if err != nil {
return err
}
for i, table := range tables {
tableName := table.Name
if dstDialect.URI().Schema != "" {
tableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, table.Name)
}
if i > 0 {
_, err = io.WriteString(w, "\n")
if err != nil {
return err
}
}
sqls, _ := dialect.CreateTableSQL(table, "")
sqls, _ := dstDialect.CreateTableSQL(table, tableName)
for _, s := range sqls {
_, err = io.WriteString(w, s+";\n")
if err != nil {
return err
}
}
if len(table.PKColumns()) > 0 && engine.dialect.URI().DBType == schemas.MSSQL {
if len(table.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL {
fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", table.Name)
}
for _, index := range table.Indexes {
_, err = io.WriteString(w, dialect.CreateIndexSQL(table.Name, index)+";\n")
_, err = io.WriteString(w, dstDialect.CreateIndexSQL(table.Name, index)+";\n")
if err != nil {
return err
}
@ -400,9 +401,9 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
cols := table.ColumnsSeq()
colNames := engine.dialect.Quoter().Join(cols, ", ")
destColNames := dialect.Quoter().Join(cols, ", ")
destColNames := dstDialect.Quoter().Join(cols, ", ")
rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(table.Name))
rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(tableName))
if err != nil {
return err
}
@ -415,7 +416,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return err
}
_, err = io.WriteString(w, "INSERT INTO "+dialect.Quoter().Quote(table.Name)+" ("+destColNames+") VALUES (")
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(tableName)+" ("+destColNames+") VALUES (")
if err != nil {
return err
}
@ -438,7 +439,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
}
} else if col.SQLType.IsBlob() {
if reflect.TypeOf(d).Kind() == reflect.Slice {
temp += fmt.Sprintf(", %s", dialect.FormatBytes(d.([]byte)))
temp += fmt.Sprintf(", %s", dstDialect.FormatBytes(d.([]byte)))
} else if reflect.TypeOf(d).Kind() == reflect.String {
temp += fmt.Sprintf(", '%s'", d.(string))
}
@ -485,8 +486,8 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
}
// FIXME: Hack for postgres
if dialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil {
_, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quoter().Quote(table.Name)+"), 1), false);\n")
if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil {
_, err = io.WriteString(w, "SELECT setval('"+tableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(tableName)+"), 1), false);\n")
if err != nil {
return err
}

View File

@ -12,6 +12,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"xorm.io/xorm/schemas"
)
func TestPingContext(t *testing.T) {
@ -97,3 +98,15 @@ func TestDump(t *testing.T) {
assert.NoError(t, err)
assert.NoError(t, sess.Commit())
}
func TestSetSchema(t *testing.T) {
assert.NoError(t, prepareEngine())
if testEngine.Dialect().URI().DBType == schemas.POSTGRES {
oldSchema := testEngine.Dialect().URI().Schema
testEngine.SetSchema("my_schema")
assert.EqualValues(t, "my_schema", testEngine.Dialect().URI().Schema)
testEngine.SetSchema(oldSchema)
assert.EqualValues(t, oldSchema, testEngine.Dialect().URI().Schema)
}
}