Simple and Powerful ORM for Go, support mysql,postgres,tidb,sqlite3,mssql,oracle https://xorm.io
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

491 lines
13 KiB

  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "bufio"
  7. "database/sql"
  8. "fmt"
  9. "io"
  10. "os"
  11. "strings"
  12. "xorm.io/xorm/internal/utils"
  13. "xorm.io/xorm/schemas"
  14. )
  15. // Ping test if database is ok
  16. func (session *Session) Ping() error {
  17. if session.isAutoClose {
  18. defer session.Close()
  19. }
  20. session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
  21. return session.DB().PingContext(session.ctx)
  22. }
  23. // CreateTable create a table according a bean
  24. func (session *Session) CreateTable(bean interface{}) error {
  25. if session.isAutoClose {
  26. defer session.Close()
  27. }
  28. return session.createTable(bean)
  29. }
  30. func (session *Session) createTable(bean interface{}) error {
  31. if err := session.statement.SetRefBean(bean); err != nil {
  32. return err
  33. }
  34. sqlStrs := session.statement.GenCreateTableSQL()
  35. for _, s := range sqlStrs {
  36. _, err := session.exec(s)
  37. if err != nil {
  38. return err
  39. }
  40. }
  41. return nil
  42. }
  43. // CreateIndexes create indexes
  44. func (session *Session) CreateIndexes(bean interface{}) error {
  45. if session.isAutoClose {
  46. defer session.Close()
  47. }
  48. return session.createIndexes(bean)
  49. }
  50. func (session *Session) createIndexes(bean interface{}) error {
  51. if err := session.statement.SetRefBean(bean); err != nil {
  52. return err
  53. }
  54. sqls := session.statement.GenIndexSQL()
  55. for _, sqlStr := range sqls {
  56. _, err := session.exec(sqlStr)
  57. if err != nil {
  58. return err
  59. }
  60. }
  61. return nil
  62. }
  63. // CreateUniques create uniques
  64. func (session *Session) CreateUniques(bean interface{}) error {
  65. if session.isAutoClose {
  66. defer session.Close()
  67. }
  68. return session.createUniques(bean)
  69. }
  70. func (session *Session) createUniques(bean interface{}) error {
  71. if err := session.statement.SetRefBean(bean); err != nil {
  72. return err
  73. }
  74. sqls := session.statement.GenUniqueSQL()
  75. for _, sqlStr := range sqls {
  76. _, err := session.exec(sqlStr)
  77. if err != nil {
  78. return err
  79. }
  80. }
  81. return nil
  82. }
  83. // DropIndexes drop indexes
  84. func (session *Session) DropIndexes(bean interface{}) error {
  85. if session.isAutoClose {
  86. defer session.Close()
  87. }
  88. return session.dropIndexes(bean)
  89. }
  90. func (session *Session) dropIndexes(bean interface{}) error {
  91. if err := session.statement.SetRefBean(bean); err != nil {
  92. return err
  93. }
  94. sqls := session.statement.GenDelIndexSQL()
  95. for _, sqlStr := range sqls {
  96. _, err := session.exec(sqlStr)
  97. if err != nil {
  98. return err
  99. }
  100. }
  101. return nil
  102. }
  103. // DropTable drop table will drop table if exist, if drop failed, it will return error
  104. func (session *Session) DropTable(beanOrTableName interface{}) error {
  105. if session.isAutoClose {
  106. defer session.Close()
  107. }
  108. return session.dropTable(beanOrTableName)
  109. }
  110. func (session *Session) dropTable(beanOrTableName interface{}) error {
  111. tableName := session.engine.TableName(beanOrTableName)
  112. sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true))
  113. if !checkIfExist {
  114. exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
  115. if err != nil {
  116. return err
  117. }
  118. checkIfExist = exist
  119. }
  120. if checkIfExist {
  121. _, err := session.exec(sqlStr)
  122. return err
  123. }
  124. return nil
  125. }
  126. // IsTableExist if a table is exist
  127. func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) {
  128. if session.isAutoClose {
  129. defer session.Close()
  130. }
  131. tableName := session.engine.TableName(beanOrTableName)
  132. return session.isTableExist(tableName)
  133. }
  134. func (session *Session) isTableExist(tableName string) (bool, error) {
  135. return session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
  136. }
  137. // IsTableEmpty if table have any records
  138. func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
  139. if session.isAutoClose {
  140. defer session.Close()
  141. }
  142. return session.isTableEmpty(session.engine.TableName(bean))
  143. }
  144. func (session *Session) isTableEmpty(tableName string) (bool, error) {
  145. var total int64
  146. sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true)))
  147. err := session.queryRow(sqlStr).Scan(&total)
  148. if err != nil {
  149. if err == sql.ErrNoRows {
  150. err = nil
  151. }
  152. return true, err
  153. }
  154. return total == 0, nil
  155. }
  156. // find if index is exist according cols
  157. func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) {
  158. indexes, err := session.engine.dialect.GetIndexes(session.getQueryer(), session.ctx, tableName)
  159. if err != nil {
  160. return false, err
  161. }
  162. for _, index := range indexes {
  163. if utils.SliceEq(index.Cols, cols) {
  164. if unique {
  165. return index.Type == schemas.UniqueType, nil
  166. }
  167. return index.Type == schemas.IndexType, nil
  168. }
  169. }
  170. return false, nil
  171. }
  172. func (session *Session) addColumn(colName string) error {
  173. col := session.statement.RefTable.GetColumn(colName)
  174. sql := session.engine.dialect.AddColumnSQL(session.statement.TableName(), col)
  175. _, err := session.exec(sql)
  176. return err
  177. }
  178. func (session *Session) addIndex(tableName, idxName string) error {
  179. index := session.statement.RefTable.Indexes[idxName]
  180. sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index)
  181. _, err := session.exec(sqlStr)
  182. return err
  183. }
  184. func (session *Session) addUnique(tableName, uqeName string) error {
  185. index := session.statement.RefTable.Indexes[uqeName]
  186. sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index)
  187. _, err := session.exec(sqlStr)
  188. return err
  189. }
  190. // Sync2 synchronize structs to database tables
  191. func (session *Session) Sync2(beans ...interface{}) error {
  192. engine := session.engine
  193. if session.isAutoClose {
  194. session.isAutoClose = false
  195. defer session.Close()
  196. }
  197. tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx)
  198. if err != nil {
  199. return err
  200. }
  201. session.autoResetStatement = false
  202. defer func() {
  203. session.autoResetStatement = true
  204. session.resetStatement()
  205. }()
  206. for _, bean := range beans {
  207. v := utils.ReflectValue(bean)
  208. table, err := engine.tagParser.ParseWithCache(v)
  209. if err != nil {
  210. return err
  211. }
  212. var tbName string
  213. if len(session.statement.AltTableName) > 0 {
  214. tbName = session.statement.AltTableName
  215. } else {
  216. tbName = engine.TableName(bean)
  217. }
  218. tbNameWithSchema := engine.tbNameWithSchema(tbName)
  219. var oriTable *schemas.Table
  220. for _, tb := range tables {
  221. if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
  222. oriTable = tb
  223. break
  224. }
  225. }
  226. // this is a new table
  227. if oriTable == nil {
  228. err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
  229. if err != nil {
  230. return err
  231. }
  232. err = session.createUniques(bean)
  233. if err != nil {
  234. return err
  235. }
  236. err = session.createIndexes(bean)
  237. if err != nil {
  238. return err
  239. }
  240. continue
  241. }
  242. // this will modify an old table
  243. if err = engine.loadTableInfo(oriTable); err != nil {
  244. return err
  245. }
  246. // check columns
  247. for _, col := range table.Columns() {
  248. var oriCol *schemas.Column
  249. for _, col2 := range oriTable.Columns() {
  250. if strings.EqualFold(col.Name, col2.Name) {
  251. oriCol = col2
  252. break
  253. }
  254. }
  255. // column is not exist on table
  256. if oriCol == nil {
  257. session.statement.RefTable = table
  258. session.statement.SetTableName(tbNameWithSchema)
  259. if err = session.addColumn(col.Name); err != nil {
  260. return err
  261. }
  262. continue
  263. }
  264. err = nil
  265. expectedType := engine.dialect.SQLType(col)
  266. curType := engine.dialect.SQLType(oriCol)
  267. if expectedType != curType {
  268. if expectedType == schemas.Text &&
  269. strings.HasPrefix(curType, schemas.Varchar) {
  270. // currently only support mysql & postgres
  271. if engine.dialect.URI().DBType == schemas.MYSQL ||
  272. engine.dialect.URI().DBType == schemas.POSTGRES {
  273. engine.logger.Infof("Table %s column %s change type from %s to %s\n",
  274. tbNameWithSchema, col.Name, curType, expectedType)
  275. _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
  276. } else {
  277. engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
  278. tbNameWithSchema, col.Name, curType, expectedType)
  279. }
  280. } else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) {
  281. if engine.dialect.URI().DBType == schemas.MYSQL {
  282. if oriCol.Length < col.Length {
  283. engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
  284. tbNameWithSchema, col.Name, oriCol.Length, col.Length)
  285. _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
  286. }
  287. }
  288. } else {
  289. if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
  290. engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
  291. tbNameWithSchema, col.Name, curType, expectedType)
  292. }
  293. }
  294. } else if expectedType == schemas.Varchar {
  295. if engine.dialect.URI().DBType == schemas.MYSQL {
  296. if oriCol.Length < col.Length {
  297. engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
  298. tbNameWithSchema, col.Name, oriCol.Length, col.Length)
  299. _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
  300. }
  301. }
  302. }
  303. if col.Default != oriCol.Default {
  304. switch {
  305. case col.IsAutoIncrement: // For autoincrement column, don't check default
  306. case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) &&
  307. ((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") ||
  308. (strings.EqualFold(col.Default, "false") && oriCol.Default == "0")):
  309. default:
  310. engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s",
  311. tbName, col.Name, oriCol.Default, col.Default)
  312. }
  313. }
  314. if col.Nullable != oriCol.Nullable {
  315. engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v",
  316. tbName, col.Name, oriCol.Nullable, col.Nullable)
  317. }
  318. if err != nil {
  319. return err
  320. }
  321. }
  322. var foundIndexNames = make(map[string]bool)
  323. var addedNames = make(map[string]*schemas.Index)
  324. for name, index := range table.Indexes {
  325. var oriIndex *schemas.Index
  326. for name2, index2 := range oriTable.Indexes {
  327. if index.Equal(index2) {
  328. oriIndex = index2
  329. foundIndexNames[name2] = true
  330. break
  331. }
  332. }
  333. if oriIndex != nil {
  334. if oriIndex.Type != index.Type {
  335. sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex)
  336. _, err = session.exec(sql)
  337. if err != nil {
  338. return err
  339. }
  340. oriIndex = nil
  341. }
  342. }
  343. if oriIndex == nil {
  344. addedNames[name] = index
  345. }
  346. }
  347. for name2, index2 := range oriTable.Indexes {
  348. if _, ok := foundIndexNames[name2]; !ok {
  349. sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2)
  350. _, err = session.exec(sql)
  351. if err != nil {
  352. return err
  353. }
  354. }
  355. }
  356. for name, index := range addedNames {
  357. if index.Type == schemas.UniqueType {
  358. session.statement.RefTable = table
  359. session.statement.SetTableName(tbNameWithSchema)
  360. err = session.addUnique(tbNameWithSchema, name)
  361. } else if index.Type == schemas.IndexType {
  362. session.statement.RefTable = table
  363. session.statement.SetTableName(tbNameWithSchema)
  364. err = session.addIndex(tbNameWithSchema, name)
  365. }
  366. if err != nil {
  367. return err
  368. }
  369. }
  370. // check all the columns which removed from struct fields but left on database tables.
  371. for _, colName := range oriTable.ColumnsSeq() {
  372. if table.GetColumn(colName) == nil {
  373. engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName)
  374. }
  375. }
  376. }
  377. return nil
  378. }
  379. // ImportFile SQL DDL file
  380. func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) {
  381. file, err := os.Open(ddlPath)
  382. if err != nil {
  383. return nil, err
  384. }
  385. defer file.Close()
  386. return session.Import(file)
  387. }
  388. // Import SQL DDL from io.Reader
  389. func (session *Session) Import(r io.Reader) ([]sql.Result, error) {
  390. var results []sql.Result
  391. var lastError error
  392. scanner := bufio.NewScanner(r)
  393. var inSingleQuote bool
  394. semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  395. if atEOF && len(data) == 0 {
  396. return 0, nil, nil
  397. }
  398. for i, b := range data {
  399. if b == '\'' {
  400. inSingleQuote = !inSingleQuote
  401. }
  402. if !inSingleQuote && b == ';' {
  403. return i + 1, data[0:i], nil
  404. }
  405. }
  406. // If we're at EOF, we have a final, non-terminated line. Return it.
  407. if atEOF {
  408. return len(data), data, nil
  409. }
  410. // Request more data.
  411. return 0, nil, nil
  412. }
  413. scanner.Split(semiColSpliter)
  414. for scanner.Scan() {
  415. query := strings.Trim(scanner.Text(), " \t\n\r")
  416. if len(query) > 0 {
  417. result, err := session.Exec(query)
  418. results = append(results, result)
  419. if err != nil {
  420. return nil, err
  421. }
  422. }
  423. }
  424. return results, lastError
  425. }