streamutils.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package streamutils
  15. import (
  16. "crypto/md5"
  17. "fmt"
  18. "hash"
  19. "io"
  20. "github.com/ulikunitz/xz"
  21. "yunion.io/x/pkg/errors"
  22. )
  23. type SStreamProperty struct {
  24. CheckSum string
  25. Size int64
  26. }
  27. type sXZReadAheadReader struct {
  28. offset int64
  29. header []byte
  30. hdrEof bool
  31. upstream io.Reader
  32. }
  33. func newXZReadAheadReader(stream io.Reader) (*sXZReadAheadReader, error) {
  34. xzHdr := make([]byte, xz.HeaderLen)
  35. n, err := stream.Read(xzHdr)
  36. hdrEof := false
  37. if err != nil {
  38. if errors.Cause(err) == io.EOF {
  39. // delay the EOF
  40. hdrEof = true
  41. xzHdr = xzHdr[:n]
  42. } else {
  43. return nil, errors.Wrap(err, "Read XZ header")
  44. }
  45. } else if n != len(xzHdr) {
  46. hdrEof = true
  47. xzHdr = xzHdr[:n]
  48. }
  49. return &sXZReadAheadReader{
  50. offset: 0,
  51. header: xzHdr,
  52. hdrEof: hdrEof,
  53. upstream: stream,
  54. }, nil
  55. }
  56. func (s *sXZReadAheadReader) IsXz() bool {
  57. return xz.ValidHeader(s.header)
  58. }
  59. func (s *sXZReadAheadReader) Read(buf []byte) (int, error) {
  60. bufOffset := 0
  61. if s.offset < int64(len(s.header)) {
  62. // read from header
  63. rdSize := len(s.header) - int(s.offset)
  64. if rdSize > len(buf) {
  65. rdSize = len(buf)
  66. }
  67. n := copy(buf, s.header[s.offset:s.offset+int64(rdSize)])
  68. s.offset += int64(n)
  69. bufOffset = n
  70. }
  71. // read buffer is full
  72. if bufOffset >= len(buf) {
  73. return bufOffset, nil
  74. }
  75. if s.offset >= int64(len(s.header)) && s.hdrEof {
  76. return bufOffset, io.EOF
  77. }
  78. n, err := s.upstream.Read(buf[bufOffset:])
  79. s.offset += int64(n)
  80. return n + bufOffset, err
  81. }
  82. func StreamPipe(upstream io.Reader, writer io.Writer, CalChecksum bool, callback func(savedTotal int64)) (*SStreamProperty, error) {
  83. return StreamPipe2(upstream, writer, CalChecksum, func(savedTotal int64, savedOnce int64) {
  84. if callback != nil {
  85. callback(savedTotal)
  86. }
  87. })
  88. }
  89. func StreamPipe2(upstream io.Reader, writer io.Writer, CalChecksum bool, callback func(savedTotal int64, savedOnce int64)) (*SStreamProperty, error) {
  90. sp := SStreamProperty{}
  91. var md5sum hash.Hash
  92. if CalChecksum {
  93. md5sum = md5.New()
  94. }
  95. aheadReader, err := newXZReadAheadReader(upstream)
  96. if err != nil {
  97. return nil, errors.Wrap(err, "ReadAheadReader")
  98. }
  99. var reader io.Reader
  100. if aheadReader.IsXz() {
  101. xzReader, err := xz.NewReader(aheadReader)
  102. if err != nil {
  103. return nil, errors.Wrap(err, "xz.NewReader")
  104. }
  105. reader = xzReader
  106. } else {
  107. reader = aheadReader
  108. }
  109. buf := make([]byte, 4096)
  110. for {
  111. n, err := reader.Read(buf)
  112. if n > 0 {
  113. sp.Size += int64(n)
  114. if callback != nil {
  115. callback(sp.Size, int64(n))
  116. }
  117. if CalChecksum {
  118. md5sum.Write(buf[:n])
  119. }
  120. offset := 0
  121. for offset < n {
  122. m, err := writer.Write(buf[offset:n])
  123. if err != nil {
  124. return nil, err
  125. }
  126. offset += m
  127. }
  128. }
  129. if err != nil {
  130. if err == io.EOF {
  131. break
  132. }
  133. return nil, err
  134. }
  135. }
  136. if CalChecksum {
  137. sp.CheckSum = fmt.Sprintf("%x", md5sum.Sum(nil))
  138. }
  139. return &sp, nil
  140. }