Fix setschema #1606
|
@ -31,7 +31,7 @@ type URI struct {
|
|||
}
|
||||
|
||||
// SetSchema set schema
|
||||
func (uri URI) SetSchema(schema string) {
|
||||
func (uri *URI) SetSchema(schema string) {
|
||||
if uri.DBType == schemas.POSTGRES {
|
||||
uri.Schema = schema
|
||||
}
|
||||
|
|
|
@ -1000,7 +1000,7 @@ func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string
|
|||
}
|
||||
|
||||
func (db *postgres) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
args := []interface{}{tableName}
|
||||
args := []interface{}{db.uri.Schema, tableName, db.uri.Schema}
|
||||
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
|
||||
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
|
||||
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
|
||||
|
@ -1011,14 +1011,7 @@ FROM pg_attribute f
|
|||
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
|
||||
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
|
||||
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
|
||||
WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
|
||||
|
||||
var f string
|
||||
if len(db.uri.Schema) != 0 {
|
||||
args = append(args, db.uri.Schema)
|
||||
f = " AND s.table_schema = $2"
|
||||
}
|
||||
s = fmt.Sprintf(s, f)
|
||||
WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_schema = $3 AND f.attnum > 0 ORDER BY f.attnum;`
|
||||
|
||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
||||
if err != nil {
|
||||
|
|
41
engine.go
41
engine.go
|
@ -350,49 +350,50 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
|||
|
||||
// dumpTables dump database all table structs and data to w with specify db type
|
||||
func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error {
|
||||
var dialect dialects.Dialect
|
||||
var distDBName string
|
||||
var dstDialect dialects.Dialect
|
||||
if len(tp) == 0 {
|
||||
dialect = engine.dialect
|
||||
distDBName = string(engine.dialect.URI().DBType)
|
||||
dstDialect = engine.dialect
|
||||
} else {
|
||||
dialect = dialects.QueryDialect(tp[0])
|
||||
if dialect == nil {
|
||||
dstDialect = dialects.QueryDialect(tp[0])
|
||||
if dstDialect == nil {
|
||||
return errors.New("Unsupported database type")
|
||||
}
|
||||
var destURI dialects.URI
|
||||
|
||||
uri := engine.dialect.URI()
|
||||
destURI = *uri
|
||||
dialect.Init(nil, &destURI)
|
||||
distDBName = string(tp[0])
|
||||
destURI := *uri
|
||||
dstDialect.Init(nil, &destURI)
|
||||
}
|
||||
|
||||
_, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n",
|
||||
time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, strings.ToUpper(distDBName)))
|
||||
time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, dstDialect.URI().DBType))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, table := range tables {
|
||||
tableName := table.Name
|
||||
if dstDialect.URI().Schema != "" {
|
||||
tableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, table.Name)
|
||||
}
|
||||
if i > 0 {
|
||||
_, err = io.WriteString(w, "\n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
sqls, _ := dialect.CreateTableSQL(table, "")
|
||||
sqls, _ := dstDialect.CreateTableSQL(table, tableName)
|
||||
for _, s := range sqls {
|
||||
_, err = io.WriteString(w, s+";\n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(table.PKColumns()) > 0 && engine.dialect.URI().DBType == schemas.MSSQL {
|
||||
if len(table.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL {
|
||||
fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", table.Name)
|
||||
}
|
||||
|
||||
for _, index := range table.Indexes {
|
||||
_, err = io.WriteString(w, dialect.CreateIndexSQL(table.Name, index)+";\n")
|
||||
_, err = io.WriteString(w, dstDialect.CreateIndexSQL(table.Name, index)+";\n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -400,9 +401,9 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
|||
|
||||
cols := table.ColumnsSeq()
|
||||
colNames := engine.dialect.Quoter().Join(cols, ", ")
|
||||
destColNames := dialect.Quoter().Join(cols, ", ")
|
||||
destColNames := dstDialect.Quoter().Join(cols, ", ")
|
||||
|
||||
rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(table.Name))
|
||||
rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(tableName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -415,7 +416,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
|||
return err
|
||||
}
|
||||
|
||||
_, err = io.WriteString(w, "INSERT INTO "+dialect.Quoter().Quote(table.Name)+" ("+destColNames+") VALUES (")
|
||||
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(tableName)+" ("+destColNames+") VALUES (")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -438,7 +439,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
|||
}
|
||||
} else if col.SQLType.IsBlob() {
|
||||
if reflect.TypeOf(d).Kind() == reflect.Slice {
|
||||
temp += fmt.Sprintf(", %s", dialect.FormatBytes(d.([]byte)))
|
||||
temp += fmt.Sprintf(", %s", dstDialect.FormatBytes(d.([]byte)))
|
||||
} else if reflect.TypeOf(d).Kind() == reflect.String {
|
||||
temp += fmt.Sprintf(", '%s'", d.(string))
|
||||
}
|
||||
|
@ -485,8 +486,8 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
|||
}
|
||||
|
||||
// FIXME: Hack for postgres
|
||||
if dialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil {
|
||||
_, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quoter().Quote(table.Name)+"), 1), false);\n")
|
||||
if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil {
|
||||
_, err = io.WriteString(w, "SELECT setval('"+tableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(tableName)+"), 1), false);\n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
func TestPingContext(t *testing.T) {
|
||||
|
@ -97,3 +98,15 @@ func TestDump(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.NoError(t, sess.Commit())
|
||||
}
|
||||
|
||||
func TestSetSchema(t *testing.T) {
|
||||
assert.NoError(t, prepareEngine())
|
||||
|
||||
if testEngine.Dialect().URI().DBType == schemas.POSTGRES {
|
||||
oldSchema := testEngine.Dialect().URI().Schema
|
||||
testEngine.SetSchema("my_schema")
|
||||
assert.EqualValues(t, "my_schema", testEngine.Dialect().URI().Schema)
|
||||
testEngine.SetSchema(oldSchema)
|
||||
assert.EqualValues(t, oldSchema, testEngine.Dialect().URI().Schema)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user