gitea源码

providers.go 6.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. // Copyright 2021 The Gitea Authors. All rights reserved.
  2. // SPDX-License-Identifier: MIT
  3. package oauth2
  4. import (
  5. "context"
  6. "errors"
  7. "fmt"
  8. "html"
  9. "html/template"
  10. "net/url"
  11. "slices"
  12. "sort"
  13. "code.gitea.io/gitea/models/auth"
  14. "code.gitea.io/gitea/models/db"
  15. "code.gitea.io/gitea/modules/log"
  16. "code.gitea.io/gitea/modules/optional"
  17. "code.gitea.io/gitea/modules/setting"
  18. "github.com/markbates/goth"
  19. )
  20. // Provider is an interface for describing a single OAuth2 provider
  21. type Provider interface {
  22. Name() string
  23. DisplayName() string
  24. IconHTML(size int) template.HTML
  25. CustomURLSettings() *CustomURLSettings
  26. SupportSSHPublicKey() bool
  27. }
  28. // GothProviderCreator provides a function to create a goth.Provider
  29. type GothProviderCreator interface {
  30. CreateGothProvider(providerName, callbackURL string, source *Source) (goth.Provider, error)
  31. }
  32. // GothProvider is an interface for describing a single OAuth2 provider
  33. type GothProvider interface {
  34. Provider
  35. GothProviderCreator
  36. }
  37. // AuthSourceProvider provides a provider for an AuthSource. Multiple auth sources could use the same registered GothProvider
  38. // So each auth source should have its own DisplayName and IconHTML for display.
  39. // The Name is the GothProvider's name, to help to find the GothProvider to sign in.
  40. // The DisplayName is the auth source config's name, site admin set it on the admin page, the IconURL can also be set there.
  41. type AuthSourceProvider struct {
  42. GothProvider
  43. sourceName, iconURL string
  44. }
  45. func (p *AuthSourceProvider) Name() string {
  46. return p.GothProvider.Name()
  47. }
  48. func (p *AuthSourceProvider) DisplayName() string {
  49. return p.sourceName
  50. }
  51. func (p *AuthSourceProvider) IconHTML(size int) template.HTML {
  52. if p.iconURL != "" {
  53. img := fmt.Sprintf(`<img class="tw-object-contain tw-mr-2" width="%d" height="%d" src="%s" alt="%s">`,
  54. size,
  55. size,
  56. html.EscapeString(p.iconURL), html.EscapeString(p.DisplayName()),
  57. )
  58. return template.HTML(img)
  59. }
  60. return p.GothProvider.IconHTML(size)
  61. }
  62. // Providers contains the map of registered OAuth2 providers in Gitea (based on goth)
  63. // key is used to map the OAuth2Provider with the goth provider type (also in AuthSource.OAuth2Config.Provider)
  64. // value is used to store display data
  65. var gothProviders = map[string]GothProvider{}
  66. func isAzureProvider(name string) bool {
  67. return name == "azuread" || name == "microsoftonline" || name == "azureadv2"
  68. }
  69. // RegisterGothProvider registers a GothProvider
  70. func RegisterGothProvider(provider GothProvider) {
  71. if _, has := gothProviders[provider.Name()]; has {
  72. log.Fatal("Duplicate oauth2provider type provided: %s", provider.Name())
  73. }
  74. gothProviders[provider.Name()] = provider
  75. }
  76. // getExistingAzureADAuthSources returns a list of Azure AD provider names that are already configured
  77. func getExistingAzureADAuthSources(ctx context.Context) ([]string, error) {
  78. authSources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{
  79. LoginType: auth.OAuth2,
  80. })
  81. if err != nil {
  82. return nil, err
  83. }
  84. var existingAzureProviders []string
  85. for _, source := range authSources {
  86. if oauth2Cfg, ok := source.Cfg.(*Source); ok {
  87. if isAzureProvider(oauth2Cfg.Provider) {
  88. existingAzureProviders = append(existingAzureProviders, oauth2Cfg.Provider)
  89. }
  90. }
  91. }
  92. return existingAzureProviders, nil
  93. }
  94. // GetSupportedOAuth2Providers returns the list of supported OAuth2 providers with context for filtering
  95. // key is used as technical name (like in the callbackURL)
  96. // values to display
  97. // Note: Azure AD providers (azuread, microsoftonline, azureadv2) are filtered out
  98. // unless they already exist in the system to encourage use of OpenID Connect
  99. func GetSupportedOAuth2Providers(ctx context.Context) []Provider {
  100. providers := make([]Provider, 0, len(gothProviders))
  101. existingAzureSources, err := getExistingAzureADAuthSources(ctx)
  102. if err != nil {
  103. log.Error("Failed to get existing OAuth2 auth sources: %v", err)
  104. }
  105. for _, provider := range gothProviders {
  106. if isAzureProvider(provider.Name()) && !slices.Contains(existingAzureSources, provider.Name()) {
  107. continue
  108. }
  109. providers = append(providers, provider)
  110. }
  111. sort.Slice(providers, func(i, j int) bool {
  112. return providers[i].Name() < providers[j].Name()
  113. })
  114. return providers
  115. }
  116. func CreateProviderFromSource(source *auth.Source) (Provider, error) {
  117. oauth2Cfg, ok := source.Cfg.(*Source)
  118. if !ok {
  119. return nil, fmt.Errorf("invalid OAuth2 source config: %v", oauth2Cfg)
  120. }
  121. gothProv := gothProviders[oauth2Cfg.Provider]
  122. return &AuthSourceProvider{GothProvider: gothProv, sourceName: source.Name, iconURL: oauth2Cfg.IconURL}, nil
  123. }
  124. // GetOAuth2Providers returns the list of configured OAuth2 providers
  125. func GetOAuth2Providers(ctx context.Context, isActive optional.Option[bool]) ([]Provider, error) {
  126. authSources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{
  127. IsActive: isActive,
  128. LoginType: auth.OAuth2,
  129. })
  130. if err != nil {
  131. return nil, err
  132. }
  133. providers := make([]Provider, 0, len(authSources))
  134. for _, source := range authSources {
  135. provider, err := CreateProviderFromSource(source)
  136. if err != nil {
  137. return nil, err
  138. }
  139. providers = append(providers, provider)
  140. }
  141. sort.Slice(providers, func(i, j int) bool {
  142. return providers[i].Name() < providers[j].Name()
  143. })
  144. return providers, nil
  145. }
  146. // RegisterProviderWithGothic register a OAuth2 provider in goth lib
  147. func RegisterProviderWithGothic(providerName string, source *Source) error {
  148. provider, err := createProvider(providerName, source)
  149. if err == nil && provider != nil {
  150. gothRWMutex.Lock()
  151. defer gothRWMutex.Unlock()
  152. goth.UseProviders(provider)
  153. }
  154. return err
  155. }
  156. // RemoveProviderFromGothic removes the given OAuth2 provider from the goth lib
  157. func RemoveProviderFromGothic(providerName string) {
  158. gothRWMutex.Lock()
  159. defer gothRWMutex.Unlock()
  160. delete(goth.GetProviders(), providerName)
  161. }
  162. // ClearProviders clears all OAuth2 providers from the goth lib
  163. func ClearProviders() {
  164. gothRWMutex.Lock()
  165. defer gothRWMutex.Unlock()
  166. goth.ClearProviders()
  167. }
  168. var ErrAuthSourceNotActivated = errors.New("auth source is not activated")
  169. // used to create different types of goth providers
  170. func createProvider(providerName string, source *Source) (goth.Provider, error) {
  171. callbackURL := setting.AppURL + "user/oauth2/" + url.PathEscape(providerName) + "/callback"
  172. var provider goth.Provider
  173. var err error
  174. p, ok := gothProviders[source.Provider]
  175. if !ok {
  176. return nil, ErrAuthSourceNotActivated
  177. }
  178. provider, err = p.CreateGothProvider(providerName, callbackURL, source)
  179. if err != nil {
  180. return provider, err
  181. }
  182. // always set the name if provider is created so we can support multiple setups of 1 provider
  183. if provider != nil {
  184. provider.SetName(providerName)
  185. }
  186. return provider, err
  187. }