transport.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. package digest
  2. import (
  3. "bytes"
  4. "io"
  5. "net/http"
  6. "sync"
  7. )
  8. // cchal is a cached challenge and the number of times it's been used.
  9. type cchal struct {
  10. c *Challenge
  11. n int
  12. }
  13. // Transport implements http.RoundTripper
  14. type Transport struct {
  15. Username string
  16. Password string
  17. // Digest computes the digest credentials.
  18. // If nil, the Digest function is used.
  19. Digest func(*http.Request, *Challenge, Options) (*Credentials, error)
  20. // FindChallenge extracts the challenge from the request headers.
  21. // If nil, the FindChallenge function is used.
  22. FindChallenge func(http.Header) (*Challenge, error)
  23. // Transport specifies the mechanism by which individual
  24. // HTTP requests are made.
  25. // If nil, DefaultTransport is used.
  26. Transport http.RoundTripper
  27. // Jar specifies the cookie jar.
  28. //
  29. // The Jar is used to insert relevant cookies into every
  30. // outbound Request and is updated with the cookie values
  31. // of every inbound Response. The Jar is consulted for every
  32. // redirect that the Client follows.
  33. //
  34. // If Jar is nil, cookies are only sent if they are explicitly
  35. // set on the Request.
  36. Jar http.CookieJar
  37. // NoReuse prevents the transport from reusing challenges.
  38. NoReuse bool
  39. // cache of challenges indexed by host
  40. cache map[string]*cchal
  41. cacheMu sync.Mutex
  42. }
  43. // save parses the digest challenge from the response
  44. // and adds it to the cache
  45. func (t *Transport) save(res *http.Response) error {
  46. // save cookies
  47. if t.Jar != nil {
  48. t.Jar.SetCookies(res.Request.URL, res.Cookies())
  49. }
  50. // find and save digest challenge
  51. find := t.FindChallenge
  52. if find == nil {
  53. find = FindChallenge
  54. }
  55. chal, err := find(res.Header)
  56. t.cacheMu.Lock()
  57. defer t.cacheMu.Unlock()
  58. if t.cache == nil {
  59. t.cache = map[string]*cchal{}
  60. }
  61. // TODO: if the challenge contains a domain, we should be using that
  62. // to match against outgoing requests. We're currently ignoring
  63. // it and just matching the hostname. That being said, none of
  64. // the major browsers respect the domain either.
  65. host := res.Request.URL.Hostname()
  66. if err != nil {
  67. // if save is being invoked, the existing cached challenge didn't work
  68. delete(t.cache, host)
  69. return err
  70. }
  71. t.cache[host] = &cchal{c: chal}
  72. return nil
  73. }
  74. // digest creates credentials from the cached challenge
  75. func (t *Transport) digest(req *http.Request, chal *Challenge, count int) (*Credentials, error) {
  76. opt := Options{
  77. Method: req.Method,
  78. URI: req.URL.RequestURI(),
  79. GetBody: req.GetBody,
  80. Count: count,
  81. Username: t.Username,
  82. Password: t.Password,
  83. }
  84. if t.Digest != nil {
  85. return t.Digest(req, chal, opt)
  86. }
  87. return Digest(chal, opt)
  88. }
  89. // challenge returns a cached challenge and count for the provided request
  90. func (t *Transport) challenge(req *http.Request) (*Challenge, int, bool) {
  91. t.cacheMu.Lock()
  92. defer t.cacheMu.Unlock()
  93. host := req.URL.Hostname()
  94. cc, ok := t.cache[host]
  95. if !ok {
  96. return nil, 0, false
  97. }
  98. if t.NoReuse {
  99. delete(t.cache, host)
  100. }
  101. cc.n++
  102. return cc.c, cc.n, true
  103. }
  104. // prepare attempts to find a cached challenge that matches the
  105. // requested domain, and use it to set the Authorization header
  106. func (t *Transport) prepare(req *http.Request) error {
  107. // add cookies
  108. if t.Jar != nil {
  109. for _, cookie := range t.Jar.Cookies(req.URL) {
  110. req.AddCookie(cookie)
  111. }
  112. }
  113. // add auth
  114. chal, count, ok := t.challenge(req)
  115. if !ok {
  116. return nil
  117. }
  118. cred, err := t.digest(req, chal, count)
  119. if err != nil {
  120. return err
  121. }
  122. if cred != nil {
  123. req.Header.Set("Authorization", cred.String())
  124. }
  125. return nil
  126. }
  127. // RoundTrip will try to authorize the request using a cached challenge.
  128. // If that doesn't work and we receive a 401, we'll try again using that challenge.
  129. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
  130. // use the configured transport if there is one
  131. tr := t.Transport
  132. if tr == nil {
  133. tr = http.DefaultTransport
  134. }
  135. // don't modify the original request
  136. clone, err := cloner(req)
  137. if err != nil {
  138. return nil, err
  139. }
  140. // make a copy of the request
  141. first, err := clone()
  142. if err != nil {
  143. return nil, err
  144. }
  145. // prepare the first request using a cached challenge
  146. if err := t.prepare(first); err != nil {
  147. return nil, err
  148. }
  149. // the first request will either succeed or return a 401
  150. res, err := tr.RoundTrip(first)
  151. if err != nil || res.StatusCode != http.StatusUnauthorized {
  152. return res, err
  153. }
  154. // drain and close the first message body
  155. _, _ = io.Copy(io.Discard, res.Body)
  156. _ = res.Body.Close()
  157. // save the challenge for future use
  158. if err := t.save(res); err != nil {
  159. if err == ErrNoChallenge {
  160. return res, nil
  161. }
  162. return nil, err
  163. }
  164. // make a second copy of the request
  165. second, err := clone()
  166. if err != nil {
  167. return nil, err
  168. }
  169. // prepare the second request based on the new challenge
  170. if err := t.prepare(second); err != nil {
  171. return nil, err
  172. }
  173. return tr.RoundTrip(second)
  174. }
  175. // CloseIdleConnections delegates the call to the underlying transport.
  176. func (t *Transport) CloseIdleConnections() {
  177. tr := t.Transport
  178. if tr == nil {
  179. tr = http.DefaultTransport
  180. }
  181. type closeIdler interface {
  182. CloseIdleConnections()
  183. }
  184. if tr, ok := tr.(closeIdler); ok {
  185. tr.CloseIdleConnections()
  186. }
  187. }
  188. // cloner returns a function which makes clones of the provided request
  189. func cloner(req *http.Request) (func() (*http.Request, error), error) {
  190. getbody := req.GetBody
  191. // if there's no GetBody function set we have to copy the body
  192. // into memory to use for future clones
  193. if getbody == nil {
  194. if req.Body == nil || req.Body == http.NoBody {
  195. getbody = func() (io.ReadCloser, error) {
  196. return http.NoBody, nil
  197. }
  198. } else {
  199. body, err := io.ReadAll(req.Body)
  200. if err != nil {
  201. return nil, err
  202. }
  203. if err := req.Body.Close(); err != nil {
  204. return nil, err
  205. }
  206. getbody = func() (io.ReadCloser, error) {
  207. return io.NopCloser(bytes.NewReader(body)), nil
  208. }
  209. }
  210. }
  211. return func() (*http.Request, error) {
  212. clone := req.Clone(req.Context())
  213. body, err := getbody()
  214. if err != nil {
  215. return nil, err
  216. }
  217. clone.Body = body
  218. clone.GetBody = getbody
  219. return clone, nil
  220. }, nil
  221. }