| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331 |
- // Copyright 2019 Yunion
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package idp
- import (
- "context"
- "encoding/base64"
- "encoding/xml"
- "io/ioutil"
- "net/http"
- "strings"
- "yunion.io/x/log"
- "yunion.io/x/pkg/appctx"
- "yunion.io/x/pkg/errors"
- "yunion.io/x/pkg/util/httputils"
- "yunion.io/x/pkg/util/samlutils"
- "yunion.io/x/onecloud/pkg/appsrv"
- "yunion.io/x/onecloud/pkg/httperrors"
- "yunion.io/x/onecloud/pkg/i18n"
- )
- const (
- IDP_ID_KEY = "<idp_id>"
- langTemplateKey = "lang_template_key"
- )
- type OnSpInitiatedLogin func(ctx context.Context, idpId string, sp *SSAMLServiceProvider) (samlutils.SSAMLSpInitiatedLoginData, error)
- type OnIdpInitiatedLogin func(ctx context.Context, sp *SSAMLServiceProvider, IdpId, redirectUrl string) (samlutils.SSAMLIdpInitiatedLoginData, error)
- type OnLogout func(ctx context.Context, idpId string) string
- type SSAMLIdpInstance struct {
- saml *samlutils.SSAMLInstance
- metadataPath string
- redirectLoginPath string
- redirectLogoutPath string
- idpInitiatedSSOPath string
- serviceProviders []*SSAMLServiceProvider
- onSpInitiatedLogin OnSpInitiatedLogin
- onIdpInitiatedLogin OnIdpInitiatedLogin
- onLogout OnLogout
- htmlTemplate i18n.Table
- }
- func NewIdpInstance(saml *samlutils.SSAMLInstance, spLoginFunc OnSpInitiatedLogin, idpLoginFunc OnIdpInitiatedLogin, logoutFunc OnLogout) *SSAMLIdpInstance {
- return &SSAMLIdpInstance{
- saml: saml,
- onSpInitiatedLogin: spLoginFunc,
- onIdpInitiatedLogin: idpLoginFunc,
- onLogout: logoutFunc,
- htmlTemplate: i18n.Table{},
- }
- }
- func (idp *SSAMLIdpInstance) AddHandlers(app *appsrv.Application, prefix string, middleware appsrv.TMiddleware) {
- idp.metadataPath = httputils.JoinPath(prefix, "metadata/"+IDP_ID_KEY)
- idp.redirectLoginPath = httputils.JoinPath(prefix, "redirect/login/"+IDP_ID_KEY)
- idp.redirectLogoutPath = httputils.JoinPath(prefix, "redirect/logout/"+IDP_ID_KEY)
- idp.idpInitiatedSSOPath = httputils.JoinPath(prefix, "sso")
- app.AddHandler("GET", idp.metadataPath, idp.metadataHandler)
- handler := idp.redirectLoginHandler
- if middleware != nil {
- handler = middleware(handler)
- }
- app.AddHandler("POST", idp.redirectLoginPath, handler)
- app.AddHandler("GET", idp.redirectLoginPath, handler)
- handler = idp.redirectLogoutHandler
- if middleware != nil {
- handler = middleware(handler)
- }
- app.AddHandler("GET", idp.redirectLogoutPath, handler)
- handler = idp.idpInitiatedSSOHandler
- if middleware != nil {
- handler = middleware(handler)
- }
- app.AddHandler("GET", idp.idpInitiatedSSOPath, handler)
- log.Infof("IDP metadata: %s", idp.getMetadataUrl(IDP_ID_KEY))
- log.Infof("IDP redirect login: %s", idp.getRedirectLoginUrl(IDP_ID_KEY))
- log.Infof("IDP redirect logout: %s", idp.getRedirectLogoutUrl(IDP_ID_KEY))
- log.Infof("IDP initated SSO: %s", idp.getIdpInitiatedSSOUrl())
- }
- func (idp *SSAMLIdpInstance) SetHtmlTemplate(entry i18n.TableEntry) error {
- for _, tmp := range entry {
- if strings.Index(tmp, samlutils.HTML_SAML_FORM_TOKEN) < 0 {
- return errors.Wrapf(httperrors.ErrInvalidFormat, "no %s found", samlutils.HTML_SAML_FORM_TOKEN)
- }
- }
- idp.htmlTemplate.Set(langTemplateKey, entry)
- return nil
- }
- func (idp *SSAMLIdpInstance) AddSPMetadataFile(filename string) error {
- metadata, err := ioutil.ReadFile(filename)
- if err != nil {
- return errors.Wrap(err, "ioutil.ReadFile")
- }
- return idp.AddSPMetadata(metadata)
- }
- func (idp *SSAMLIdpInstance) AddSPMetadata(metadata []byte) error {
- ed, err := samlutils.ParseMetadata(metadata)
- if err != nil {
- return errors.Wrap(err, "samlutils.ParseMetadata")
- }
- sp := &SSAMLServiceProvider{desc: ed}
- err = sp.IsValid()
- if err != nil {
- return errors.Wrap(err, "NewSAMLServiceProvider")
- }
- log.Debugf("Register SP metadata: %s", sp.GetEntityId())
- idp.serviceProviders = append(idp.serviceProviders, sp)
- return nil
- }
- func (idp *SSAMLIdpInstance) getMetadataUrl(idpId string) string {
- return strings.Replace(httputils.JoinPath(idp.saml.GetEntityId(), idp.metadataPath), IDP_ID_KEY, idpId, 1)
- }
- func (idp *SSAMLIdpInstance) getRedirectLoginUrl(idpId string) string {
- return strings.Replace(httputils.JoinPath(idp.saml.GetEntityId(), idp.redirectLoginPath), IDP_ID_KEY, idpId, 1)
- }
- func (idp *SSAMLIdpInstance) getRedirectLogoutUrl(idpId string) string {
- return strings.Replace(httputils.JoinPath(idp.saml.GetEntityId(), idp.redirectLogoutPath), IDP_ID_KEY, idpId, 1)
- }
- func (idp *SSAMLIdpInstance) getIdpInitiatedSSOUrl() string {
- return httputils.JoinPath(idp.saml.GetEntityId(), idp.idpInitiatedSSOPath)
- }
- func (idp *SSAMLIdpInstance) metadataHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
- params := appctx.AppContextParams(ctx)
- idpId := params[IDP_ID_KEY]
- desc := idp.GetMetadata(idpId)
- appsrv.SendXmlWithIndent(w, nil, desc, true)
- }
- func (idp *SSAMLIdpInstance) redirectLoginHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
- params, query, _ := appsrv.FetchEnv(ctx, w, r)
- idpId := params[IDP_ID_KEY]
- input := samlutils.SIdpRedirectLoginInput{}
- err := query.Unmarshal(&input)
- if err != nil {
- httperrors.InputParameterError(ctx, w, "query.Unmarshal error %s", err)
- return
- }
- log.Debugf("recv input %s", input)
- respHtml, err := idp.processLoginRequest(ctx, idpId, input)
- if err != nil {
- httperrors.InputParameterError(ctx, w, "parse parameter error %s", err)
- return
- }
- appsrv.SendHTML(w, respHtml)
- }
- func (idp *SSAMLIdpInstance) redirectLogoutHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
- params := appctx.AppContextParams(ctx)
- idpId := params[IDP_ID_KEY]
- log.Debugf("logout: %s", r.Header)
- html := idp.onLogout(ctx, idpId)
- appsrv.SendHTML(w, html)
- }
- func (idp *SSAMLIdpInstance) idpInitiatedSSOHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
- _, query, _ := appsrv.FetchEnv(ctx, w, r)
- input := samlutils.SIdpInitiatedLoginInput{}
- err := query.Unmarshal(&input)
- if err != nil {
- httperrors.InputParameterError(ctx, w, "unmarshal input fail %s", err)
- return
- }
- respHtml, err := idp.processIdpInitiatedLogin(ctx, input)
- if err != nil {
- httperrors.GeneralServerError(ctx, w, err)
- return
- }
- appsrv.SendHTML(w, respHtml)
- }
- func (idp *SSAMLIdpInstance) GetMetadata(idpId string) samlutils.EntityDescriptor {
- input := samlutils.SSAMLIdpMetadataInput{
- EntityId: idp.saml.GetEntityId(),
- CertString: idp.saml.GetCertString(),
- RedirectLoginUrl: idp.getRedirectLoginUrl(idpId),
- RedirectLogoutUrl: idp.getRedirectLogoutUrl(idpId),
- }
- return samlutils.NewIdpMetadata(input)
- }
- func (idp *SSAMLIdpInstance) processLoginRequest(ctx context.Context, idpId string, input samlutils.SIdpRedirectLoginInput) (string, error) {
- plainText, err := samlutils.SAMLDecode(input.SAMLRequest)
- if err != nil {
- return "", errors.Wrap(err, "samlutils.SAMLDecode")
- }
- log.Debugf("AuthnRequest: %s", string(plainText))
- authReq := samlutils.AuthnRequest{}
- err = xml.Unmarshal(plainText, &authReq)
- if err != nil {
- return "", errors.Wrap(err, "xml.Unmarshal")
- }
- sp := idp.getServiceProvider(authReq.Issuer.Issuer)
- if sp == nil {
- return "", errors.Wrapf(httperrors.ErrResourceNotFound, "issuer %s not found", authReq.Issuer.Issuer)
- }
- if len(authReq.Destination) > 0 && authReq.Destination != idp.getRedirectLoginUrl(idpId) {
- return "", errors.Wrapf(httperrors.ErrInputParameter, "Destination not match: get %s want %s", authReq.Destination, idp.getRedirectLoginUrl(idpId))
- }
- if len(authReq.AssertionConsumerServiceURL) > 0 && authReq.AssertionConsumerServiceURL != sp.GetPostAssertionConsumerServiceUrl() {
- return "", errors.Wrapf(httperrors.ErrInputParameter, "AssertionConsumerServiceURL not match: get %s want %s", authReq.AssertionConsumerServiceURL, sp.GetPostAssertionConsumerServiceUrl())
- }
- sp.Username = input.Username
- resp, err := idp.getLoginResponse(ctx, authReq, idpId, sp)
- if err != nil {
- return "", errors.Wrap(err, "getLoginResponse")
- }
- form, err := idp.samlResponse2Form(ctx, sp.GetPostAssertionConsumerServiceUrl(), resp, input.RelayState)
- if err != nil {
- return "", errors.Wrap(err, "samlResponse2Form")
- }
- return form, nil
- }
- func (idp *SSAMLIdpInstance) samlResponse2Form(ctx context.Context, url string, resp *samlutils.Response, state string) (string, error) {
- respXml, err := xml.Marshal(resp)
- if err != nil {
- return "", errors.Wrap(err, "xml.Marshal")
- }
- signed, err := idp.saml.SignXML(string(respXml))
- if err != nil {
- return "", errors.Wrap(err, "saml.SignXML")
- }
- log.Debugf("ResponseXML: %s", signed)
- samlResp := base64.StdEncoding.EncodeToString([]byte(signed))
- form := samlutils.SAMLForm(url, map[string]string{
- "SAMLResponse": samlResp,
- "RelayState": state,
- })
- template := samlutils.DEFAULT_HTML_TEMPLATE
- _temp := idp.htmlTemplate.Lookup(ctx, langTemplateKey)
- if _temp != langTemplateKey {
- template = _temp
- }
- form = strings.Replace(template, samlutils.HTML_SAML_FORM_TOKEN, form, 1)
- return form, nil
- }
- func (idp *SSAMLIdpInstance) getServiceProvider(eId string) *SSAMLServiceProvider {
- for _, sp := range idp.serviceProviders {
- if sp.GetEntityId() == eId {
- return sp
- }
- }
- return nil
- }
- func (idp *SSAMLIdpInstance) getLoginResponse(ctx context.Context, req samlutils.AuthnRequest, idpId string, sp *SSAMLServiceProvider) (*samlutils.Response, error) {
- data, err := idp.onSpInitiatedLogin(ctx, idpId, sp)
- if err != nil {
- return nil, errors.Wrap(err, "idp.onSpInitiatedLogin")
- }
- input := samlutils.SSAMLResponseInput{
- IssuerCertString: idp.saml.GetCertString(),
- IssuerEntityId: idp.saml.GetEntityId(),
- RequestID: req.ID,
- RequestEntityId: req.Issuer.Issuer,
- AssertionConsumerServiceURL: sp.GetPostAssertionConsumerServiceUrl(),
- SSAMLSpInitiatedLoginData: data,
- }
- resp := samlutils.NewResponse(input)
- return &resp, nil
- }
- func (idp *SSAMLIdpInstance) processIdpInitiatedLogin(ctx context.Context, input samlutils.SIdpInitiatedLoginInput) (string, error) {
- sp := idp.getServiceProvider(input.EntityID)
- if sp == nil {
- return "", errors.Wrapf(httperrors.ErrResourceNotFound, "issuer %s not found", input.EntityID)
- }
- data, err := idp.onIdpInitiatedLogin(ctx, sp, input.IdpId, input.RedirectUrl)
- if err != nil {
- return "", errors.Wrap(err, "idp.onIdpInitiatedLogin")
- }
- if len(data.Form) > 0 {
- return data.Form, nil
- }
- respInput := samlutils.SSAMLResponseInput{
- IssuerCertString: idp.saml.GetCertString(),
- IssuerEntityId: idp.saml.GetEntityId(),
- RequestID: "",
- RequestEntityId: sp.GetEntityId(),
- AssertionConsumerServiceURL: sp.GetPostAssertionConsumerServiceUrl(),
- SSAMLSpInitiatedLoginData: data.SSAMLSpInitiatedLoginData,
- }
- resp := samlutils.NewResponse(respInput)
- form, err := idp.samlResponse2Form(ctx, sp.GetPostAssertionConsumerServiceUrl(), &resp, data.RelayState)
- if err != nil {
- return "", errors.Wrap(err, "samlResponse2Form")
- }
- return form, nil
- }
|