gitea源码

init.go 1.9KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. // Copyright 2021 The Gitea Authors. All rights reserved.
  2. // SPDX-License-Identifier: MIT
  3. package oauth2
  4. import (
  5. "context"
  6. "encoding/gob"
  7. "net/http"
  8. "sync"
  9. "code.gitea.io/gitea/models/auth"
  10. "code.gitea.io/gitea/models/db"
  11. "code.gitea.io/gitea/modules/log"
  12. "code.gitea.io/gitea/modules/optional"
  13. "code.gitea.io/gitea/modules/setting"
  14. "github.com/google/uuid"
  15. "github.com/gorilla/sessions"
  16. "github.com/markbates/goth/gothic"
  17. )
  18. var gothRWMutex = sync.RWMutex{}
  19. // UsersStoreKey is the key for the store
  20. const UsersStoreKey = "gitea-oauth2-sessions"
  21. // ProviderHeaderKey is the HTTP header key
  22. const ProviderHeaderKey = "gitea-oauth2-provider"
  23. // Init initializes the oauth source
  24. func Init(ctx context.Context) error {
  25. // Lock our mutex
  26. gothRWMutex.Lock()
  27. gob.Register(&sessions.Session{})
  28. gothic.Store = &SessionsStore{
  29. maxLength: int64(setting.OAuth2.MaxTokenLength),
  30. }
  31. gothic.SetState = func(req *http.Request) string {
  32. return uuid.New().String()
  33. }
  34. gothic.GetProviderName = func(req *http.Request) (string, error) {
  35. return req.Header.Get(ProviderHeaderKey), nil
  36. }
  37. // Unlock our mutex
  38. gothRWMutex.Unlock()
  39. return initOAuth2Sources(ctx)
  40. }
  41. // ResetOAuth2 clears existing OAuth2 providers and loads them from DB
  42. func ResetOAuth2(ctx context.Context) error {
  43. ClearProviders()
  44. return initOAuth2Sources(ctx)
  45. }
  46. // initOAuth2Sources is used to load and register all active OAuth2 providers
  47. func initOAuth2Sources(ctx context.Context) error {
  48. authSources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{
  49. IsActive: optional.Some(true),
  50. LoginType: auth.OAuth2,
  51. })
  52. if err != nil {
  53. return err
  54. }
  55. for _, source := range authSources {
  56. oauth2Source, ok := source.Cfg.(*Source)
  57. if !ok {
  58. continue
  59. }
  60. err := oauth2Source.RegisterSource()
  61. if err != nil {
  62. log.Critical("Unable to register source: %s due to Error: %v.", source.Name, err)
  63. }
  64. }
  65. return nil
  66. }