// Copyright 2019 Yunion
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cas
import (
"context"
"fmt"
"regexp"
"strings"
"yunion.io/x/jsonutils"
"yunion.io/x/log"
"yunion.io/x/pkg/errors"
"yunion.io/x/pkg/util/httputils"
api "yunion.io/x/onecloud/pkg/apis/identity"
"yunion.io/x/onecloud/pkg/keystone/driver"
"yunion.io/x/onecloud/pkg/keystone/models"
"yunion.io/x/onecloud/pkg/mcclient"
)
// apereo CAS (Central Authentication Server)
type SCASDriver struct {
driver.SBaseIdentityDriver
casConfig *api.SCASIdpConfigOptions
isDebug bool
}
func NewCASDriver(idpId, idpName, template, targetDomainId string, conf api.TConfigs) (driver.IIdentityBackend, error) {
base, err := driver.NewBaseIdentityDriver(idpId, idpName, template, targetDomainId, conf)
if err != nil {
return nil, errors.Wrap(err, "NewBaseIdentityDriver")
}
drv := SCASDriver{SBaseIdentityDriver: base}
drv.SetVirtualObject(&drv)
err = drv.prepareConfig()
if err != nil {
return nil, errors.Wrap(err, "prepareConfig")
}
return &drv, nil
}
func (self *SCASDriver) prepareConfig() error {
if self.casConfig == nil {
conf := api.SCASIdpConfigOptions{}
confJson := jsonutils.Marshal(self.Config["cas"])
err := confJson.Unmarshal(&conf)
if err != nil {
return errors.Wrap(err, "json.Unmarshal")
}
if len(conf.UserIdAttribute) == 0 {
conf.UserIdAttribute = "cas:user"
}
if len(conf.UserNameAttribute) == 0 {
conf.UserNameAttribute = "cas:user"
}
self.casConfig = &conf
log.Debugf("%s %s %#v", self.Config, confJson, self.casConfig)
}
return nil
}
func (cas *SCASDriver) GetSsoRedirectUri(ctx context.Context, callbackUrl, state string) (string, error) {
req := map[string]string{
"service": callbackUrl,
}
urlStr := fmt.Sprintf("%s?%s", cas.casConfig.CASServerURL, jsonutils.Marshal(req).QueryString())
return urlStr, nil
}
func (self *SCASDriver) request(ctx context.Context, method httputils.THttpMethod, path string) ([]byte, error) {
cli := httputils.GetDefaultClient()
urlStr := httputils.JoinPath(self.casConfig.CASServerURL, path)
resp, err := httputils.Request(cli, ctx, method, urlStr, nil, nil, self.isDebug)
_, body, err := httputils.ParseResponse("", resp, err, self.isDebug)
return body, err
}
/*
serviceValidate response:
casuser
casuser
UsernamePasswordCredential
false
2019-09-05T12:40:08.014Z[UTC]
AcceptUsersAuthenticationHandler
AcceptUsersAuthenticationHandler
false
*/
/*type SCASServiceResponse struct {
XMLName xml.Name `xml:"serviceResponse"`
CASAuthenticationSuccess struct {
CASUser string `xml:"user"`
} `xml:"authenticationSuccess"`
}*/
func (self *SCASDriver) Authenticate(ctx context.Context, ident mcclient.SAuthenticationIdentity) (*api.SUserExtended, error) {
query := jsonutils.NewDict()
query.Set("ticket", jsonutils.NewString(ident.CASTicket.Id))
query.Set("service", jsonutils.NewString(ident.CASTicket.Service))
path := "serviceValidate?" + query.QueryString()
resp, err := self.request(ctx, "GET", path)
/*if err != nil && httputils.ErrorCode(err) == 404 {
path = "serviceValidate?" + query.QueryString()
resp, err = self.request(ctx, "GET", path)
}*/
if err != nil {
return nil, errors.Wrap(err, "self.request")
}
log.Debugf("CAS response: %s qs: %s", resp, query.QueryString())
attrs := fetchAttributes(resp)
var domainId, domainName, usrId, usrName string
if v, ok := attrs[self.casConfig.DomainIdAttribute]; ok && len(v) > 0 {
domainId = v[0]
}
if v, ok := attrs[self.casConfig.DomainNameAttribute]; ok && len(v) > 0 {
domainName = v[0]
}
if v, ok := attrs[self.casConfig.UserIdAttribute]; ok && len(v) > 0 {
usrId = v[0]
}
if v, ok := attrs[self.casConfig.UserNameAttribute]; ok && len(v) > 0 {
usrName = v[0]
}
idp, err := models.IdentityProviderManager.FetchIdentityProviderById(self.IdpId)
if err != nil {
return nil, errors.Wrap(err, "self.GetIdentityProvider")
}
domain, usr, err := idp.SyncOrCreateDomainAndUser(ctx, domainId, domainName, usrId, usrName)
if err != nil {
return nil, errors.Wrap(err, "idp.SyncOrCreateDomainAndUser")
}
extUser, err := models.UserManager.FetchUserExtended(usr.Id, "", "", "")
if err != nil {
return nil, errors.Wrap(err, "models.UserManager.FetchUserExtended")
}
idp.TryUserJoinProject(self.casConfig.SIdpAttributeOptions, ctx, usr, domain.Id, attrs)
extUser.AuditIds = []string{
ident.CASTicket.Id,
}
return extUser, nil
}
/*func (self *SCASDriver) userTryJoinProject(ctx context.Context, usr *models.SUser, domainId string, resp []byte) {
var err error
var targetProject *models.SProject
log.Debugf("userTryJoinProject resp %s proj %s", string(resp), self.casConfig.CasProjectAttribute)
if !consts.GetNonDefaultDomainProjects() {
domainId = api.DEFAULT_DOMAIN_ID
}
if len(self.casConfig.CasProjectAttribute) > 0 {
projName := fetchAttribute(resp, self.casConfig.CasProjectAttribute)
if len(projName) > 0 {
targetProject, err = models.ProjectManager.FetchProject("", projName, domainId, "")
if err != nil {
log.Errorf("fetch project %s fail %s", projName, err)
if errors.Cause(err) == sql.ErrNoRows && self.casConfig.AutoCreateCasProject.IsTrue() {
targetProject, err = models.ProjectManager.NewProject(ctx, projName, "cas project", domainId)
if err != nil {
log.Errorf("auto create project %s fail %s", projName, err)
}
}
}
}
}
if targetProject == nil && len(self.casConfig.DefaultCasProjectId) > 0 {
targetProject, err = models.ProjectManager.FetchProjectById(self.casConfig.DefaultCasProjectId)
if err != nil {
log.Errorf("fetch default project %s fail %s", self.casConfig.DefaultCasProjectId, err)
}
}
if targetProject != nil {
// put user in project
var targetRole *models.SRole
if len(self.casConfig.CasRoleAttribute) > 0 {
roleName := fetchAttribute(resp, self.casConfig.CasRoleAttribute)
if len(roleName) > 0 {
targetRole, err = models.RoleManager.FetchRole("", roleName, domainId, "")
if err != nil {
log.Errorf("fetch role %s fail %s", roleName, err)
}
}
}
if targetRole == nil && len(self.casConfig.DefaultCasRoleId) > 0 {
targetRole, err = models.RoleManager.FetchRoleById(self.casConfig.DefaultCasRoleId)
if err != nil {
log.Errorf("fetch default role %s fail %s", self.casConfig.DefaultCasRoleId, err)
}
}
if targetRole != nil {
err = models.AssignmentManager.ProjectAddUser(ctx, models.GetDefaultAdminCred(), targetProject, usr, targetRole)
if err != nil {
log.Errorf("CAS user join project fail %s", err)
}
}
}
}*/
func fetchAttributes(heystack []byte) map[string][]string {
ret := make(map[string][]string)
pattern := regexp.MustCompile(`<([^>/]+)>([^<]*)([^>]+)>`)
results := pattern.FindAllStringSubmatch(string(heystack), -1)
for _, result := range results {
key := result[1]
value := strings.TrimSpace(result[2])
var vs []string
if _, ok := ret[key]; ok {
vs = ret[key]
} else {
vs = make([]string, 0, 1)
}
ret[key] = append(vs, value)
}
return ret
}
func (self *SCASDriver) Sync(ctx context.Context) error {
return nil
}
func (self *SCASDriver) Probe(ctx context.Context) error {
_, err := self.request(ctx, "GET", "login")
if err != nil && httputils.ErrorCode(err) != 401 {
return errors.Wrap(err, "self.request")
}
return nil
}