| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- package digest
- import (
- "bytes"
- "io"
- "net/http"
- "sync"
- )
- // cchal is a cached challenge and the number of times it's been used.
- type cchal struct {
- c *Challenge
- n int
- }
- // Transport implements http.RoundTripper
- type Transport struct {
- Username string
- Password string
- // Digest computes the digest credentials.
- // If nil, the Digest function is used.
- Digest func(*http.Request, *Challenge, Options) (*Credentials, error)
- // FindChallenge extracts the challenge from the request headers.
- // If nil, the FindChallenge function is used.
- FindChallenge func(http.Header) (*Challenge, error)
- // Transport specifies the mechanism by which individual
- // HTTP requests are made.
- // If nil, DefaultTransport is used.
- Transport http.RoundTripper
- // Jar specifies the cookie jar.
- //
- // The Jar is used to insert relevant cookies into every
- // outbound Request and is updated with the cookie values
- // of every inbound Response. The Jar is consulted for every
- // redirect that the Client follows.
- //
- // If Jar is nil, cookies are only sent if they are explicitly
- // set on the Request.
- Jar http.CookieJar
- // NoReuse prevents the transport from reusing challenges.
- NoReuse bool
- // cache of challenges indexed by host
- cache map[string]*cchal
- cacheMu sync.Mutex
- }
- // save parses the digest challenge from the response
- // and adds it to the cache
- func (t *Transport) save(res *http.Response) error {
- // save cookies
- if t.Jar != nil {
- t.Jar.SetCookies(res.Request.URL, res.Cookies())
- }
- // find and save digest challenge
- find := t.FindChallenge
- if find == nil {
- find = FindChallenge
- }
- chal, err := find(res.Header)
- t.cacheMu.Lock()
- defer t.cacheMu.Unlock()
- if t.cache == nil {
- t.cache = map[string]*cchal{}
- }
- // TODO: if the challenge contains a domain, we should be using that
- // to match against outgoing requests. We're currently ignoring
- // it and just matching the hostname. That being said, none of
- // the major browsers respect the domain either.
- host := res.Request.URL.Hostname()
- if err != nil {
- // if save is being invoked, the existing cached challenge didn't work
- delete(t.cache, host)
- return err
- }
- t.cache[host] = &cchal{c: chal}
- return nil
- }
- // digest creates credentials from the cached challenge
- func (t *Transport) digest(req *http.Request, chal *Challenge, count int) (*Credentials, error) {
- opt := Options{
- Method: req.Method,
- URI: req.URL.RequestURI(),
- GetBody: req.GetBody,
- Count: count,
- Username: t.Username,
- Password: t.Password,
- }
- if t.Digest != nil {
- return t.Digest(req, chal, opt)
- }
- return Digest(chal, opt)
- }
- // challenge returns a cached challenge and count for the provided request
- func (t *Transport) challenge(req *http.Request) (*Challenge, int, bool) {
- t.cacheMu.Lock()
- defer t.cacheMu.Unlock()
- host := req.URL.Hostname()
- cc, ok := t.cache[host]
- if !ok {
- return nil, 0, false
- }
- if t.NoReuse {
- delete(t.cache, host)
- }
- cc.n++
- return cc.c, cc.n, true
- }
- // prepare attempts to find a cached challenge that matches the
- // requested domain, and use it to set the Authorization header
- func (t *Transport) prepare(req *http.Request) error {
- // add cookies
- if t.Jar != nil {
- for _, cookie := range t.Jar.Cookies(req.URL) {
- req.AddCookie(cookie)
- }
- }
- // add auth
- chal, count, ok := t.challenge(req)
- if !ok {
- return nil
- }
- cred, err := t.digest(req, chal, count)
- if err != nil {
- return err
- }
- if cred != nil {
- req.Header.Set("Authorization", cred.String())
- }
- return nil
- }
- // RoundTrip will try to authorize the request using a cached challenge.
- // If that doesn't work and we receive a 401, we'll try again using that challenge.
- func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
- // use the configured transport if there is one
- tr := t.Transport
- if tr == nil {
- tr = http.DefaultTransport
- }
- // don't modify the original request
- clone, err := cloner(req)
- if err != nil {
- return nil, err
- }
- // make a copy of the request
- first, err := clone()
- if err != nil {
- return nil, err
- }
- // prepare the first request using a cached challenge
- if err := t.prepare(first); err != nil {
- return nil, err
- }
- // the first request will either succeed or return a 401
- res, err := tr.RoundTrip(first)
- if err != nil || res.StatusCode != http.StatusUnauthorized {
- return res, err
- }
- // drain and close the first message body
- _, _ = io.Copy(io.Discard, res.Body)
- _ = res.Body.Close()
- // save the challenge for future use
- if err := t.save(res); err != nil {
- if err == ErrNoChallenge {
- return res, nil
- }
- return nil, err
- }
- // make a second copy of the request
- second, err := clone()
- if err != nil {
- return nil, err
- }
- // prepare the second request based on the new challenge
- if err := t.prepare(second); err != nil {
- return nil, err
- }
- return tr.RoundTrip(second)
- }
- // CloseIdleConnections delegates the call to the underlying transport.
- func (t *Transport) CloseIdleConnections() {
- tr := t.Transport
- if tr == nil {
- tr = http.DefaultTransport
- }
- type closeIdler interface {
- CloseIdleConnections()
- }
- if tr, ok := tr.(closeIdler); ok {
- tr.CloseIdleConnections()
- }
- }
- // cloner returns a function which makes clones of the provided request
- func cloner(req *http.Request) (func() (*http.Request, error), error) {
- getbody := req.GetBody
- // if there's no GetBody function set we have to copy the body
- // into memory to use for future clones
- if getbody == nil {
- if req.Body == nil || req.Body == http.NoBody {
- getbody = func() (io.ReadCloser, error) {
- return http.NoBody, nil
- }
- } else {
- body, err := io.ReadAll(req.Body)
- if err != nil {
- return nil, err
- }
- if err := req.Body.Close(); err != nil {
- return nil, err
- }
- getbody = func() (io.ReadCloser, error) {
- return io.NopCloser(bytes.NewReader(body)), nil
- }
- }
- }
- return func() (*http.Request, error) {
- clone := req.Clone(req.Context())
- body, err := getbody()
- if err != nil {
- return nil, err
- }
- clone.Body = body
- clone.GetBody = getbody
- return clone, nil
- }, nil
- }
|