middleware.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package handler
  15. import (
  16. "context"
  17. "encoding/base64"
  18. "net/http"
  19. "strings"
  20. "yunion.io/x/jsonutils"
  21. "yunion.io/x/log"
  22. "yunion.io/x/pkg/appctx"
  23. "yunion.io/x/pkg/errors"
  24. "yunion.io/x/onecloud/pkg/apigateway/clientman"
  25. "yunion.io/x/onecloud/pkg/apigateway/constants"
  26. "yunion.io/x/onecloud/pkg/httperrors"
  27. "yunion.io/x/onecloud/pkg/mcclient"
  28. )
  29. func Base64UrlEncode(data []byte) string {
  30. str := base64.StdEncoding.EncodeToString(data)
  31. str = strings.Replace(str, "+", "-", -1)
  32. str = strings.Replace(str, "/", "_", -1)
  33. str = strings.Replace(str, "=", "", -1)
  34. return str
  35. }
  36. func Base64UrlDecode(str string) ([]byte, error) {
  37. if strings.ContainsAny(str, "+/") {
  38. return nil, errors.Wrap(httperrors.ErrInputParameter, "invalid base64url encoding")
  39. }
  40. str = strings.Replace(str, "-", "+", -1)
  41. str = strings.Replace(str, "_", "/", -1)
  42. for len(str)%4 != 0 {
  43. str += "="
  44. }
  45. return base64.StdEncoding.DecodeString(str)
  46. }
  47. func getAuthToken(r *http.Request) string {
  48. auth := r.Header.Get(constants.AUTH_HEADER)
  49. if len(auth) > 0 && auth[:len(constants.AUTH_PREFIX)] == constants.AUTH_PREFIX {
  50. return auth[len(constants.AUTH_PREFIX):]
  51. } else {
  52. return ""
  53. }
  54. }
  55. func getAuthCookie(r *http.Request) string {
  56. return getCookie(r, constants.YUNION_AUTH_COOKIE)
  57. }
  58. /*func setAuthHeader(w http.ResponseWriter, tid string) {
  59. w.Header().Set(constants.AUTH_HEADER, fmt.Sprintf("%s%s", constants.AUTH_PREFIX, tid))
  60. }*/
  61. func fetchAuthInfo(ctx context.Context, r *http.Request) (mcclient.TokenCredential, *clientman.SAuthToken, error) {
  62. var token mcclient.TokenCredential
  63. var authToken *clientman.SAuthToken
  64. // no more use Auth header
  65. // auth1 := getAuthToken(r)
  66. auth := getAuthToken(r)
  67. if len(auth) == 0 {
  68. authCookieStr := getAuthCookie(r)
  69. if len(authCookieStr) > 0 {
  70. authCookie, err := jsonutils.ParseString(authCookieStr)
  71. if err != nil {
  72. return nil, nil, errors.Wrap(httperrors.ErrInputParameter, "Auth cookie decode")
  73. }
  74. auth, err = authCookie.GetString("session")
  75. if err != nil {
  76. return nil, nil, errors.Wrap(httperrors.ErrInputParameter, "authCookie missing session field")
  77. }
  78. }
  79. }
  80. // if len(auth) > 0 && auth != auth1 { // hack!!! browser cache problem???
  81. // log.Errorf("XXXX Auth cookie and header mismatch!!! %s:%s", auth, auth1)
  82. // auth = auth1
  83. // }
  84. if len(auth) > 0 {
  85. var err error
  86. authToken, err = clientman.Decode(auth)
  87. if err != nil {
  88. log.Errorf("clientman.Decode auth token fail: %v", err)
  89. return nil, nil, errors.Wrap(httperrors.ErrInputParameter, "clientman.Decode auth token fail")
  90. }
  91. token, err = authToken.GetToken(ctx)
  92. if err != nil {
  93. log.Errorf("authToken.GetToken fail: %v", err)
  94. return nil, nil, errors.Wrap(httperrors.ErrInputParameter, "authToken.GetToken fail")
  95. }
  96. }
  97. if token == nil {
  98. return nil, nil, errors.Wrap(httperrors.ErrInvalidCredential, "No token in header")
  99. } else if !token.IsValid() {
  100. return nil, nil, errors.Wrap(httperrors.ErrInvalidCredential, "Token in header invalid")
  101. }
  102. return token, authToken, nil
  103. }
  104. func fetchAndSetAuthContext(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, error) {
  105. token, authToken, err := fetchAuthInfo(ctx, r)
  106. if err != nil {
  107. return ctx, errors.Wrap(err, "fetchAuthInfo")
  108. }
  109. // 启用双因子认证
  110. if !authToken.IsTotpVerified() {
  111. return ctx, errors.Wrap(httperrors.ErrInvalidCredential, "TOTP authentication failed")
  112. }
  113. // no more send auth header, save auth info in cookie
  114. // setAuthHeader(w, authHeader)
  115. ctx = context.WithValue(ctx, appctx.APP_CONTEXT_KEY_AUTH_TOKEN, token)
  116. return ctx, nil
  117. }
  118. func FetchAuthToken(f func(context.Context, http.ResponseWriter, *http.Request)) func(context.Context, http.ResponseWriter, *http.Request) {
  119. return func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
  120. ctx, err := fetchAndSetAuthContext(ctx, w, r)
  121. if err != nil {
  122. httperrors.InvalidCredentialError(ctx, w, "No token in header: %v", err)
  123. return
  124. }
  125. f(ctx, w, r)
  126. }
  127. }