sp.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package sp
  15. import (
  16. "context"
  17. "encoding/base64"
  18. "encoding/xml"
  19. "io/ioutil"
  20. "net/http"
  21. "strings"
  22. "yunion.io/x/jsonutils"
  23. "yunion.io/x/log"
  24. "yunion.io/x/pkg/errors"
  25. "yunion.io/x/pkg/util/httputils"
  26. "yunion.io/x/pkg/util/samlutils"
  27. "yunion.io/x/onecloud/pkg/appsrv"
  28. "yunion.io/x/onecloud/pkg/httperrors"
  29. )
  30. type SSAMLAttribute struct {
  31. Name string
  32. FriendlyName string
  33. Values []string
  34. }
  35. type SSAMLAssertionConsumeResult struct {
  36. SSAMLSpInitiatedLoginRequest
  37. Attributes []SSAMLAttribute
  38. }
  39. type SSAMLSpInitiatedLoginRequest struct {
  40. RequestID string
  41. RelayState string
  42. }
  43. type OnSAMLAssertionConsume func(ctx context.Context, w http.ResponseWriter, idp *SSAMLIdentityProvider, result SSAMLAssertionConsumeResult) error
  44. type OnSAMLSpInitiatedLogin func(ctx context.Context, idp *SSAMLIdentityProvider) (SSAMLSpInitiatedLoginRequest, error)
  45. type SSAMLSpInstance struct {
  46. saml *samlutils.SSAMLInstance
  47. serviceName string
  48. metadataPath string
  49. assertionConsumerPath string
  50. spInitiatedSSOPath string
  51. assertionConsumerUri string
  52. identityProviders []*SSAMLIdentityProvider
  53. onSAMLAssertionConsume OnSAMLAssertionConsume
  54. onSAMLSpInitiatedLogin OnSAMLSpInitiatedLogin
  55. htmlTemplate string
  56. }
  57. func NewSpInstance(saml *samlutils.SSAMLInstance, serviceName string, consumeFunc OnSAMLAssertionConsume, loginFunc OnSAMLSpInitiatedLogin) *SSAMLSpInstance {
  58. return &SSAMLSpInstance{
  59. saml: saml,
  60. serviceName: serviceName,
  61. onSAMLAssertionConsume: consumeFunc,
  62. onSAMLSpInitiatedLogin: loginFunc,
  63. }
  64. }
  65. func (sp *SSAMLSpInstance) GetIdentityProviders() []*SSAMLIdentityProvider {
  66. return sp.identityProviders
  67. }
  68. func (sp *SSAMLSpInstance) AddIdpMetadataFile(filename string) error {
  69. metadata, err := ioutil.ReadFile(filename)
  70. if err != nil {
  71. return errors.Wrap(err, "ioutil.ReadFile")
  72. }
  73. return sp.AddIdpMetadata(metadata)
  74. }
  75. func (sp *SSAMLSpInstance) AddIdpMetadata(metadata []byte) error {
  76. ed, err := samlutils.ParseMetadata(metadata)
  77. if err != nil {
  78. return errors.Wrap(err, "samlutils.ParseMetadata")
  79. }
  80. idp, err := NewSAMLIdpFromDescriptor(ed)
  81. if err != nil {
  82. return errors.Wrap(err, "NewSAMLIdpFromDescriptor")
  83. }
  84. err = idp.IsValid()
  85. if err != nil {
  86. return errors.Wrap(err, "Invalid SAMLIdentityProvider")
  87. }
  88. log.Debugf("Register Idp metadata: %s", idp.GetEntityId())
  89. sp.identityProviders = append(sp.identityProviders, idp)
  90. return nil
  91. }
  92. func (sp *SSAMLSpInstance) AddIdp(entityId, redirectSsoUrl string) error {
  93. idp := NewSAMLIdp(entityId, redirectSsoUrl)
  94. err := idp.IsValid()
  95. if err != nil {
  96. return errors.Wrap(err, "Invalid SAMLIdentityProvider")
  97. }
  98. log.Debugf("Register Idp metadata: %s", idp.GetEntityId())
  99. sp.identityProviders = append(sp.identityProviders, idp)
  100. return nil
  101. }
  102. func (sp *SSAMLSpInstance) AddHandlers(app *appsrv.Application, prefix string) {
  103. sp.metadataPath = httputils.JoinPath(prefix, "metadata")
  104. sp.assertionConsumerPath = httputils.JoinPath(prefix, "acs")
  105. sp.spInitiatedSSOPath = httputils.JoinPath(prefix, "sso")
  106. app.AddHandler("GET", sp.metadataPath, sp.metadataHandler)
  107. app.AddHandler("POST", sp.assertionConsumerPath, sp.assertionConsumeHandler)
  108. app.AddHandler("GET", sp.spInitiatedSSOPath, sp.spInitiatedSSOHandler)
  109. log.Infof("SP metadata: %s", sp.getMetadataUrl())
  110. log.Infof("SP assertion consumer: %s", sp.getAssertionConsumerUrl())
  111. log.Infof("SP initated SSO: %s", sp.getSpInitiatedSSOUrl())
  112. }
  113. func (sp *SSAMLSpInstance) SetAssertionConsumerUri(uri string) {
  114. sp.assertionConsumerUri = uri
  115. }
  116. func (sp *SSAMLSpInstance) getMetadataUrl() string {
  117. return httputils.JoinPath(sp.saml.GetEntityId(), sp.metadataPath)
  118. }
  119. func (sp *SSAMLSpInstance) getAssertionConsumerUrl() string {
  120. if len(sp.assertionConsumerUri) > 0 {
  121. return sp.assertionConsumerUri
  122. }
  123. return httputils.JoinPath(sp.saml.GetEntityId(), sp.assertionConsumerPath)
  124. }
  125. func (sp *SSAMLSpInstance) getSpInitiatedSSOUrl() string {
  126. return httputils.JoinPath(sp.saml.GetEntityId(), sp.spInitiatedSSOPath)
  127. }
  128. func (sp *SSAMLSpInstance) metadataHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
  129. desc := sp.GetMetadata()
  130. appsrv.SendXmlWithIndent(w, nil, desc, true)
  131. }
  132. func (sp *SSAMLSpInstance) assertionConsumeHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
  133. samlResponse := r.FormValue("SAMLResponse")
  134. relayState := r.FormValue("RelayState")
  135. err := sp.processAssertionConsumer(ctx, w, samlResponse, relayState)
  136. if err != nil {
  137. httperrors.GeneralServerError(ctx, w, err)
  138. return
  139. }
  140. }
  141. func (sp *SSAMLSpInstance) spInitiatedSSOHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
  142. _, query, _ := appsrv.FetchEnv(ctx, w, r)
  143. input := samlutils.SSpInitiatedLoginInput{}
  144. err := query.Unmarshal(&input)
  145. if err != nil {
  146. httperrors.InputParameterError(ctx, w, "unmarshal input fail %s", err)
  147. return
  148. }
  149. redirectUrl, err := sp.ProcessSpInitiatedLogin(ctx, input)
  150. if err != nil {
  151. httperrors.GeneralServerError(ctx, w, err)
  152. return
  153. }
  154. appsrv.SendRedirect(w, redirectUrl)
  155. }
  156. func (sp *SSAMLSpInstance) GetMetadata() samlutils.EntityDescriptor {
  157. input := samlutils.SSAMLSpMetadataInput{
  158. EntityId: sp.saml.GetEntityId(),
  159. CertString: sp.saml.GetCertString(),
  160. ServiceName: sp.serviceName,
  161. AssertionConsumerUrl: sp.getAssertionConsumerUrl(),
  162. RequestedAttributes: []samlutils.RequestedAttribute{
  163. {
  164. IsRequired: "false",
  165. Name: "userId",
  166. FriendlyName: "userId",
  167. },
  168. {
  169. IsRequired: "false",
  170. Name: "projectId",
  171. FriendlyName: "projectId",
  172. },
  173. {
  174. IsRequired: "false",
  175. Name: "roleId",
  176. FriendlyName: "roleId",
  177. },
  178. },
  179. }
  180. return samlutils.NewSpMetadata(input)
  181. }
  182. func (sp *SSAMLSpInstance) getIdentityProvider(eId string) *SSAMLIdentityProvider {
  183. for _, sp := range sp.identityProviders {
  184. if sp.GetEntityId() == eId {
  185. return sp
  186. }
  187. }
  188. return nil
  189. }
  190. func (sp *SSAMLSpInstance) processAssertionConsumer(ctx context.Context, w http.ResponseWriter, samlResponse string, relayState string) error {
  191. samlRespBytes, err := base64.StdEncoding.DecodeString(samlResponse)
  192. if err != nil {
  193. return errors.Wrap(err, "base64.StdEncoding.DecodeString")
  194. }
  195. log.Debugf("samlResponse: %s", string(samlRespBytes))
  196. samlResp, err := sp.saml.UnmarshalResponse(samlRespBytes)
  197. if err != nil {
  198. return errors.Wrap(err, "saml.UnmarshalResponse")
  199. }
  200. /*_, err = samlutils.ValidateXML(string(samlRespBytes))
  201. if err != nil {
  202. return errors.Wrap(err, "ValidateXML")
  203. }*/
  204. if !samlResp.IsSuccess() {
  205. return errors.Wrapf(httperrors.ErrInvalidCredential, "SAML authenticate fail: %s", samlResp.Status.StatusCode.Value)
  206. }
  207. idp := sp.getIdentityProvider(samlResp.Issuer.Issuer)
  208. if idp == nil {
  209. return errors.Wrapf(httperrors.ErrResourceNotFound, "issuer %s not found", samlResp.Issuer.Issuer)
  210. }
  211. result := SSAMLAssertionConsumeResult{}
  212. if samlResp.InResponseTo != nil {
  213. result.RequestID = *samlResp.InResponseTo
  214. }
  215. result.RelayState = relayState
  216. if samlResp.Assertion != nil && samlResp.Assertion.AttributeStatement != nil {
  217. result.Attributes = make([]SSAMLAttribute, len(samlResp.Assertion.AttributeStatement.Attributes))
  218. for i, attr := range samlResp.Assertion.AttributeStatement.Attributes {
  219. values := make([]string, len(attr.AttributeValues))
  220. for i := range values {
  221. values[i] = attr.AttributeValues[i].Value
  222. }
  223. result.Attributes[i].Name = attr.Name
  224. if attr.FriendlyName != nil {
  225. result.Attributes[i].FriendlyName = *attr.FriendlyName
  226. }
  227. result.Attributes[i].Values = values
  228. }
  229. }
  230. err = sp.onSAMLAssertionConsume(ctx, w, idp, result)
  231. if err != nil {
  232. return errors.Wrap(err, "onSAMLAssertionConsume")
  233. }
  234. return nil
  235. }
  236. func (sp *SSAMLSpInstance) ProcessSpInitiatedLogin(ctx context.Context, input samlutils.SSpInitiatedLoginInput) (string, error) {
  237. idp := sp.getIdentityProvider(input.EntityID)
  238. if idp == nil {
  239. return "", errors.Wrapf(httperrors.ErrResourceNotFound, "issuer %s not found", input.EntityID)
  240. }
  241. loginReq, err := sp.onSAMLSpInitiatedLogin(ctx, idp)
  242. if err != nil {
  243. return "", errors.Wrap(err, "onSAMLSpInitiatedLogin")
  244. }
  245. reqInput := samlutils.SSAMLRequestInput{
  246. AssertionConsumerServiceURL: sp.getAssertionConsumerUrl(),
  247. Destination: idp.getRedirectSSOUrl(),
  248. RequestID: loginReq.RequestID,
  249. EntityID: sp.saml.GetEntityId(),
  250. }
  251. samlRequest := samlutils.NewRequest(reqInput)
  252. samlRequestXml, err := xml.Marshal(samlRequest)
  253. if err != nil {
  254. return "", errors.Wrap(err, "xml.Marshal")
  255. }
  256. reqStr, err := samlutils.SAMLEncode(samlRequestXml)
  257. if err != nil {
  258. return "", errors.Wrap(err, "SAMLEncode")
  259. }
  260. queryInput := samlutils.SIdpRedirectLoginInput{
  261. SAMLRequest: reqStr,
  262. RelayState: loginReq.RelayState,
  263. }
  264. queryStr := jsonutils.Marshal(queryInput).QueryString()
  265. redirectUrl := idp.getRedirectSSOUrl()
  266. if strings.IndexByte(redirectUrl, '?') > 0 {
  267. // non-empty query string
  268. redirectUrl += "&" + queryStr
  269. } else {
  270. redirectUrl += "?" + queryStr
  271. }
  272. return redirectUrl, nil
  273. }