| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- // Copyright 2019 Yunion
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package ldaputils
- import (
- "crypto/tls"
- "encoding/hex"
- "fmt"
- "strings"
- "github.com/go-ldap/ldap/v3"
- "yunion.io/x/log"
- "yunion.io/x/pkg/errors"
- )
- var (
- ErrUserNotFound = errors.Error("not found")
- ErrUserDuplicate = errors.Error("user id duplicate")
- ErrUserBadCredential = errors.Error("bad credential")
- binaryAttributes = []string{
- "objectGUID",
- "objectSid",
- }
- )
- type SLDAPClient struct {
- url string
- account string
- password string
- baseDN string
- isDebug bool
- conn *ldap.Conn
- }
- func NewLDAPClient(url, account, password string, baseDN string, isDebug bool) *SLDAPClient {
- return &SLDAPClient{
- url: url,
- account: account,
- password: password,
- baseDN: baseDN,
- isDebug: isDebug,
- }
- }
- func (cli *SLDAPClient) Connect() error {
- var err error
- cli.conn, err = ldap.DialURL(cli.url, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
- if err != nil {
- return errors.Wrap(err, "DiaURL")
- }
- return cli.bind()
- }
- func (cli *SLDAPClient) bind() error {
- if len(cli.account) > 0 {
- err := cli.conn.Bind(cli.account, cli.password)
- if err != nil {
- return errors.Wrap(err, "Bind")
- }
- }
- return nil
- }
- func (cli *SLDAPClient) Close() {
- if cli.conn != nil {
- cli.conn.Close()
- cli.conn = nil
- }
- }
- func (cli *SLDAPClient) Authenticate(baseDN string, objClass string, uidAttr string, uname string, passwd string, filter string, fields []string, queryScope int) (*ldap.Entry, error) {
- attrMap := make(map[string]string)
- attrMap[uidAttr] = uname
- var retEntry *ldap.Entry
- total, err := cli.Search(
- baseDN, objClass, attrMap, filter, fields, queryScope,
- 2, 2,
- func(offset uint32, entry *ldap.Entry) error {
- retEntry = entry
- return nil
- },
- )
- if err != nil {
- return nil, errors.Wrap(err, "Search")
- }
- if total == 0 {
- return nil, ErrUserNotFound
- }
- if total > 1 {
- return nil, ErrUserDuplicate
- }
- defer cli.bind()
- err = cli.conn.Bind(retEntry.DN, passwd)
- if err != nil {
- return nil, ErrUserBadCredential
- }
- return retEntry, nil
- }
- const (
- StopSearch = errors.Error("stop dap search")
- )
- // support pagination
- // reference: https://zerokspot.com/weblog/2018/03/07/paging-in-gopkg-ldap/
- func (cli *SLDAPClient) Search(
- base string, objClass string,
- condition map[string]string,
- filter string,
- fields []string,
- queryScope int,
- pageSize uint32, limit uint32,
- entryFunc func(offset uint32, entry *ldap.Entry) error,
- ) (uint32, error) {
- searches := strings.Builder{}
- if len(condition) == 0 && len(objClass) == 0 {
- searches.WriteString("(objectClass=*)")
- }
- if len(objClass) > 0 {
- searches.WriteString("(objectClass=")
- searches.WriteString(objClass)
- searches.WriteString(")")
- }
- for k, v := range condition {
- searches.WriteString("(")
- searches.WriteString(k)
- searches.WriteString("=")
- if isBinaryAttr(k) {
- v = toBinarySearchString(v)
- }
- searches.WriteString(v)
- searches.WriteString(")")
- }
- if len(filter) > 0 && strings.HasPrefix(filter, "(") && strings.HasSuffix(filter, ")") {
- searches.WriteString(filter)
- }
- searchStr := fmt.Sprintf("(&%s)", searches.String())
- if len(base) == 0 {
- base = cli.baseDN
- }
- if queryScope != ldap.ScopeWholeSubtree && queryScope != ldap.ScopeSingleLevel && queryScope != ldap.ScopeBaseObject {
- queryScope = ldap.ScopeWholeSubtree
- }
- log.Debugf("ldapSearch: %s", searchStr)
- offset := uint32(0)
- paging := ldap.NewControlPaging(pageSize)
- for {
- searchRequest := ldap.NewSearchRequest(
- base, // The base dn to search
- queryScope, ldap.NeverDerefAliases, 0, 0, false,
- searchStr,
- fields, // A list attributes to retrieve
- []ldap.Control{paging},
- )
- searchResult, err := cli.conn.Search(searchRequest)
- if err != nil {
- return offset, errors.Wrap(err, "Seearch")
- }
- pageTotal := len(searchResult.Entries)
- for i := 0; i < pageTotal; i++ {
- entry := searchResult.Entries[i]
- err := entryFunc(offset, entry)
- if err != nil && errors.Cause(err) != StopSearch {
- return offset, errors.Wrapf(err, "process entry fail at %d-%d-%d", pageTotal, i, offset)
- }
- offset += 1
- if (err != nil && errors.Cause(err) == StopSearch) || (limit > 0 && offset >= limit) {
- // stop, offset exceeds limit or receive StopSearch error
- return offset, nil
- }
- }
- resultCtrl := ldap.FindControl(searchResult.Controls, paging.GetControlType())
- if resultCtrl == nil {
- break
- }
- if pagingCtrl, ok := resultCtrl.(*ldap.ControlPaging); ok {
- if len(pagingCtrl.Cookie) == 0 {
- break
- }
- paging.SetCookie(pagingCtrl.Cookie)
- }
- }
- return offset, nil
- }
- func isBinaryAttr(attrName string) bool {
- for _, attr := range binaryAttributes {
- if strings.EqualFold(attr, attrName) {
- return true
- }
- }
- return false
- }
- func toBinary(val string) string {
- ret, err := hex.DecodeString(val)
- if err != nil {
- return val
- } else {
- return string(ret)
- }
- }
- func toBinarySearchString(val string) string {
- ret := strings.Builder{}
- for i := 0; i < len(val); i += 2 {
- ret.WriteString(`\`)
- ret.WriteString(val[i : i+2])
- }
- return ret.String()
- }
- func toHex(val string) string {
- return hex.EncodeToString([]byte(val))
- }
- func GetAttributeValue(e *ldap.Entry, key string) string {
- val := e.GetAttributeValue(key)
- if isBinaryAttr(key) {
- val = toHex(val)
- }
- return val
- }
- func GetAttributeValues(e *ldap.Entry, key string) []string {
- vals := e.GetAttributeValues(key)
- if isBinaryAttr(key) {
- for i := range vals {
- vals[i] = toHex(vals[i])
- }
- }
- return vals
- }
|