| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317 |
- // Copyright 2019 Yunion
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package sp
- import (
- "context"
- "encoding/base64"
- "encoding/xml"
- "io/ioutil"
- "net/http"
- "strings"
- "yunion.io/x/jsonutils"
- "yunion.io/x/log"
- "yunion.io/x/pkg/errors"
- "yunion.io/x/pkg/util/httputils"
- "yunion.io/x/pkg/util/samlutils"
- "yunion.io/x/onecloud/pkg/appsrv"
- "yunion.io/x/onecloud/pkg/httperrors"
- )
- type SSAMLAttribute struct {
- Name string
- FriendlyName string
- Values []string
- }
- type SSAMLAssertionConsumeResult struct {
- SSAMLSpInitiatedLoginRequest
- Attributes []SSAMLAttribute
- }
- type SSAMLSpInitiatedLoginRequest struct {
- RequestID string
- RelayState string
- }
- type OnSAMLAssertionConsume func(ctx context.Context, w http.ResponseWriter, idp *SSAMLIdentityProvider, result SSAMLAssertionConsumeResult) error
- type OnSAMLSpInitiatedLogin func(ctx context.Context, idp *SSAMLIdentityProvider) (SSAMLSpInitiatedLoginRequest, error)
- type SSAMLSpInstance struct {
- saml *samlutils.SSAMLInstance
- serviceName string
- metadataPath string
- assertionConsumerPath string
- spInitiatedSSOPath string
- assertionConsumerUri string
- identityProviders []*SSAMLIdentityProvider
- onSAMLAssertionConsume OnSAMLAssertionConsume
- onSAMLSpInitiatedLogin OnSAMLSpInitiatedLogin
- htmlTemplate string
- }
- func NewSpInstance(saml *samlutils.SSAMLInstance, serviceName string, consumeFunc OnSAMLAssertionConsume, loginFunc OnSAMLSpInitiatedLogin) *SSAMLSpInstance {
- return &SSAMLSpInstance{
- saml: saml,
- serviceName: serviceName,
- onSAMLAssertionConsume: consumeFunc,
- onSAMLSpInitiatedLogin: loginFunc,
- }
- }
- func (sp *SSAMLSpInstance) GetIdentityProviders() []*SSAMLIdentityProvider {
- return sp.identityProviders
- }
- func (sp *SSAMLSpInstance) AddIdpMetadataFile(filename string) error {
- metadata, err := ioutil.ReadFile(filename)
- if err != nil {
- return errors.Wrap(err, "ioutil.ReadFile")
- }
- return sp.AddIdpMetadata(metadata)
- }
- func (sp *SSAMLSpInstance) AddIdpMetadata(metadata []byte) error {
- ed, err := samlutils.ParseMetadata(metadata)
- if err != nil {
- return errors.Wrap(err, "samlutils.ParseMetadata")
- }
- idp, err := NewSAMLIdpFromDescriptor(ed)
- if err != nil {
- return errors.Wrap(err, "NewSAMLIdpFromDescriptor")
- }
- err = idp.IsValid()
- if err != nil {
- return errors.Wrap(err, "Invalid SAMLIdentityProvider")
- }
- log.Debugf("Register Idp metadata: %s", idp.GetEntityId())
- sp.identityProviders = append(sp.identityProviders, idp)
- return nil
- }
- func (sp *SSAMLSpInstance) AddIdp(entityId, redirectSsoUrl string) error {
- idp := NewSAMLIdp(entityId, redirectSsoUrl)
- err := idp.IsValid()
- if err != nil {
- return errors.Wrap(err, "Invalid SAMLIdentityProvider")
- }
- log.Debugf("Register Idp metadata: %s", idp.GetEntityId())
- sp.identityProviders = append(sp.identityProviders, idp)
- return nil
- }
- func (sp *SSAMLSpInstance) AddHandlers(app *appsrv.Application, prefix string) {
- sp.metadataPath = httputils.JoinPath(prefix, "metadata")
- sp.assertionConsumerPath = httputils.JoinPath(prefix, "acs")
- sp.spInitiatedSSOPath = httputils.JoinPath(prefix, "sso")
- app.AddHandler("GET", sp.metadataPath, sp.metadataHandler)
- app.AddHandler("POST", sp.assertionConsumerPath, sp.assertionConsumeHandler)
- app.AddHandler("GET", sp.spInitiatedSSOPath, sp.spInitiatedSSOHandler)
- log.Infof("SP metadata: %s", sp.getMetadataUrl())
- log.Infof("SP assertion consumer: %s", sp.getAssertionConsumerUrl())
- log.Infof("SP initated SSO: %s", sp.getSpInitiatedSSOUrl())
- }
- func (sp *SSAMLSpInstance) SetAssertionConsumerUri(uri string) {
- sp.assertionConsumerUri = uri
- }
- func (sp *SSAMLSpInstance) getMetadataUrl() string {
- return httputils.JoinPath(sp.saml.GetEntityId(), sp.metadataPath)
- }
- func (sp *SSAMLSpInstance) getAssertionConsumerUrl() string {
- if len(sp.assertionConsumerUri) > 0 {
- return sp.assertionConsumerUri
- }
- return httputils.JoinPath(sp.saml.GetEntityId(), sp.assertionConsumerPath)
- }
- func (sp *SSAMLSpInstance) getSpInitiatedSSOUrl() string {
- return httputils.JoinPath(sp.saml.GetEntityId(), sp.spInitiatedSSOPath)
- }
- func (sp *SSAMLSpInstance) metadataHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
- desc := sp.GetMetadata()
- appsrv.SendXmlWithIndent(w, nil, desc, true)
- }
- func (sp *SSAMLSpInstance) assertionConsumeHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
- samlResponse := r.FormValue("SAMLResponse")
- relayState := r.FormValue("RelayState")
- err := sp.processAssertionConsumer(ctx, w, samlResponse, relayState)
- if err != nil {
- httperrors.GeneralServerError(ctx, w, err)
- return
- }
- }
- func (sp *SSAMLSpInstance) spInitiatedSSOHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
- _, query, _ := appsrv.FetchEnv(ctx, w, r)
- input := samlutils.SSpInitiatedLoginInput{}
- err := query.Unmarshal(&input)
- if err != nil {
- httperrors.InputParameterError(ctx, w, "unmarshal input fail %s", err)
- return
- }
- redirectUrl, err := sp.ProcessSpInitiatedLogin(ctx, input)
- if err != nil {
- httperrors.GeneralServerError(ctx, w, err)
- return
- }
- appsrv.SendRedirect(w, redirectUrl)
- }
- func (sp *SSAMLSpInstance) GetMetadata() samlutils.EntityDescriptor {
- input := samlutils.SSAMLSpMetadataInput{
- EntityId: sp.saml.GetEntityId(),
- CertString: sp.saml.GetCertString(),
- ServiceName: sp.serviceName,
- AssertionConsumerUrl: sp.getAssertionConsumerUrl(),
- RequestedAttributes: []samlutils.RequestedAttribute{
- {
- IsRequired: "false",
- Name: "userId",
- FriendlyName: "userId",
- },
- {
- IsRequired: "false",
- Name: "projectId",
- FriendlyName: "projectId",
- },
- {
- IsRequired: "false",
- Name: "roleId",
- FriendlyName: "roleId",
- },
- },
- }
- return samlutils.NewSpMetadata(input)
- }
- func (sp *SSAMLSpInstance) getIdentityProvider(eId string) *SSAMLIdentityProvider {
- for _, sp := range sp.identityProviders {
- if sp.GetEntityId() == eId {
- return sp
- }
- }
- return nil
- }
- func (sp *SSAMLSpInstance) processAssertionConsumer(ctx context.Context, w http.ResponseWriter, samlResponse string, relayState string) error {
- samlRespBytes, err := base64.StdEncoding.DecodeString(samlResponse)
- if err != nil {
- return errors.Wrap(err, "base64.StdEncoding.DecodeString")
- }
- log.Debugf("samlResponse: %s", string(samlRespBytes))
- samlResp, err := sp.saml.UnmarshalResponse(samlRespBytes)
- if err != nil {
- return errors.Wrap(err, "saml.UnmarshalResponse")
- }
- /*_, err = samlutils.ValidateXML(string(samlRespBytes))
- if err != nil {
- return errors.Wrap(err, "ValidateXML")
- }*/
- if !samlResp.IsSuccess() {
- return errors.Wrapf(httperrors.ErrInvalidCredential, "SAML authenticate fail: %s", samlResp.Status.StatusCode.Value)
- }
- idp := sp.getIdentityProvider(samlResp.Issuer.Issuer)
- if idp == nil {
- return errors.Wrapf(httperrors.ErrResourceNotFound, "issuer %s not found", samlResp.Issuer.Issuer)
- }
- result := SSAMLAssertionConsumeResult{}
- if samlResp.InResponseTo != nil {
- result.RequestID = *samlResp.InResponseTo
- }
- result.RelayState = relayState
- if samlResp.Assertion != nil && samlResp.Assertion.AttributeStatement != nil {
- result.Attributes = make([]SSAMLAttribute, len(samlResp.Assertion.AttributeStatement.Attributes))
- for i, attr := range samlResp.Assertion.AttributeStatement.Attributes {
- values := make([]string, len(attr.AttributeValues))
- for i := range values {
- values[i] = attr.AttributeValues[i].Value
- }
- result.Attributes[i].Name = attr.Name
- if attr.FriendlyName != nil {
- result.Attributes[i].FriendlyName = *attr.FriendlyName
- }
- result.Attributes[i].Values = values
- }
- }
- err = sp.onSAMLAssertionConsume(ctx, w, idp, result)
- if err != nil {
- return errors.Wrap(err, "onSAMLAssertionConsume")
- }
- return nil
- }
- func (sp *SSAMLSpInstance) ProcessSpInitiatedLogin(ctx context.Context, input samlutils.SSpInitiatedLoginInput) (string, error) {
- idp := sp.getIdentityProvider(input.EntityID)
- if idp == nil {
- return "", errors.Wrapf(httperrors.ErrResourceNotFound, "issuer %s not found", input.EntityID)
- }
- loginReq, err := sp.onSAMLSpInitiatedLogin(ctx, idp)
- if err != nil {
- return "", errors.Wrap(err, "onSAMLSpInitiatedLogin")
- }
- reqInput := samlutils.SSAMLRequestInput{
- AssertionConsumerServiceURL: sp.getAssertionConsumerUrl(),
- Destination: idp.getRedirectSSOUrl(),
- RequestID: loginReq.RequestID,
- EntityID: sp.saml.GetEntityId(),
- }
- samlRequest := samlutils.NewRequest(reqInput)
- samlRequestXml, err := xml.Marshal(samlRequest)
- if err != nil {
- return "", errors.Wrap(err, "xml.Marshal")
- }
- reqStr, err := samlutils.SAMLEncode(samlRequestXml)
- if err != nil {
- return "", errors.Wrap(err, "SAMLEncode")
- }
- queryInput := samlutils.SIdpRedirectLoginInput{
- SAMLRequest: reqStr,
- RelayState: loginReq.RelayState,
- }
- queryStr := jsonutils.Marshal(queryInput).QueryString()
- redirectUrl := idp.getRedirectSSOUrl()
- if strings.IndexByte(redirectUrl, '?') > 0 {
- // non-empty query string
- redirectUrl += "&" + queryStr
- } else {
- redirectUrl += "?" + queryStr
- }
- return redirectUrl, nil
- }
|