conn.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. package ldap
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "net"
  8. "net/url"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. ber "github.com/go-asn1-ber/asn1-ber"
  13. )
  14. const (
  15. // MessageQuit causes the processMessages loop to exit
  16. MessageQuit = 0
  17. // MessageRequest sends a request to the server
  18. MessageRequest = 1
  19. // MessageResponse receives a response from the server
  20. MessageResponse = 2
  21. // MessageFinish indicates the client considers a particular message ID to be finished
  22. MessageFinish = 3
  23. // MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
  24. MessageTimeout = 4
  25. )
  26. const (
  27. // DefaultLdapPort default ldap port for pure TCP connection
  28. DefaultLdapPort = "389"
  29. // DefaultLdapsPort default ldap port for SSL connection
  30. DefaultLdapsPort = "636"
  31. )
  32. // PacketResponse contains the packet or error encountered reading a response
  33. type PacketResponse struct {
  34. // Packet is the packet read from the server
  35. Packet *ber.Packet
  36. // Error is an error encountered while reading
  37. Error error
  38. }
  39. // ReadPacket returns the packet or an error
  40. func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
  41. if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
  42. return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
  43. }
  44. return pr.Packet, pr.Error
  45. }
  46. type messageContext struct {
  47. id int64
  48. // close(done) should only be called from finishMessage()
  49. done chan struct{}
  50. // close(responses) should only be called from processMessages(), and only sent to from sendResponse()
  51. responses chan *PacketResponse
  52. }
  53. // sendResponse should only be called within the processMessages() loop which
  54. // is also responsible for closing the responses channel.
  55. func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
  56. select {
  57. case msgCtx.responses <- packet:
  58. // Successfully sent packet to message handler.
  59. case <-msgCtx.done:
  60. // The request handler is done and will not receive more
  61. // packets.
  62. }
  63. }
  64. type messagePacket struct {
  65. Op int
  66. MessageID int64
  67. Packet *ber.Packet
  68. Context *messageContext
  69. }
  70. type sendMessageFlags uint
  71. const (
  72. startTLS sendMessageFlags = 1 << iota
  73. )
  74. // Conn represents an LDAP Connection
  75. type Conn struct {
  76. // requestTimeout is loaded atomically
  77. // so we need to ensure 64-bit alignment on 32-bit platforms.
  78. requestTimeout int64
  79. conn net.Conn
  80. isTLS bool
  81. closing uint32
  82. closeErr atomic.Value
  83. isStartingTLS bool
  84. Debug debugging
  85. chanConfirm chan struct{}
  86. messageContexts map[int64]*messageContext
  87. chanMessage chan *messagePacket
  88. chanMessageID chan int64
  89. wgClose sync.WaitGroup
  90. outstandingRequests uint
  91. messageMutex sync.Mutex
  92. }
  93. var _ Client = &Conn{}
  94. // DefaultTimeout is a package-level variable that sets the timeout value
  95. // used for the Dial and DialTLS methods.
  96. //
  97. // WARNING: since this is a package-level variable, setting this value from
  98. // multiple places will probably result in undesired behaviour.
  99. var DefaultTimeout = 60 * time.Second
  100. // DialOpt configures DialContext.
  101. type DialOpt func(*DialContext)
  102. // DialWithDialer updates net.Dialer in DialContext.
  103. func DialWithDialer(d *net.Dialer) DialOpt {
  104. return func(dc *DialContext) {
  105. dc.d = d
  106. }
  107. }
  108. // DialWithTLSConfig updates tls.Config in DialContext.
  109. func DialWithTLSConfig(tc *tls.Config) DialOpt {
  110. return func(dc *DialContext) {
  111. dc.tc = tc
  112. }
  113. }
  114. // DialContext contains necessary parameters to dial the given ldap URL.
  115. type DialContext struct {
  116. d *net.Dialer
  117. tc *tls.Config
  118. }
  119. func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
  120. if u.Scheme == "ldapi" {
  121. if u.Path == "" || u.Path == "/" {
  122. u.Path = "/var/run/slapd/ldapi"
  123. }
  124. return dc.d.Dial("unix", u.Path)
  125. }
  126. host, port, err := net.SplitHostPort(u.Host)
  127. if err != nil {
  128. // we asume that error is due to missing port
  129. host = u.Host
  130. port = ""
  131. }
  132. switch u.Scheme {
  133. case "ldap":
  134. if port == "" {
  135. port = DefaultLdapPort
  136. }
  137. return dc.d.Dial("tcp", net.JoinHostPort(host, port))
  138. case "ldaps":
  139. if port == "" {
  140. port = DefaultLdapsPort
  141. }
  142. return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc)
  143. }
  144. return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
  145. }
  146. // Dial connects to the given address on the given network using net.Dial
  147. // and then returns a new Conn for the connection.
  148. // @deprecated Use DialURL instead.
  149. func Dial(network, addr string) (*Conn, error) {
  150. c, err := net.DialTimeout(network, addr, DefaultTimeout)
  151. if err != nil {
  152. return nil, NewError(ErrorNetwork, err)
  153. }
  154. conn := NewConn(c, false)
  155. conn.Start()
  156. return conn, nil
  157. }
  158. // DialTLS connects to the given address on the given network using tls.Dial
  159. // and then returns a new Conn for the connection.
  160. // @deprecated Use DialURL instead.
  161. func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
  162. c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config)
  163. if err != nil {
  164. return nil, NewError(ErrorNetwork, err)
  165. }
  166. conn := NewConn(c, true)
  167. conn.Start()
  168. return conn, nil
  169. }
  170. // DialURL connects to the given ldap URL.
  171. // The following schemas are supported: ldap://, ldaps://, ldapi://.
  172. // On success a new Conn for the connection is returned.
  173. func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
  174. u, err := url.Parse(addr)
  175. if err != nil {
  176. return nil, NewError(ErrorNetwork, err)
  177. }
  178. var dc DialContext
  179. for _, opt := range opts {
  180. opt(&dc)
  181. }
  182. if dc.d == nil {
  183. dc.d = &net.Dialer{Timeout: DefaultTimeout}
  184. }
  185. c, err := dc.dial(u)
  186. if err != nil {
  187. return nil, NewError(ErrorNetwork, err)
  188. }
  189. conn := NewConn(c, u.Scheme == "ldaps")
  190. conn.Start()
  191. return conn, nil
  192. }
  193. // NewConn returns a new Conn using conn for network I/O.
  194. func NewConn(conn net.Conn, isTLS bool) *Conn {
  195. return &Conn{
  196. conn: conn,
  197. chanConfirm: make(chan struct{}),
  198. chanMessageID: make(chan int64),
  199. chanMessage: make(chan *messagePacket, 10),
  200. messageContexts: map[int64]*messageContext{},
  201. requestTimeout: 0,
  202. isTLS: isTLS,
  203. }
  204. }
  205. // Start initializes goroutines to read responses and process messages
  206. func (l *Conn) Start() {
  207. l.wgClose.Add(1)
  208. go l.reader()
  209. go l.processMessages()
  210. }
  211. // IsClosing returns whether or not we're currently closing.
  212. func (l *Conn) IsClosing() bool {
  213. return atomic.LoadUint32(&l.closing) == 1
  214. }
  215. // setClosing sets the closing value to true
  216. func (l *Conn) setClosing() bool {
  217. return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
  218. }
  219. // Close closes the connection.
  220. func (l *Conn) Close() {
  221. l.messageMutex.Lock()
  222. defer l.messageMutex.Unlock()
  223. if l.setClosing() {
  224. l.Debug.Printf("Sending quit message and waiting for confirmation")
  225. l.chanMessage <- &messagePacket{Op: MessageQuit}
  226. <-l.chanConfirm
  227. close(l.chanMessage)
  228. l.Debug.Printf("Closing network connection")
  229. if err := l.conn.Close(); err != nil {
  230. log.Println(err)
  231. }
  232. l.wgClose.Done()
  233. }
  234. l.wgClose.Wait()
  235. }
  236. // SetTimeout sets the time after a request is sent that a MessageTimeout triggers
  237. func (l *Conn) SetTimeout(timeout time.Duration) {
  238. if timeout > 0 {
  239. atomic.StoreInt64(&l.requestTimeout, int64(timeout))
  240. }
  241. }
  242. // Returns the next available messageID
  243. func (l *Conn) nextMessageID() int64 {
  244. if messageID, ok := <-l.chanMessageID; ok {
  245. return messageID
  246. }
  247. return 0
  248. }
  249. // StartTLS sends the command to start a TLS session and then creates a new TLS Client
  250. func (l *Conn) StartTLS(config *tls.Config) error {
  251. if l.isTLS {
  252. return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
  253. }
  254. packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
  255. packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
  256. request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
  257. request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
  258. packet.AppendChild(request)
  259. l.Debug.PrintPacket(packet)
  260. msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
  261. if err != nil {
  262. return err
  263. }
  264. defer l.finishMessage(msgCtx)
  265. l.Debug.Printf("%d: waiting for response", msgCtx.id)
  266. packetResponse, ok := <-msgCtx.responses
  267. if !ok {
  268. return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
  269. }
  270. packet, err = packetResponse.ReadPacket()
  271. l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
  272. if err != nil {
  273. return err
  274. }
  275. if l.Debug {
  276. if err := addLDAPDescriptions(packet); err != nil {
  277. l.Close()
  278. return err
  279. }
  280. l.Debug.PrintPacket(packet)
  281. }
  282. if err := GetLDAPError(packet); err == nil {
  283. conn := tls.Client(l.conn, config)
  284. if connErr := conn.Handshake(); connErr != nil {
  285. l.Close()
  286. return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr))
  287. }
  288. l.isTLS = true
  289. l.conn = conn
  290. } else {
  291. return err
  292. }
  293. go l.reader()
  294. return nil
  295. }
  296. // TLSConnectionState returns the client's TLS connection state.
  297. // The return values are their zero values if StartTLS did
  298. // not succeed.
  299. func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) {
  300. tc, ok := l.conn.(*tls.Conn)
  301. if !ok {
  302. return
  303. }
  304. return tc.ConnectionState(), true
  305. }
  306. func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
  307. return l.sendMessageWithFlags(packet, 0)
  308. }
  309. func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
  310. if l.IsClosing() {
  311. return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
  312. }
  313. l.messageMutex.Lock()
  314. l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
  315. if l.isStartingTLS {
  316. l.messageMutex.Unlock()
  317. return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
  318. }
  319. if flags&startTLS != 0 {
  320. if l.outstandingRequests != 0 {
  321. l.messageMutex.Unlock()
  322. return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
  323. }
  324. l.isStartingTLS = true
  325. }
  326. l.outstandingRequests++
  327. l.messageMutex.Unlock()
  328. responses := make(chan *PacketResponse)
  329. messageID := packet.Children[0].Value.(int64)
  330. message := &messagePacket{
  331. Op: MessageRequest,
  332. MessageID: messageID,
  333. Packet: packet,
  334. Context: &messageContext{
  335. id: messageID,
  336. done: make(chan struct{}),
  337. responses: responses,
  338. },
  339. }
  340. l.sendProcessMessage(message)
  341. return message.Context, nil
  342. }
  343. func (l *Conn) finishMessage(msgCtx *messageContext) {
  344. close(msgCtx.done)
  345. if l.IsClosing() {
  346. return
  347. }
  348. l.messageMutex.Lock()
  349. l.outstandingRequests--
  350. if l.isStartingTLS {
  351. l.isStartingTLS = false
  352. }
  353. l.messageMutex.Unlock()
  354. message := &messagePacket{
  355. Op: MessageFinish,
  356. MessageID: msgCtx.id,
  357. }
  358. l.sendProcessMessage(message)
  359. }
  360. func (l *Conn) sendProcessMessage(message *messagePacket) bool {
  361. l.messageMutex.Lock()
  362. defer l.messageMutex.Unlock()
  363. if l.IsClosing() {
  364. return false
  365. }
  366. l.chanMessage <- message
  367. return true
  368. }
  369. func (l *Conn) processMessages() {
  370. defer func() {
  371. if err := recover(); err != nil {
  372. log.Printf("ldap: recovered panic in processMessages: %v", err)
  373. }
  374. for messageID, msgCtx := range l.messageContexts {
  375. // If we are closing due to an error, inform anyone who
  376. // is waiting about the error.
  377. if l.IsClosing() && l.closeErr.Load() != nil {
  378. msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
  379. }
  380. l.Debug.Printf("Closing channel for MessageID %d", messageID)
  381. close(msgCtx.responses)
  382. delete(l.messageContexts, messageID)
  383. }
  384. close(l.chanMessageID)
  385. close(l.chanConfirm)
  386. }()
  387. var messageID int64 = 1
  388. for {
  389. select {
  390. case l.chanMessageID <- messageID:
  391. messageID++
  392. case message := <-l.chanMessage:
  393. switch message.Op {
  394. case MessageQuit:
  395. l.Debug.Printf("Shutting down - quit message received")
  396. return
  397. case MessageRequest:
  398. // Add to message list and write to network
  399. l.Debug.Printf("Sending message %d", message.MessageID)
  400. buf := message.Packet.Bytes()
  401. _, err := l.conn.Write(buf)
  402. if err != nil {
  403. l.Debug.Printf("Error Sending Message: %s", err.Error())
  404. message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
  405. close(message.Context.responses)
  406. break
  407. }
  408. // Only add to messageContexts if we were able to
  409. // successfully write the message.
  410. l.messageContexts[message.MessageID] = message.Context
  411. // Add timeout if defined
  412. requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
  413. if requestTimeout > 0 {
  414. go func() {
  415. defer func() {
  416. if err := recover(); err != nil {
  417. log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
  418. }
  419. }()
  420. time.Sleep(requestTimeout)
  421. timeoutMessage := &messagePacket{
  422. Op: MessageTimeout,
  423. MessageID: message.MessageID,
  424. }
  425. l.sendProcessMessage(timeoutMessage)
  426. }()
  427. }
  428. case MessageResponse:
  429. l.Debug.Printf("Receiving message %d", message.MessageID)
  430. if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
  431. msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
  432. } else {
  433. log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing())
  434. l.Debug.PrintPacket(message.Packet)
  435. }
  436. case MessageTimeout:
  437. // Handle the timeout by closing the channel
  438. // All reads will return immediately
  439. if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
  440. l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
  441. msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")})
  442. delete(l.messageContexts, message.MessageID)
  443. close(msgCtx.responses)
  444. }
  445. case MessageFinish:
  446. l.Debug.Printf("Finished message %d", message.MessageID)
  447. if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
  448. delete(l.messageContexts, message.MessageID)
  449. close(msgCtx.responses)
  450. }
  451. }
  452. }
  453. }
  454. }
  455. func (l *Conn) reader() {
  456. cleanstop := false
  457. defer func() {
  458. if err := recover(); err != nil {
  459. log.Printf("ldap: recovered panic in reader: %v", err)
  460. }
  461. if !cleanstop {
  462. l.Close()
  463. }
  464. }()
  465. for {
  466. if cleanstop {
  467. l.Debug.Printf("reader clean stopping (without closing the connection)")
  468. return
  469. }
  470. packet, err := ber.ReadPacket(l.conn)
  471. if err != nil {
  472. // A read error is expected here if we are closing the connection...
  473. if !l.IsClosing() {
  474. l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
  475. l.Debug.Printf("reader error: %s", err)
  476. }
  477. return
  478. }
  479. if err := addLDAPDescriptions(packet); err != nil {
  480. l.Debug.Printf("descriptions error: %s", err)
  481. }
  482. if len(packet.Children) == 0 {
  483. l.Debug.Printf("Received bad ldap packet")
  484. continue
  485. }
  486. l.messageMutex.Lock()
  487. if l.isStartingTLS {
  488. cleanstop = true
  489. }
  490. l.messageMutex.Unlock()
  491. message := &messagePacket{
  492. Op: MessageResponse,
  493. MessageID: packet.Children[0].Value.(int64),
  494. Packet: packet,
  495. }
  496. if !l.sendProcessMessage(message) {
  497. return
  498. }
  499. }
  500. }