gitea源码

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. // Copyright 2016 The Gitea Authors. All rights reserved.
  2. // SPDX-License-Identifier: MIT
  3. package unittest
  4. import (
  5. "context"
  6. "fmt"
  7. "math"
  8. "os"
  9. "strings"
  10. "code.gitea.io/gitea/models/db"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. "xorm.io/builder"
  14. )
  15. // Code in this file is mainly used by unittest.CheckConsistencyFor, which is not in the unit test for various reasons.
  16. // In the future if we can decouple CheckConsistencyFor into separate unit test code, then this file can be moved into unittest package too.
  17. // NonexistentID an ID that will never exist
  18. const NonexistentID = int64(math.MaxInt64)
  19. type TestingT interface {
  20. require.TestingT
  21. assert.TestingT
  22. Context() context.Context
  23. }
  24. type testCond struct {
  25. query any
  26. args []any
  27. }
  28. type testOrderBy string
  29. // Cond create a condition with arguments for a test
  30. func Cond(query any, args ...any) any {
  31. return &testCond{query: query, args: args}
  32. }
  33. // OrderBy creates "ORDER BY" a test query
  34. func OrderBy(orderBy string) any {
  35. return testOrderBy(orderBy)
  36. }
  37. func whereOrderConditions(e db.Engine, conditions []any) db.Engine {
  38. orderBy := "id" // query must have the "ORDER BY", otherwise the result is not deterministic
  39. for _, condition := range conditions {
  40. switch cond := condition.(type) {
  41. case *testCond:
  42. e = e.Where(cond.query, cond.args...)
  43. case testOrderBy:
  44. orderBy = string(cond)
  45. default:
  46. e = e.Where(cond)
  47. }
  48. }
  49. return e.OrderBy(orderBy)
  50. }
  51. func getBeanIfExists(t TestingT, bean any, conditions ...any) (bool, error) {
  52. e := db.GetEngine(t.Context())
  53. return whereOrderConditions(e, conditions).Get(bean)
  54. }
  55. func GetBean[T any](t TestingT, bean T, conditions ...any) (ret T) {
  56. exists, err := getBeanIfExists(t, bean, conditions...)
  57. require.NoError(t, err)
  58. if exists {
  59. return bean
  60. }
  61. return ret
  62. }
  63. // AssertExistsAndLoadBean assert that a bean exists and load it from the test database
  64. func AssertExistsAndLoadBean[T any](t TestingT, bean T, conditions ...any) T {
  65. exists, err := getBeanIfExists(t, bean, conditions...)
  66. require.NoError(t, err)
  67. require.True(t, exists,
  68. "Expected to find %+v (of type %T, with conditions %+v), but did not",
  69. bean, bean, conditions)
  70. return bean
  71. }
  72. // AssertExistsAndLoadMap assert that a row exists and load it from the test database
  73. func AssertExistsAndLoadMap(t TestingT, table string, conditions ...any) map[string]string {
  74. e := db.GetEngine(t.Context()).Table(table)
  75. res, err := whereOrderConditions(e, conditions).Query()
  76. assert.NoError(t, err)
  77. assert.Len(t, res, 1,
  78. "Expected to find one row in %s (with conditions %+v), but found %d",
  79. table, conditions, len(res),
  80. )
  81. if len(res) == 1 {
  82. rec := map[string]string{}
  83. for k, v := range res[0] {
  84. rec[k] = string(v)
  85. }
  86. return rec
  87. }
  88. return nil
  89. }
  90. // GetCount get the count of a bean
  91. func GetCount(t TestingT, bean any, conditions ...any) int {
  92. e := db.GetEngine(t.Context())
  93. for _, condition := range conditions {
  94. switch cond := condition.(type) {
  95. case *testCond:
  96. e = e.Where(cond.query, cond.args...)
  97. default:
  98. e = e.Where(cond)
  99. }
  100. }
  101. count, err := e.Count(bean)
  102. assert.NoError(t, err)
  103. return int(count)
  104. }
  105. // AssertNotExistsBean assert that a bean does not exist in the test database
  106. func AssertNotExistsBean(t TestingT, bean any, conditions ...any) {
  107. exists, err := getBeanIfExists(t, bean, conditions...)
  108. assert.NoError(t, err)
  109. assert.False(t, exists)
  110. }
  111. // AssertCount assert the count of a bean
  112. func AssertCount(t TestingT, bean, expected any) bool {
  113. return assert.EqualValues(t, expected, GetCount(t, bean))
  114. }
  115. // AssertInt64InRange assert value is in range [low, high]
  116. func AssertInt64InRange(t assert.TestingT, low, high, value int64) {
  117. assert.True(t, value >= low && value <= high,
  118. "Expected value in range [%d, %d], found %d", low, high, value)
  119. }
  120. // GetCountByCond get the count of database entries matching bean
  121. func GetCountByCond(t TestingT, tableName string, cond builder.Cond) int64 {
  122. e := db.GetEngine(t.Context())
  123. count, err := e.Table(tableName).Where(cond).Count()
  124. assert.NoError(t, err)
  125. return count
  126. }
  127. // AssertCountByCond test the count of database entries matching bean
  128. func AssertCountByCond(t TestingT, tableName string, cond builder.Cond, expected int) bool {
  129. return assert.EqualValues(t, expected, GetCountByCond(t, tableName, cond),
  130. "Failed consistency test, the counted bean (of table %s) was %+v", tableName, cond)
  131. }
  132. // DumpQueryResult dumps the result of a query for debugging purpose
  133. func DumpQueryResult(t require.TestingT, sqlOrBean any, sqlArgs ...any) {
  134. x := GetXORMEngine()
  135. goDB := x.DB().DB
  136. sql, ok := sqlOrBean.(string)
  137. if !ok {
  138. sql = "SELECT * FROM " + x.TableName(sqlOrBean)
  139. } else if !strings.Contains(sql, " ") {
  140. sql = "SELECT * FROM " + sql
  141. }
  142. rows, err := goDB.Query(sql, sqlArgs...)
  143. require.NoError(t, err)
  144. defer rows.Close()
  145. columns, err := rows.Columns()
  146. require.NoError(t, err)
  147. _, _ = fmt.Fprintf(os.Stdout, "====== DumpQueryResult: %s ======\n", sql)
  148. idx := 0
  149. for rows.Next() {
  150. row := make([]any, len(columns))
  151. rowPointers := make([]any, len(columns))
  152. for i := range row {
  153. rowPointers[i] = &row[i]
  154. }
  155. require.NoError(t, rows.Scan(rowPointers...))
  156. _, _ = fmt.Fprintf(os.Stdout, "- # row[%d]\n", idx)
  157. for i, col := range columns {
  158. _, _ = fmt.Fprintf(os.Stdout, " %s: %v\n", col, row[i])
  159. }
  160. idx++
  161. }
  162. if idx == 0 {
  163. _, _ = fmt.Fprintf(os.Stdout, "(no result, columns: %s)\n", strings.Join(columns, ", "))
  164. }
  165. }