progress.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package aws
  2. import (
  3. "bytes"
  4. "io"
  5. "os"
  6. "strings"
  7. )
  8. type ProgressFunc func(increment, completed, total int64)
  9. type teeReader struct {
  10. reader io.Reader
  11. writer io.Writer
  12. tracker *readerTracker
  13. }
  14. type readerTracker struct {
  15. completedBytes int64
  16. totalBytes int64
  17. progressFunc ProgressFunc
  18. }
  19. // TeeReader returns a Reader that writes to w what it reads from r.
  20. // All reads from r performed through it are matched with
  21. // corresponding writes to w. There is no internal buffering -
  22. // to write must complete before the read completes.
  23. // Any error encountered while writing is reported as a read error.
  24. func TeeReader(reader io.Reader, writer io.Writer, totalBytes int64, progressFunc ProgressFunc) io.ReadCloser {
  25. return &teeReader{
  26. reader: reader,
  27. writer: writer,
  28. tracker: &readerTracker{
  29. completedBytes: 0,
  30. totalBytes: totalBytes,
  31. progressFunc: progressFunc,
  32. },
  33. }
  34. }
  35. func (t *teeReader) Read(p []byte) (n int, err error) {
  36. n, err = t.reader.Read(p)
  37. // Read encountered error
  38. if err != nil && err != io.EOF {
  39. return
  40. }
  41. if n > 0 {
  42. // update completedBytes
  43. t.tracker.completedBytes += int64(n)
  44. if t.tracker.progressFunc != nil {
  45. // report progress
  46. t.tracker.progressFunc(int64(n), t.tracker.completedBytes, t.tracker.totalBytes)
  47. }
  48. // CRC
  49. if t.writer != nil {
  50. if n, err := t.writer.Write(p[:n]); err != nil {
  51. return n, err
  52. }
  53. }
  54. }
  55. return
  56. }
  57. func (t *teeReader) Close() error {
  58. if rc, ok := t.reader.(io.ReadCloser); ok {
  59. return rc.Close()
  60. }
  61. return nil
  62. }
  63. // GetReaderLen returns the length of the reader
  64. func GetReaderLen(reader io.Reader) int64 {
  65. var contentLength int64
  66. switch v := reader.(type) {
  67. case *bytes.Buffer:
  68. contentLength = int64(v.Len())
  69. case *bytes.Reader:
  70. contentLength = int64(v.Len())
  71. case *strings.Reader:
  72. contentLength = int64(v.Len())
  73. case *os.File:
  74. fileInfo, err := v.Stat()
  75. if err != nil {
  76. contentLength = 0
  77. } else {
  78. contentLength = fileInfo.Size()
  79. }
  80. default:
  81. contentLength = 0
  82. }
  83. return contentLength
  84. }