idp.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package idp
  15. import (
  16. "context"
  17. "encoding/base64"
  18. "encoding/xml"
  19. "io/ioutil"
  20. "net/http"
  21. "strings"
  22. "yunion.io/x/log"
  23. "yunion.io/x/pkg/appctx"
  24. "yunion.io/x/pkg/errors"
  25. "yunion.io/x/pkg/util/httputils"
  26. "yunion.io/x/pkg/util/samlutils"
  27. "yunion.io/x/onecloud/pkg/appsrv"
  28. "yunion.io/x/onecloud/pkg/httperrors"
  29. "yunion.io/x/onecloud/pkg/i18n"
  30. )
  31. const (
  32. IDP_ID_KEY = "<idp_id>"
  33. langTemplateKey = "lang_template_key"
  34. )
  35. type OnSpInitiatedLogin func(ctx context.Context, idpId string, sp *SSAMLServiceProvider) (samlutils.SSAMLSpInitiatedLoginData, error)
  36. type OnIdpInitiatedLogin func(ctx context.Context, sp *SSAMLServiceProvider, IdpId, redirectUrl string) (samlutils.SSAMLIdpInitiatedLoginData, error)
  37. type OnLogout func(ctx context.Context, idpId string) string
  38. type SSAMLIdpInstance struct {
  39. saml *samlutils.SSAMLInstance
  40. metadataPath string
  41. redirectLoginPath string
  42. redirectLogoutPath string
  43. idpInitiatedSSOPath string
  44. serviceProviders []*SSAMLServiceProvider
  45. onSpInitiatedLogin OnSpInitiatedLogin
  46. onIdpInitiatedLogin OnIdpInitiatedLogin
  47. onLogout OnLogout
  48. htmlTemplate i18n.Table
  49. }
  50. func NewIdpInstance(saml *samlutils.SSAMLInstance, spLoginFunc OnSpInitiatedLogin, idpLoginFunc OnIdpInitiatedLogin, logoutFunc OnLogout) *SSAMLIdpInstance {
  51. return &SSAMLIdpInstance{
  52. saml: saml,
  53. onSpInitiatedLogin: spLoginFunc,
  54. onIdpInitiatedLogin: idpLoginFunc,
  55. onLogout: logoutFunc,
  56. htmlTemplate: i18n.Table{},
  57. }
  58. }
  59. func (idp *SSAMLIdpInstance) AddHandlers(app *appsrv.Application, prefix string, middleware appsrv.TMiddleware) {
  60. idp.metadataPath = httputils.JoinPath(prefix, "metadata/"+IDP_ID_KEY)
  61. idp.redirectLoginPath = httputils.JoinPath(prefix, "redirect/login/"+IDP_ID_KEY)
  62. idp.redirectLogoutPath = httputils.JoinPath(prefix, "redirect/logout/"+IDP_ID_KEY)
  63. idp.idpInitiatedSSOPath = httputils.JoinPath(prefix, "sso")
  64. app.AddHandler("GET", idp.metadataPath, idp.metadataHandler)
  65. handler := idp.redirectLoginHandler
  66. if middleware != nil {
  67. handler = middleware(handler)
  68. }
  69. app.AddHandler("POST", idp.redirectLoginPath, handler)
  70. app.AddHandler("GET", idp.redirectLoginPath, handler)
  71. handler = idp.redirectLogoutHandler
  72. if middleware != nil {
  73. handler = middleware(handler)
  74. }
  75. app.AddHandler("GET", idp.redirectLogoutPath, handler)
  76. handler = idp.idpInitiatedSSOHandler
  77. if middleware != nil {
  78. handler = middleware(handler)
  79. }
  80. app.AddHandler("GET", idp.idpInitiatedSSOPath, handler)
  81. log.Infof("IDP metadata: %s", idp.getMetadataUrl(IDP_ID_KEY))
  82. log.Infof("IDP redirect login: %s", idp.getRedirectLoginUrl(IDP_ID_KEY))
  83. log.Infof("IDP redirect logout: %s", idp.getRedirectLogoutUrl(IDP_ID_KEY))
  84. log.Infof("IDP initated SSO: %s", idp.getIdpInitiatedSSOUrl())
  85. }
  86. func (idp *SSAMLIdpInstance) SetHtmlTemplate(entry i18n.TableEntry) error {
  87. for _, tmp := range entry {
  88. if strings.Index(tmp, samlutils.HTML_SAML_FORM_TOKEN) < 0 {
  89. return errors.Wrapf(httperrors.ErrInvalidFormat, "no %s found", samlutils.HTML_SAML_FORM_TOKEN)
  90. }
  91. }
  92. idp.htmlTemplate.Set(langTemplateKey, entry)
  93. return nil
  94. }
  95. func (idp *SSAMLIdpInstance) AddSPMetadataFile(filename string) error {
  96. metadata, err := ioutil.ReadFile(filename)
  97. if err != nil {
  98. return errors.Wrap(err, "ioutil.ReadFile")
  99. }
  100. return idp.AddSPMetadata(metadata)
  101. }
  102. func (idp *SSAMLIdpInstance) AddSPMetadata(metadata []byte) error {
  103. ed, err := samlutils.ParseMetadata(metadata)
  104. if err != nil {
  105. return errors.Wrap(err, "samlutils.ParseMetadata")
  106. }
  107. sp := &SSAMLServiceProvider{desc: ed}
  108. err = sp.IsValid()
  109. if err != nil {
  110. return errors.Wrap(err, "NewSAMLServiceProvider")
  111. }
  112. log.Debugf("Register SP metadata: %s", sp.GetEntityId())
  113. idp.serviceProviders = append(idp.serviceProviders, sp)
  114. return nil
  115. }
  116. func (idp *SSAMLIdpInstance) getMetadataUrl(idpId string) string {
  117. return strings.Replace(httputils.JoinPath(idp.saml.GetEntityId(), idp.metadataPath), IDP_ID_KEY, idpId, 1)
  118. }
  119. func (idp *SSAMLIdpInstance) getRedirectLoginUrl(idpId string) string {
  120. return strings.Replace(httputils.JoinPath(idp.saml.GetEntityId(), idp.redirectLoginPath), IDP_ID_KEY, idpId, 1)
  121. }
  122. func (idp *SSAMLIdpInstance) getRedirectLogoutUrl(idpId string) string {
  123. return strings.Replace(httputils.JoinPath(idp.saml.GetEntityId(), idp.redirectLogoutPath), IDP_ID_KEY, idpId, 1)
  124. }
  125. func (idp *SSAMLIdpInstance) getIdpInitiatedSSOUrl() string {
  126. return httputils.JoinPath(idp.saml.GetEntityId(), idp.idpInitiatedSSOPath)
  127. }
  128. func (idp *SSAMLIdpInstance) metadataHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
  129. params := appctx.AppContextParams(ctx)
  130. idpId := params[IDP_ID_KEY]
  131. desc := idp.GetMetadata(idpId)
  132. appsrv.SendXmlWithIndent(w, nil, desc, true)
  133. }
  134. func (idp *SSAMLIdpInstance) redirectLoginHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
  135. params, query, _ := appsrv.FetchEnv(ctx, w, r)
  136. idpId := params[IDP_ID_KEY]
  137. input := samlutils.SIdpRedirectLoginInput{}
  138. err := query.Unmarshal(&input)
  139. if err != nil {
  140. httperrors.InputParameterError(ctx, w, "query.Unmarshal error %s", err)
  141. return
  142. }
  143. log.Debugf("recv input %s", input)
  144. respHtml, err := idp.processLoginRequest(ctx, idpId, input)
  145. if err != nil {
  146. httperrors.InputParameterError(ctx, w, "parse parameter error %s", err)
  147. return
  148. }
  149. appsrv.SendHTML(w, respHtml)
  150. }
  151. func (idp *SSAMLIdpInstance) redirectLogoutHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
  152. params := appctx.AppContextParams(ctx)
  153. idpId := params[IDP_ID_KEY]
  154. log.Debugf("logout: %s", r.Header)
  155. html := idp.onLogout(ctx, idpId)
  156. appsrv.SendHTML(w, html)
  157. }
  158. func (idp *SSAMLIdpInstance) idpInitiatedSSOHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
  159. _, query, _ := appsrv.FetchEnv(ctx, w, r)
  160. input := samlutils.SIdpInitiatedLoginInput{}
  161. err := query.Unmarshal(&input)
  162. if err != nil {
  163. httperrors.InputParameterError(ctx, w, "unmarshal input fail %s", err)
  164. return
  165. }
  166. respHtml, err := idp.processIdpInitiatedLogin(ctx, input)
  167. if err != nil {
  168. httperrors.GeneralServerError(ctx, w, err)
  169. return
  170. }
  171. appsrv.SendHTML(w, respHtml)
  172. }
  173. func (idp *SSAMLIdpInstance) GetMetadata(idpId string) samlutils.EntityDescriptor {
  174. input := samlutils.SSAMLIdpMetadataInput{
  175. EntityId: idp.saml.GetEntityId(),
  176. CertString: idp.saml.GetCertString(),
  177. RedirectLoginUrl: idp.getRedirectLoginUrl(idpId),
  178. RedirectLogoutUrl: idp.getRedirectLogoutUrl(idpId),
  179. }
  180. return samlutils.NewIdpMetadata(input)
  181. }
  182. func (idp *SSAMLIdpInstance) processLoginRequest(ctx context.Context, idpId string, input samlutils.SIdpRedirectLoginInput) (string, error) {
  183. plainText, err := samlutils.SAMLDecode(input.SAMLRequest)
  184. if err != nil {
  185. return "", errors.Wrap(err, "samlutils.SAMLDecode")
  186. }
  187. log.Debugf("AuthnRequest: %s", string(plainText))
  188. authReq := samlutils.AuthnRequest{}
  189. err = xml.Unmarshal(plainText, &authReq)
  190. if err != nil {
  191. return "", errors.Wrap(err, "xml.Unmarshal")
  192. }
  193. sp := idp.getServiceProvider(authReq.Issuer.Issuer)
  194. if sp == nil {
  195. return "", errors.Wrapf(httperrors.ErrResourceNotFound, "issuer %s not found", authReq.Issuer.Issuer)
  196. }
  197. if len(authReq.Destination) > 0 && authReq.Destination != idp.getRedirectLoginUrl(idpId) {
  198. return "", errors.Wrapf(httperrors.ErrInputParameter, "Destination not match: get %s want %s", authReq.Destination, idp.getRedirectLoginUrl(idpId))
  199. }
  200. if len(authReq.AssertionConsumerServiceURL) > 0 && authReq.AssertionConsumerServiceURL != sp.GetPostAssertionConsumerServiceUrl() {
  201. return "", errors.Wrapf(httperrors.ErrInputParameter, "AssertionConsumerServiceURL not match: get %s want %s", authReq.AssertionConsumerServiceURL, sp.GetPostAssertionConsumerServiceUrl())
  202. }
  203. sp.Username = input.Username
  204. resp, err := idp.getLoginResponse(ctx, authReq, idpId, sp)
  205. if err != nil {
  206. return "", errors.Wrap(err, "getLoginResponse")
  207. }
  208. form, err := idp.samlResponse2Form(ctx, sp.GetPostAssertionConsumerServiceUrl(), resp, input.RelayState)
  209. if err != nil {
  210. return "", errors.Wrap(err, "samlResponse2Form")
  211. }
  212. return form, nil
  213. }
  214. func (idp *SSAMLIdpInstance) samlResponse2Form(ctx context.Context, url string, resp *samlutils.Response, state string) (string, error) {
  215. respXml, err := xml.Marshal(resp)
  216. if err != nil {
  217. return "", errors.Wrap(err, "xml.Marshal")
  218. }
  219. signed, err := idp.saml.SignXML(string(respXml))
  220. if err != nil {
  221. return "", errors.Wrap(err, "saml.SignXML")
  222. }
  223. log.Debugf("ResponseXML: %s", signed)
  224. samlResp := base64.StdEncoding.EncodeToString([]byte(signed))
  225. form := samlutils.SAMLForm(url, map[string]string{
  226. "SAMLResponse": samlResp,
  227. "RelayState": state,
  228. })
  229. template := samlutils.DEFAULT_HTML_TEMPLATE
  230. _temp := idp.htmlTemplate.Lookup(ctx, langTemplateKey)
  231. if _temp != langTemplateKey {
  232. template = _temp
  233. }
  234. form = strings.Replace(template, samlutils.HTML_SAML_FORM_TOKEN, form, 1)
  235. return form, nil
  236. }
  237. func (idp *SSAMLIdpInstance) getServiceProvider(eId string) *SSAMLServiceProvider {
  238. for _, sp := range idp.serviceProviders {
  239. if sp.GetEntityId() == eId {
  240. return sp
  241. }
  242. }
  243. return nil
  244. }
  245. func (idp *SSAMLIdpInstance) getLoginResponse(ctx context.Context, req samlutils.AuthnRequest, idpId string, sp *SSAMLServiceProvider) (*samlutils.Response, error) {
  246. data, err := idp.onSpInitiatedLogin(ctx, idpId, sp)
  247. if err != nil {
  248. return nil, errors.Wrap(err, "idp.onSpInitiatedLogin")
  249. }
  250. input := samlutils.SSAMLResponseInput{
  251. IssuerCertString: idp.saml.GetCertString(),
  252. IssuerEntityId: idp.saml.GetEntityId(),
  253. RequestID: req.ID,
  254. RequestEntityId: req.Issuer.Issuer,
  255. AssertionConsumerServiceURL: sp.GetPostAssertionConsumerServiceUrl(),
  256. SSAMLSpInitiatedLoginData: data,
  257. }
  258. resp := samlutils.NewResponse(input)
  259. return &resp, nil
  260. }
  261. func (idp *SSAMLIdpInstance) processIdpInitiatedLogin(ctx context.Context, input samlutils.SIdpInitiatedLoginInput) (string, error) {
  262. sp := idp.getServiceProvider(input.EntityID)
  263. if sp == nil {
  264. return "", errors.Wrapf(httperrors.ErrResourceNotFound, "issuer %s not found", input.EntityID)
  265. }
  266. data, err := idp.onIdpInitiatedLogin(ctx, sp, input.IdpId, input.RedirectUrl)
  267. if err != nil {
  268. return "", errors.Wrap(err, "idp.onIdpInitiatedLogin")
  269. }
  270. if len(data.Form) > 0 {
  271. return data.Form, nil
  272. }
  273. respInput := samlutils.SSAMLResponseInput{
  274. IssuerCertString: idp.saml.GetCertString(),
  275. IssuerEntityId: idp.saml.GetEntityId(),
  276. RequestID: "",
  277. RequestEntityId: sp.GetEntityId(),
  278. AssertionConsumerServiceURL: sp.GetPostAssertionConsumerServiceUrl(),
  279. SSAMLSpInitiatedLoginData: data.SSAMLSpInitiatedLoginData,
  280. }
  281. resp := samlutils.NewResponse(respInput)
  282. form, err := idp.samlResponse2Form(ctx, sp.GetPostAssertionConsumerServiceUrl(), &resp, data.RelayState)
  283. if err != nil {
  284. return "", errors.Wrap(err, "samlResponse2Form")
  285. }
  286. return form, nil
  287. }