This repository has been archived on 2020-04-27. You can view files and clone it, but cannot push or open issues or pull requests.
cmd/xorm/go.go
Lunny Xiao 9ccdb0ebed
All checks were successful
continuous-integration/drone/push Build is passing
upgrade to xorm v1.0.1
2020-04-27 09:43:43 +08:00

326 lines
6.8 KiB
Go

// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"errors"
"fmt"
"go/format"
"reflect"
"sort"
"strings"
"text/template"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
var (
supportComment bool
GoLangTmpl LangTmpl = LangTmpl{
template.FuncMap{"Mapper": mapper.Table2Obj,
"Type": typestring,
"Tag": tag,
"UnTitle": unTitle,
"gt": gt,
"getCol": getCol,
"UpperTitle": upTitle,
},
formatGo,
genGoImports,
}
)
var (
errBadComparisonType = errors.New("invalid type for comparison")
errBadComparison = errors.New("incompatible types for comparison")
errNoComparison = errors.New("missing argument for comparison")
)
type kind int
const (
invalidKind kind = iota
boolKind
complexKind
intKind
floatKind
integerKind
stringKind
uintKind
)
func basicKind(v reflect.Value) (kind, error) {
switch v.Kind() {
case reflect.Bool:
return boolKind, nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return intKind, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return uintKind, nil
case reflect.Float32, reflect.Float64:
return floatKind, nil
case reflect.Complex64, reflect.Complex128:
return complexKind, nil
case reflect.String:
return stringKind, nil
}
return invalidKind, errBadComparisonType
}
// eq evaluates the comparison a == b || a == c || ...
func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
v1 := reflect.ValueOf(arg1)
k1, err := basicKind(v1)
if err != nil {
return false, err
}
if len(arg2) == 0 {
return false, errNoComparison
}
for _, arg := range arg2 {
v2 := reflect.ValueOf(arg)
k2, err := basicKind(v2)
if err != nil {
return false, err
}
if k1 != k2 {
return false, errBadComparison
}
truth := false
switch k1 {
case boolKind:
truth = v1.Bool() == v2.Bool()
case complexKind:
truth = v1.Complex() == v2.Complex()
case floatKind:
truth = v1.Float() == v2.Float()
case intKind:
truth = v1.Int() == v2.Int()
case stringKind:
truth = v1.String() == v2.String()
case uintKind:
truth = v1.Uint() == v2.Uint()
default:
panic("invalid kind")
}
if truth {
return true, nil
}
}
return false, nil
}
// lt evaluates the comparison a < b.
func lt(arg1, arg2 interface{}) (bool, error) {
v1 := reflect.ValueOf(arg1)
k1, err := basicKind(v1)
if err != nil {
return false, err
}
v2 := reflect.ValueOf(arg2)
k2, err := basicKind(v2)
if err != nil {
return false, err
}
if k1 != k2 {
return false, errBadComparison
}
truth := false
switch k1 {
case boolKind, complexKind:
return false, errBadComparisonType
case floatKind:
truth = v1.Float() < v2.Float()
case intKind:
truth = v1.Int() < v2.Int()
case stringKind:
truth = v1.String() < v2.String()
case uintKind:
truth = v1.Uint() < v2.Uint()
default:
panic("invalid kind")
}
return truth, nil
}
// le evaluates the comparison <= b.
func le(arg1, arg2 interface{}) (bool, error) {
// <= is < or ==.
lessThan, err := lt(arg1, arg2)
if lessThan || err != nil {
return lessThan, err
}
return eq(arg1, arg2)
}
// gt evaluates the comparison a > b.
func gt(arg1, arg2 interface{}) (bool, error) {
// > is the inverse of <=.
lessOrEqual, err := le(arg1, arg2)
if err != nil {
return false, err
}
return !lessOrEqual, nil
}
func getCol(cols map[string]*core.Column, name string) *core.Column {
return cols[strings.ToLower(name)]
}
func formatGo(src string) (string, error) {
source, err := format.Source([]byte(src))
if err != nil {
return "", err
}
return string(source), nil
}
func genGoImports(tables []*schemas.Table) map[string]string {
imports := make(map[string]string)
for _, table := range tables {
for _, col := range table.Columns() {
if typestring(col) == "time.Time" {
imports["time"] = "time"
}
}
}
return imports
}
func typestring(col *schemas.Column) string {
st := col.SQLType
t := schemas.SQLType2Type(st)
s := t.String()
if s == "[]uint8" {
return "[]byte"
}
return s
}
func tag(table *schemas.Table, col *schemas.Column) string {
isNameId := (mapper.Table2Obj(col.Name) == "Id")
isIdPk := isNameId && typestring(col) == "int64"
var res []string
if !col.Nullable {
if !isIdPk {
res = append(res, "not null")
}
}
if col.IsPrimaryKey {
res = append(res, "pk")
}
if col.Default != "" {
res = append(res, "default "+col.Default)
}
if col.IsAutoIncrement {
res = append(res, "autoincr")
}
if col.SQLType.IsTime() && include(created, col.Name) {
res = append(res, "created")
}
if col.SQLType.IsTime() && include(updated, col.Name) {
res = append(res, "updated")
}
if col.SQLType.IsTime() && include(deleted, col.Name) {
res = append(res, "deleted")
}
if supportComment && col.Comment != "" {
res = append(res, fmt.Sprintf("comment('%s')", col.Comment))
}
names := make([]string, 0, len(col.Indexes))
for name := range col.Indexes {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
index := table.Indexes[name]
var uistr string
if index.Type == core.UniqueType {
uistr = "unique"
} else if index.Type == core.IndexType {
uistr = "index"
}
if len(index.Cols) > 1 {
uistr += "(" + index.Name + ")"
}
res = append(res, uistr)
}
nstr := col.SQLType.Name
if col.Length != 0 {
if col.Length2 != 0 {
nstr += fmt.Sprintf("(%v,%v)", col.Length, col.Length2)
} else {
nstr += fmt.Sprintf("(%v)", col.Length)
}
} else if len(col.EnumOptions) > 0 { //enum
nstr += "("
opts := ""
enumOptions := make([]string, 0, len(col.EnumOptions))
for enumOption := range col.EnumOptions {
enumOptions = append(enumOptions, enumOption)
}
sort.Strings(enumOptions)
for _, v := range enumOptions {
opts += fmt.Sprintf(",'%v'", v)
}
nstr += strings.TrimLeft(opts, ",")
nstr += ")"
} else if len(col.SetOptions) > 0 { //enum
nstr += "("
opts := ""
setOptions := make([]string, 0, len(col.SetOptions))
for setOption := range col.SetOptions {
setOptions = append(setOptions, setOption)
}
sort.Strings(setOptions)
for _, v := range setOptions {
opts += fmt.Sprintf(",'%v'", v)
}
nstr += strings.TrimLeft(opts, ",")
nstr += ")"
}
res = append(res, nstr)
var tags []string
if genJson {
if include(ignoreColumnsJSON, col.Name) {
tags = append(tags, "json:\"-\"")
} else {
tags = append(tags, "json:\""+col.Name+"\"")
}
}
if len(res) > 0 {
tags = append(tags, "xorm:\""+strings.Join(res, " ")+"\"")
}
if len(tags) > 0 {
return "`" + strings.Join(tags, " ") + "`"
} else {
return ""
}
}
func include(source []string, target string) bool {
for _, s := range source {
if s == target {
return true
}
}
return false
}