reverse/cmd/reverse.go
2022-06-03 22:35:24 +08:00

268 lines
5.4 KiB
Go

// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmd
import (
"bytes"
"errors"
"html/template"
"io/ioutil"
"os"
"path/filepath"
"strings"
"gitea.com/polaris/reverse/pkg/conf"
"gitea.com/polaris/reverse/pkg/language"
"gitea.com/polaris/reverse/pkg/utils"
"gitea.com/lunny/log"
underscore "github.com/ahl5esoft/golang-underscore"
"github.com/gobwas/glob"
"xorm.io/xorm"
"xorm.io/xorm/schemas"
)
var (
defaultFuncs = template.FuncMap{
"UnTitle": utils.UnTitle,
"Upper": utils.UpTitle,
}
)
func reverseFromConfig(rFile string) error {
configs, err := conf.NewReverseConfigFromYAML(rFile)
if err != nil {
return err
}
for _, cfg := range configs {
for _, target := range cfg.Targets {
if err := runReverse(&cfg.Source, &target); err != nil {
return err
}
}
}
return nil
}
// filterTables filter by target.ExcludeTables and target.IncludeTables
func filterTables(tables []*schemas.Table, target *conf.ReverseTarget) []*schemas.Table {
// support one table from flag
if tableName != "" {
for _, table := range tables {
if table.Name == tableName {
return []*schemas.Table{table}
}
}
panic(tableName + " is not exists!")
}
var res = make([]*schemas.Table, 0, len(tables))
underscore.Chain(tables).
Filter(func(tbl schemas.Table, _ int) bool {
for _, exclude := range target.ExcludeTables {
s, _ := glob.Compile(exclude)
if s.Match(tbl.Name) {
return false
}
}
return true
}).
Filter(func(tbl schemas.Table, _ int) bool {
// if not set, all tables by default
if len(target.IncludeTables) == 0 {
return true
}
for _, include := range target.IncludeTables {
s, _ := glob.Compile(include)
if s.Match(tbl.Name) {
return true
}
}
return false
}).
Each(func(tbl schemas.Table, _ int) {
res = append(res, &tbl)
})
return res
}
func runReverse(source *conf.ReverseSource, target *conf.ReverseTarget) error {
var (
formatter func(string) (string, error)
importter func([]*schemas.Table) []string
)
orm, err := xorm.NewEngine(source.Database, source.ConnStr)
if err != nil {
return err
}
tables, err := orm.DBMetas()
if err != nil {
return err
}
// filter tables according includes and excludes
tables = filterTables(tables, target)
// load configuration from language
lang := language.GetLanguage(target.Language, target.TableName)
// load template
var bs []byte
if target.Template != "" {
bs = []byte(target.Template)
} else if target.TemplatePath != "" {
bs, err = ioutil.ReadFile(target.TemplatePath)
if err != nil {
return err
}
}
var tableMapper = utils.GetMapperByName(target.TableMapper)
var colMapper = utils.GetMapperByName(target.ColumnMapper)
funcs := utils.MergeFuncMap(
template.FuncMap(defaultFuncs),
template.FuncMap{
"TableMapper": tableMapper.Table2Obj,
"ColumnMapper": colMapper.Table2Obj,
})
if lang != nil {
lang.BindTarget(target)
if bs == nil {
bs = []byte(lang.GetTemplate())
}
funcs = utils.MergeFuncMap(funcs, lang.GetFuncs())
if formatter == nil {
formatter = lang.GetFormatter()
}
if importter == nil {
importter = lang.GetImportter()
}
target.ExtName = lang.GetExtName()
}
if !strings.HasPrefix(target.ExtName, ".") {
target.ExtName = "." + target.ExtName
}
if bs == nil {
return errors.New("you have to indicate template / template path or a language")
}
t := template.New("reverse")
t.Funcs(funcs)
tmpl, err := t.Parse(string(bs))
if err != nil {
return err
}
for _, table := range tables {
if target.TablePrefix != "" {
table.Name = strings.TrimPrefix(table.Name, target.TablePrefix)
}
for _, col := range table.Columns() {
col.FieldName = colMapper.Table2Obj(col.Name)
}
}
err = os.MkdirAll(target.OutputDir, os.ModePerm)
if err != nil {
return err
}
var w *os.File
if !target.MultipleFiles {
w, err = os.Create(filepath.Join(target.OutputDir, "models"+target.ExtName))
if err != nil {
return err
}
defer w.Close()
imports := importter(tables)
newbytes := bytes.NewBufferString("")
err = tmpl.Execute(newbytes, map[string]interface{}{
"Tables": tables,
"Imports": imports,
})
if err != nil {
return err
}
tplcontent, err := ioutil.ReadAll(newbytes)
if err != nil {
return err
}
var source string
if formatter != nil {
source, err = formatter(string(tplcontent))
if err != nil {
log.Warnf("%v", err)
source = string(tplcontent)
}
} else {
source = string(tplcontent)
}
w.WriteString(source)
w.Close()
} else {
for _, table := range tables {
// imports
tbs := []*schemas.Table{table}
imports := importter(tbs)
w, err := os.Create(filepath.Join(target.OutputDir, table.Name+target.ExtName))
if err != nil {
return err
}
defer w.Close()
newbytes := bytes.NewBufferString("")
err = tmpl.Execute(newbytes, map[string]interface{}{
"Tables": tbs,
"Imports": imports,
})
if err != nil {
return err
}
tplcontent, err := ioutil.ReadAll(newbytes)
if err != nil {
return err
}
var source string
if formatter != nil {
source, err = formatter(string(tplcontent))
if err != nil {
log.Warnf("%v", err)
source = string(tplcontent)
}
} else {
source = string(tplcontent)
}
w.WriteString(source)
w.Close()
}
}
return nil
}