gitea源码

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. // Copyright 2024 The Gitea Authors. All rights reserved.
  2. // SPDX-License-Identifier: MIT
  3. package unittest
  4. import (
  5. "database/sql"
  6. "encoding/hex"
  7. "fmt"
  8. "os"
  9. "path/filepath"
  10. "slices"
  11. "strings"
  12. "code.gitea.io/gitea/models/db"
  13. "gopkg.in/yaml.v3"
  14. "xorm.io/xorm"
  15. "xorm.io/xorm/schemas"
  16. )
  17. type FixtureItem struct {
  18. fileFullPath string
  19. tableName string
  20. tableNameQuoted string
  21. sqlInserts []string
  22. sqlInsertArgs [][]any
  23. mssqlHasIdentityColumn bool
  24. }
  25. type fixturesLoaderInternal struct {
  26. xormEngine *xorm.Engine
  27. xormTableNames map[string]bool
  28. db *sql.DB
  29. dbType schemas.DBType
  30. fixtures map[string]*FixtureItem
  31. quoteObject func(string) string
  32. paramPlaceholder func(idx int) string
  33. }
  34. func (f *fixturesLoaderInternal) mssqlTableHasIdentityColumn(db *sql.DB, tableName string) (bool, error) {
  35. row := db.QueryRow(`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)`, tableName)
  36. var count int
  37. if err := row.Scan(&count); err != nil {
  38. return false, err
  39. }
  40. return count > 0, nil
  41. }
  42. func (f *fixturesLoaderInternal) preprocessFixtureRow(row []map[string]any) (err error) {
  43. for _, m := range row {
  44. for k, v := range m {
  45. if s, ok := v.(string); ok {
  46. if strings.HasPrefix(s, "0x") {
  47. if m[k], err = hex.DecodeString(s[2:]); err != nil {
  48. return err
  49. }
  50. }
  51. }
  52. }
  53. }
  54. return nil
  55. }
  56. func (f *fixturesLoaderInternal) prepareFixtureItem(fixture *FixtureItem) (err error) {
  57. fixture.tableNameQuoted = f.quoteObject(fixture.tableName)
  58. if f.dbType == schemas.MSSQL {
  59. fixture.mssqlHasIdentityColumn, err = f.mssqlTableHasIdentityColumn(f.db, fixture.tableName)
  60. if err != nil {
  61. return err
  62. }
  63. }
  64. data, err := os.ReadFile(fixture.fileFullPath)
  65. if err != nil {
  66. return fmt.Errorf("failed to read file %q: %w", fixture.fileFullPath, err)
  67. }
  68. var rows []map[string]any
  69. if err = yaml.Unmarshal(data, &rows); err != nil {
  70. return fmt.Errorf("failed to unmarshal yaml data from %q: %w", fixture.fileFullPath, err)
  71. }
  72. if err = f.preprocessFixtureRow(rows); err != nil {
  73. return fmt.Errorf("failed to preprocess fixture rows from %q: %w", fixture.fileFullPath, err)
  74. }
  75. var sqlBuf []byte
  76. var sqlArguments []any
  77. for _, row := range rows {
  78. sqlBuf = append(sqlBuf, fmt.Sprintf("INSERT INTO %s (", fixture.tableNameQuoted)...)
  79. for k, v := range row {
  80. sqlBuf = append(sqlBuf, f.quoteObject(k)...)
  81. sqlBuf = append(sqlBuf, ","...)
  82. sqlArguments = append(sqlArguments, v)
  83. }
  84. sqlBuf = sqlBuf[:len(sqlBuf)-1]
  85. sqlBuf = append(sqlBuf, ") VALUES ("...)
  86. paramIdx := 1
  87. for range row {
  88. sqlBuf = append(sqlBuf, f.paramPlaceholder(paramIdx)...)
  89. sqlBuf = append(sqlBuf, ',')
  90. paramIdx++
  91. }
  92. sqlBuf[len(sqlBuf)-1] = ')'
  93. fixture.sqlInserts = append(fixture.sqlInserts, string(sqlBuf))
  94. fixture.sqlInsertArgs = append(fixture.sqlInsertArgs, slices.Clone(sqlArguments))
  95. sqlBuf = sqlBuf[:0]
  96. sqlArguments = sqlArguments[:0]
  97. }
  98. return nil
  99. }
  100. func (f *fixturesLoaderInternal) loadFixtures(tx *sql.Tx, fixture *FixtureItem) (err error) {
  101. if fixture.tableNameQuoted == "" {
  102. if err = f.prepareFixtureItem(fixture); err != nil {
  103. return err
  104. }
  105. }
  106. _, err = tx.Exec("DELETE FROM " + fixture.tableNameQuoted) // sqlite3 doesn't support truncate
  107. if err != nil {
  108. return err
  109. }
  110. if fixture.mssqlHasIdentityColumn {
  111. _, err = tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s ON", fixture.tableNameQuoted))
  112. if err != nil {
  113. return err
  114. }
  115. defer func() { _, err = tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", fixture.tableNameQuoted)) }()
  116. }
  117. for i := range fixture.sqlInserts {
  118. _, err = tx.Exec(fixture.sqlInserts[i], fixture.sqlInsertArgs[i]...)
  119. }
  120. if err != nil {
  121. return err
  122. }
  123. return nil
  124. }
  125. func (f *fixturesLoaderInternal) Load() error {
  126. tx, err := f.db.Begin()
  127. if err != nil {
  128. return err
  129. }
  130. defer func() { _ = tx.Rollback() }()
  131. for _, fixture := range f.fixtures {
  132. if !f.xormTableNames[fixture.tableName] {
  133. continue
  134. }
  135. if err := f.loadFixtures(tx, fixture); err != nil {
  136. return fmt.Errorf("failed to load fixtures from %s: %w", fixture.fileFullPath, err)
  137. }
  138. }
  139. if err = tx.Commit(); err != nil {
  140. return err
  141. }
  142. for xormTableName := range f.xormTableNames {
  143. if f.fixtures[xormTableName] == nil {
  144. _, _ = f.xormEngine.Exec("DELETE FROM `" + xormTableName + "`")
  145. }
  146. }
  147. return nil
  148. }
  149. func FixturesFileFullPaths(dir string, files []string) (map[string]*FixtureItem, error) {
  150. if files != nil && len(files) == 0 {
  151. return nil, nil // load nothing
  152. }
  153. files = slices.Clone(files)
  154. if len(files) == 0 {
  155. entries, err := os.ReadDir(dir)
  156. if err != nil {
  157. return nil, err
  158. }
  159. for _, e := range entries {
  160. files = append(files, e.Name())
  161. }
  162. }
  163. fixtureItems := map[string]*FixtureItem{}
  164. for _, file := range files {
  165. fileFillPath := file
  166. if !filepath.IsAbs(fileFillPath) {
  167. fileFillPath = filepath.Join(dir, file)
  168. }
  169. tableName, _, _ := strings.Cut(filepath.Base(file), ".")
  170. fixtureItems[tableName] = &FixtureItem{fileFullPath: fileFillPath, tableName: tableName}
  171. }
  172. return fixtureItems, nil
  173. }
  174. func NewFixturesLoader(x *xorm.Engine, opts FixturesOptions) (FixturesLoader, error) {
  175. fixtureItems, err := FixturesFileFullPaths(opts.Dir, opts.Files)
  176. if err != nil {
  177. return nil, fmt.Errorf("failed to get fixtures files: %w", err)
  178. }
  179. f := &fixturesLoaderInternal{xormEngine: x, db: x.DB().DB, dbType: x.Dialect().URI().DBType, fixtures: fixtureItems}
  180. switch f.dbType {
  181. case schemas.SQLITE:
  182. f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) }
  183. f.paramPlaceholder = func(idx int) string { return "?" }
  184. case schemas.POSTGRES:
  185. f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) }
  186. f.paramPlaceholder = func(idx int) string { return fmt.Sprintf(`$%d`, idx) }
  187. case schemas.MYSQL:
  188. f.quoteObject = func(s string) string { return fmt.Sprintf("`%s`", s) }
  189. f.paramPlaceholder = func(idx int) string { return "?" }
  190. case schemas.MSSQL:
  191. f.quoteObject = func(s string) string { return fmt.Sprintf("[%s]", s) }
  192. f.paramPlaceholder = func(idx int) string { return "?" }
  193. }
  194. xormBeans, _ := db.NamesToBean()
  195. f.xormTableNames = map[string]bool{}
  196. for _, bean := range xormBeans {
  197. f.xormTableNames[x.TableName(bean)] = true
  198. }
  199. return f, nil
  200. }