Fix setschema #1606
|
@ -31,7 +31,7 @@ type URI struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSchema set schema
|
// SetSchema set schema
|
||||||
func (uri URI) SetSchema(schema string) {
|
func (uri *URI) SetSchema(schema string) {
|
||||||
if uri.DBType == schemas.POSTGRES {
|
if uri.DBType == schemas.POSTGRES {
|
||||||
uri.Schema = schema
|
uri.Schema = schema
|
||||||
}
|
}
|
||||||
|
|
|
@ -1000,7 +1000,7 @@ func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
func (db *postgres) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{db.uri.Schema, tableName, db.uri.Schema}
|
||||||
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
|
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
|
||||||
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
|
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
|
||||||
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
|
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
|
||||||
|
@ -1011,14 +1011,7 @@ FROM pg_attribute f
|
||||||
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
|
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
|
||||||
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
|
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
|
||||||
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
|
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
|
||||||
WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
|
WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_schema = $3 AND f.attnum > 0 ORDER BY f.attnum;`
|
||||||
|
|
||||||
var f string
|
|
||||||
if len(db.uri.Schema) != 0 {
|
|
||||||
args = append(args, db.uri.Schema)
|
|
||||||
f = " AND s.table_schema = $2"
|
|
||||||
}
|
|
||||||
s = fmt.Sprintf(s, f)
|
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := db.DB().QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
41
engine.go
41
engine.go
|
@ -350,49 +350,50 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
||||||
|
|
||||||
// dumpTables dump database all table structs and data to w with specify db type
|
// dumpTables dump database all table structs and data to w with specify db type
|
||||||
func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error {
|
func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error {
|
||||||
var dialect dialects.Dialect
|
var dstDialect dialects.Dialect
|
||||||
var distDBName string
|
|
||||||
if len(tp) == 0 {
|
if len(tp) == 0 {
|
||||||
dialect = engine.dialect
|
dstDialect = engine.dialect
|
||||||
distDBName = string(engine.dialect.URI().DBType)
|
|
||||||
} else {
|
} else {
|
||||||
dialect = dialects.QueryDialect(tp[0])
|
dstDialect = dialects.QueryDialect(tp[0])
|
||||||
if dialect == nil {
|
if dstDialect == nil {
|
||||||
return errors.New("Unsupported database type")
|
return errors.New("Unsupported database type")
|
||||||
}
|
}
|
||||||
var destURI dialects.URI
|
|
||||||
uri := engine.dialect.URI()
|
uri := engine.dialect.URI()
|
||||||
destURI = *uri
|
destURI := *uri
|
||||||
dialect.Init(nil, &destURI)
|
dstDialect.Init(nil, &destURI)
|
||||||
distDBName = string(tp[0])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n",
|
_, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n",
|
||||||
time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, strings.ToUpper(distDBName)))
|
time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, dstDialect.URI().DBType))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, table := range tables {
|
for i, table := range tables {
|
||||||
|
tableName := table.Name
|
||||||
|
if dstDialect.URI().Schema != "" {
|
||||||
|
tableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, table.Name)
|
||||||
|
}
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
_, err = io.WriteString(w, "\n")
|
_, err = io.WriteString(w, "\n")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sqls, _ := dialect.CreateTableSQL(table, "")
|
sqls, _ := dstDialect.CreateTableSQL(table, tableName)
|
||||||
for _, s := range sqls {
|
for _, s := range sqls {
|
||||||
_, err = io.WriteString(w, s+";\n")
|
_, err = io.WriteString(w, s+";\n")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(table.PKColumns()) > 0 && engine.dialect.URI().DBType == schemas.MSSQL {
|
if len(table.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL {
|
||||||
fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", table.Name)
|
fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", table.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, index := range table.Indexes {
|
for _, index := range table.Indexes {
|
||||||
_, err = io.WriteString(w, dialect.CreateIndexSQL(table.Name, index)+";\n")
|
_, err = io.WriteString(w, dstDialect.CreateIndexSQL(table.Name, index)+";\n")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -400,9 +401,9 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
||||||
|
|
||||||
cols := table.ColumnsSeq()
|
cols := table.ColumnsSeq()
|
||||||
colNames := engine.dialect.Quoter().Join(cols, ", ")
|
colNames := engine.dialect.Quoter().Join(cols, ", ")
|
||||||
destColNames := dialect.Quoter().Join(cols, ", ")
|
destColNames := dstDialect.Quoter().Join(cols, ", ")
|
||||||
|
|
||||||
rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(table.Name))
|
rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(tableName))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -415,7 +416,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = io.WriteString(w, "INSERT INTO "+dialect.Quoter().Quote(table.Name)+" ("+destColNames+") VALUES (")
|
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(tableName)+" ("+destColNames+") VALUES (")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -438,7 +439,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
||||||
}
|
}
|
||||||
} else if col.SQLType.IsBlob() {
|
} else if col.SQLType.IsBlob() {
|
||||||
if reflect.TypeOf(d).Kind() == reflect.Slice {
|
if reflect.TypeOf(d).Kind() == reflect.Slice {
|
||||||
temp += fmt.Sprintf(", %s", dialect.FormatBytes(d.([]byte)))
|
temp += fmt.Sprintf(", %s", dstDialect.FormatBytes(d.([]byte)))
|
||||||
} else if reflect.TypeOf(d).Kind() == reflect.String {
|
} else if reflect.TypeOf(d).Kind() == reflect.String {
|
||||||
temp += fmt.Sprintf(", '%s'", d.(string))
|
temp += fmt.Sprintf(", '%s'", d.(string))
|
||||||
}
|
}
|
||||||
|
@ -485,8 +486,8 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: Hack for postgres
|
// FIXME: Hack for postgres
|
||||||
if dialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil {
|
if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil {
|
||||||
_, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quoter().Quote(table.Name)+"), 1), false);\n")
|
_, err = io.WriteString(w, "SELECT setval('"+tableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(tableName)+"), 1), false);\n")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"xorm.io/xorm/schemas"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPingContext(t *testing.T) {
|
func TestPingContext(t *testing.T) {
|
||||||
|
@ -97,3 +98,15 @@ func TestDump(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NoError(t, sess.Commit())
|
assert.NoError(t, sess.Commit())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetSchema(t *testing.T) {
|
||||||
|
assert.NoError(t, prepareEngine())
|
||||||
|
|
||||||
|
if testEngine.Dialect().URI().DBType == schemas.POSTGRES {
|
||||||
|
oldSchema := testEngine.Dialect().URI().Schema
|
||||||
|
testEngine.SetSchema("my_schema")
|
||||||
|
assert.EqualValues(t, "my_schema", testEngine.Dialect().URI().Schema)
|
||||||
|
testEngine.SetSchema(oldSchema)
|
||||||
|
assert.EqualValues(t, oldSchema, testEngine.Dialect().URI().Schema)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user