downloader.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. package s3
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "github.com/ks3sdklib/aws-sdk-go/aws"
  7. "github.com/ks3sdklib/aws-sdk-go/internal/crc"
  8. "hash"
  9. "io"
  10. "os"
  11. "path/filepath"
  12. "sort"
  13. "strconv"
  14. "sync"
  15. "sync/atomic"
  16. "time"
  17. )
  18. type DownloadFileInput struct {
  19. // The name of the bucket.
  20. Bucket *string `location:"uri" locationName:"Bucket" type:"string" required:"true"`
  21. // Object key of the object.
  22. Key *string `location:"uri" locationName:"Key" type:"string" required:"true"`
  23. // The path of the file to be downloaded.
  24. DownloadFile *string `type:"string" locationName:"DownloadFile" required:"true"`
  25. // The size of each part.
  26. PartSize *int64 `type:"integer" locationName:"PartSize"`
  27. // The number of tasks to download the file.
  28. TaskNum *int64 `type:"integer" locationName:"TaskNum"`
  29. // Whether to enable checkpoint.
  30. EnableCheckpoint *bool `type:"boolean" locationName:"EnableCheckpoint"`
  31. // The directory to store the checkpoint file.
  32. CheckpointDir *string `type:"string" locationName:"CheckpointDir"`
  33. // The checkpoint file path.
  34. CheckpointFile *string `type:"string" locationName:"CheckpointFile"`
  35. // 下载的范围,range[0]为开始位置,range[1]为结束位置
  36. // range[0] 小于 0 时表示从文件头开始下载
  37. // range[1] 小于 0 时表示下载到文件末尾
  38. // range[0] 和 range[1] 都小于 0 时表示下载整个文件
  39. // range[0] 和 range[1] 都大于等于 0 时表示下载指定范围的文件
  40. // range[0] 大于 range[1] 且 range[1] 非负时表示下载整个文件
  41. // 例如:
  42. // range=[0, 99] 表示下载文件的前100个字节
  43. // range=[100, 199] 表示下载文件的第101个字节至第200个字节
  44. // range=[100, -1] 表示下载文件的第101个字节至文件末尾
  45. // range=[-1, 100] 表示下载文件的后100个字节
  46. // Downloads the specified range bytes of an object.
  47. Range []int64 `locationName:"Range" type:"list"`
  48. // Sets the Content-Type header of the response.
  49. ResponseContentType *string `location:"querystring" locationName:"response-content-type" type:"string"`
  50. // Sets the Content-Language header of the response.
  51. ResponseContentLanguage *string `location:"querystring" locationName:"response-content-language" type:"string"`
  52. // Sets the Expires header of the response.
  53. ResponseExpires *time.Time `location:"querystring" locationName:"response-expires" type:"timestamp" timestampFormat:"iso8601"`
  54. // Sets the Cache-Control header of the response.
  55. ResponseCacheControl *string `location:"querystring" locationName:"response-cache-control" type:"string"`
  56. // Sets the Content-Disposition header of the response
  57. ResponseContentDisposition *string `location:"querystring" locationName:"response-content-disposition" type:"string"`
  58. // Sets the Content-Encoding header of the response.
  59. ResponseContentEncoding *string `location:"querystring" locationName:"response-content-encoding" type:"string"`
  60. // Return the object only if it has been modified since the specified time,
  61. // otherwise return a 304 (not modified).
  62. IfModifiedSince *time.Time `location:"header" locationName:"If-Modified-Since" type:"timestamp" timestampFormat:"rfc822"`
  63. // Return the object only if it has not been modified since the specified time,
  64. // otherwise return a 412 (precondition failed).
  65. IfUnmodifiedSince *time.Time `location:"header" locationName:"If-Unmodified-Since" type:"timestamp" timestampFormat:"rfc822"`
  66. // Return the object only if its entity tag (ETag) is the same as the one specified,
  67. // otherwise return a 412 (precondition failed).
  68. IfMatch *string `location:"header" locationName:"If-Match" type:"string"`
  69. // Return the object only if its entity tag (ETag) is different from the one
  70. // specified, otherwise return a 304 (not modified).
  71. IfNoneMatch *string `location:"header" locationName:"If-None-Match" type:"string"`
  72. // Specify the encoding type of the client.
  73. // If you want to compress and transmit the returned content using gzip,
  74. // you need to add a request header: Accept-Encoding:gzip。
  75. // KS3 will determine whether to return gzip compressed data based on the
  76. // Content-Type and Object size (not less than 1 KB) of the object.
  77. // Value: gzip、br、deflate
  78. AcceptEncoding *string `location:"header" locationName:"Accept-Encoding" type:"string"`
  79. // Specifies the algorithm to use to when encrypting the object, eg: AES256.
  80. SSECustomerAlgorithm *string `location:"header" locationName:"x-amz-server-side-encryption-customer-algorithm" type:"string"`
  81. // Specifies the customer-provided encryption key for KS3 to use in encrypting data.
  82. SSECustomerKey *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key" type:"string"`
  83. // Specifies the 128-bit MD5 digest of the encryption key according to RFC 1321.
  84. SSECustomerKeyMD5 *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key-MD5" type:"string"`
  85. // Progress callback function
  86. ProgressFn aws.ProgressFunc `location:"function"`
  87. }
  88. type DownloadFileOutput struct {
  89. Bucket *string
  90. Key *string
  91. ETag *string
  92. ChecksumCRC64ECMA *string
  93. ObjectMeta map[string]*string
  94. }
  95. func (c *S3) DownloadFile(request *DownloadFileInput) (*DownloadFileOutput, error) {
  96. return c.DownloadFileWithContext(context.Background(), request)
  97. }
  98. func (c *S3) DownloadFileWithContext(ctx context.Context, request *DownloadFileInput) (*DownloadFileOutput, error) {
  99. return newDownloader(c, ctx, request).downloadFile()
  100. }
  101. type Downloader struct {
  102. client *S3
  103. context context.Context
  104. downloadFileRequest *DownloadFileInput
  105. downloadCheckpoint *DownloadCheckpoint
  106. CompletedSize int64
  107. downloadFileSize int64
  108. downloadFileMeta map[string]*string
  109. mu sync.Mutex
  110. error error
  111. }
  112. func newDownloader(s3 *S3, ctx context.Context, request *DownloadFileInput) *Downloader {
  113. return &Downloader{
  114. client: s3,
  115. context: ctx,
  116. downloadFileRequest: request,
  117. }
  118. }
  119. func (d *Downloader) downloadFile() (*DownloadFileOutput, error) {
  120. err := d.validate()
  121. if err != nil {
  122. return nil, err
  123. }
  124. d.downloadFileMeta, err = d.headObject()
  125. if err != nil {
  126. return nil, err
  127. }
  128. dcp, err := newDownloadCheckpoint(d)
  129. if err != nil {
  130. return nil, err
  131. }
  132. d.downloadCheckpoint = dcp
  133. if aws.ToBoolean(d.downloadFileRequest.EnableCheckpoint) {
  134. cpFilePath := aws.ToString(d.downloadFileRequest.CheckpointFile)
  135. if cpFilePath == "" {
  136. cpFilePath, err = generateDownloadCpFilePath(d.downloadFileRequest)
  137. if err != nil {
  138. return nil, err
  139. }
  140. }
  141. dcp.CpFilePath = cpFilePath
  142. err = dcp.load()
  143. if err != nil {
  144. return nil, err
  145. }
  146. if !FileExists(dcp.DownloadFilePath + TempFileSuffix) {
  147. dcp.PartETagList = make([]*CompletedPart, 0)
  148. dcp.remove()
  149. }
  150. }
  151. err = d.createDownloadDir(dcp.DownloadFilePath + TempFileSuffix)
  152. if err != nil {
  153. return nil, err
  154. }
  155. objectRange := d.getObjectRange()
  156. d.downloadFileSize = objectRange[1] - objectRange[0] + 1
  157. partSize := aws.ToLong(d.downloadFileRequest.PartSize)
  158. totalPartNum := (d.downloadFileSize-1)/partSize + 1
  159. tasks := make(chan DownloadPartTask, totalPartNum)
  160. var i int64
  161. for i = 0; i < totalPartNum; i++ {
  162. partNum := i + 1
  163. start := objectRange[0] + i*partSize
  164. end := Min(start+partSize-1, objectRange[1])
  165. actualPartSize := end - start + 1
  166. if d.getPartETag(partNum) != nil {
  167. d.publishProgress(actualPartSize)
  168. } else {
  169. downloadPartTask := DownloadPartTask{
  170. partNumber: partNum,
  171. start: start,
  172. end: end,
  173. actualPartSize: actualPartSize,
  174. }
  175. tasks <- downloadPartTask
  176. }
  177. }
  178. close(tasks)
  179. var wg sync.WaitGroup
  180. for i = 0; i < aws.ToLong(d.downloadFileRequest.TaskNum); i++ {
  181. wg.Add(1)
  182. go d.runTask(tasks, &wg)
  183. }
  184. wg.Wait()
  185. if d.error != nil {
  186. return nil, d.error
  187. }
  188. if d.downloadFileRequest.Range == nil && d.client.Config.CrcCheckEnabled {
  189. clientCrc64 := d.getCrc64Ecma(dcp.PartETagList)
  190. serverCrc64, _ := strconv.ParseUint(aws.ToString(d.downloadFileMeta[HTTPHeaderAmzChecksumCrc64ecma]), 10, 64)
  191. d.client.Config.LogDebug("check file crc64, client crc64:%d, server crc64:%d", clientCrc64, serverCrc64)
  192. if serverCrc64 != 0 && clientCrc64 != serverCrc64 {
  193. return nil, errors.New(fmt.Sprintf("crc64 check failed, client crc64:%d, server crc64:%d", clientCrc64, serverCrc64))
  194. }
  195. }
  196. err = d.complete()
  197. if err != nil {
  198. return nil, err
  199. }
  200. return d.getDownloadFileOutput(), nil
  201. }
  202. func (d *Downloader) validate() error {
  203. request := d.downloadFileRequest
  204. if request == nil {
  205. return errors.New("download file request is required")
  206. }
  207. if aws.ToString(request.Bucket) == "" {
  208. return errors.New("bucket is required")
  209. }
  210. if aws.ToString(request.Key) == "" {
  211. return errors.New("key is required")
  212. }
  213. err := d.normalizeDownloadPath()
  214. if err != nil {
  215. return err
  216. }
  217. if request.PartSize == nil {
  218. request.PartSize = aws.Long(DefaultPartSize)
  219. } else if aws.ToLong(request.PartSize) < MinPartSize {
  220. request.PartSize = aws.Long(MinPartSize)
  221. } else if aws.ToLong(request.PartSize) > MaxPartSize {
  222. request.PartSize = aws.Long(MaxPartSize)
  223. }
  224. if aws.ToLong(request.TaskNum) <= 0 {
  225. request.TaskNum = aws.Long(DefaultTaskNum)
  226. }
  227. return nil
  228. }
  229. func (d *Downloader) getDownloadFileOutput() *DownloadFileOutput {
  230. return &DownloadFileOutput{
  231. Bucket: d.downloadFileRequest.Bucket,
  232. Key: d.downloadFileRequest.Key,
  233. ETag: d.downloadFileMeta[HTTPHeaderEtag],
  234. ChecksumCRC64ECMA: d.downloadFileMeta[HTTPHeaderAmzChecksumCrc64ecma],
  235. ObjectMeta: d.downloadFileMeta,
  236. }
  237. }
  238. func (d *Downloader) getActualPartSize(fileSize int64, partSize int64, partNum int64) int64 {
  239. offset := (partNum - 1) * partSize
  240. actualPartSize := partSize
  241. if offset+partSize >= fileSize {
  242. actualPartSize = fileSize - offset
  243. }
  244. return actualPartSize
  245. }
  246. func (d *Downloader) getPartETag(partNumber int64) *CompletedPart {
  247. for _, partETag := range d.downloadCheckpoint.PartETagList {
  248. if *partETag.PartNumber == partNumber {
  249. return partETag
  250. }
  251. }
  252. return nil
  253. }
  254. type DownloadPartTask struct {
  255. partNumber int64
  256. actualPartSize int64
  257. start int64
  258. end int64
  259. }
  260. func (d *Downloader) runTask(tasks <-chan DownloadPartTask, wg *sync.WaitGroup) {
  261. defer wg.Done()
  262. for task := range tasks {
  263. if d.error != nil {
  264. return
  265. }
  266. partETag, err := d.downloadPart(task)
  267. if err != nil {
  268. d.setError(err)
  269. return
  270. }
  271. d.updatePart(partETag)
  272. }
  273. }
  274. func (d *Downloader) downloadPart(task DownloadPartTask) (CompletedPart, error) {
  275. request := d.downloadFileRequest
  276. dcp := d.downloadCheckpoint
  277. tempFilePath := dcp.DownloadFilePath + TempFileSuffix
  278. var completedPart CompletedPart
  279. resp, err := d.client.GetObjectWithContext(d.context, &GetObjectInput{
  280. Bucket: aws.String(dcp.BucketName),
  281. Key: aws.String(dcp.ObjectKey),
  282. Range: aws.String(fmt.Sprintf("bytes=%d-%d", task.start, task.end)),
  283. ResponseContentType: request.ResponseContentType,
  284. ResponseContentLanguage: request.ResponseContentLanguage,
  285. ResponseExpires: request.ResponseExpires,
  286. ResponseCacheControl: request.ResponseCacheControl,
  287. ResponseContentDisposition: request.ResponseContentDisposition,
  288. ResponseContentEncoding: request.ResponseContentEncoding,
  289. IfModifiedSince: request.IfModifiedSince,
  290. IfUnmodifiedSince: request.IfUnmodifiedSince,
  291. IfMatch: request.IfMatch,
  292. IfNoneMatch: request.IfNoneMatch,
  293. AcceptEncoding: request.AcceptEncoding,
  294. SSECustomerAlgorithm: request.SSECustomerAlgorithm,
  295. SSECustomerKey: request.SSECustomerKey,
  296. SSECustomerKeyMD5: request.SSECustomerKeyMD5,
  297. })
  298. if err != nil {
  299. return completedPart, err
  300. }
  301. defer resp.Body.Close()
  302. var crc64 hash.Hash64
  303. crc64 = crc.NewCRC(crc.CrcTable(), 0)
  304. resp.Body = aws.TeeReader(resp.Body, crc64, task.actualPartSize, nil)
  305. fd, err := os.OpenFile(tempFilePath, os.O_WRONLY|os.O_CREATE, FilePermMode)
  306. if err != nil {
  307. return completedPart, err
  308. }
  309. defer fd.Close()
  310. _, err = fd.Seek((task.partNumber-1)*dcp.PartSize, io.SeekStart)
  311. if err != nil {
  312. return completedPart, err
  313. }
  314. _, err = io.Copy(fd, resp.Body)
  315. if err != nil {
  316. return completedPart, err
  317. }
  318. completedPart.PartNumber = aws.Long(task.partNumber)
  319. completedPart.ChecksumCRC64ECMA = aws.String(strconv.FormatUint(crc64.Sum64(), 10))
  320. d.publishProgress(task.actualPartSize)
  321. return completedPart, nil
  322. }
  323. func (d *Downloader) updatePart(partETag CompletedPart) {
  324. d.mu.Lock()
  325. defer d.mu.Unlock()
  326. d.downloadCheckpoint.PartETagList = append(d.downloadCheckpoint.PartETagList, &partETag)
  327. d.downloadCheckpoint.dump()
  328. }
  329. func (d *Downloader) setError(err error) {
  330. d.mu.Lock()
  331. defer d.mu.Unlock()
  332. if d.error == nil {
  333. d.error = err
  334. }
  335. }
  336. func (d *Downloader) publishProgress(actualPartSize int64) {
  337. if d.downloadFileRequest.ProgressFn != nil {
  338. atomic.AddInt64(&d.CompletedSize, actualPartSize)
  339. d.downloadFileRequest.ProgressFn(actualPartSize, d.CompletedSize, d.downloadFileSize)
  340. }
  341. }
  342. func (d *Downloader) getCrc64Ecma(parts []*CompletedPart) uint64 {
  343. if parts == nil || len(parts) == 0 {
  344. return 0
  345. }
  346. sort.Sort(CompletedParts(d.downloadCheckpoint.PartETagList))
  347. crcTemp, _ := strconv.ParseUint(*parts[0].ChecksumCRC64ECMA, 10, 64)
  348. for i := 1; i < len(parts); i++ {
  349. crc2, _ := strconv.ParseUint(*parts[i].ChecksumCRC64ECMA, 10, 64)
  350. partSize := d.getActualPartSize(d.downloadFileSize, aws.ToLong(d.downloadFileRequest.PartSize), *parts[i].PartNumber)
  351. crcTemp = crc.CRC64Combine(crcTemp, crc2, (uint64)(partSize))
  352. }
  353. return crcTemp
  354. }
  355. func (d *Downloader) complete() error {
  356. fileName := aws.ToString(d.downloadFileRequest.DownloadFile)
  357. tempFileName := fileName + TempFileSuffix
  358. err := os.Rename(tempFileName, fileName)
  359. if err != nil {
  360. return err
  361. }
  362. d.downloadCheckpoint.remove()
  363. return nil
  364. }
  365. func (d *Downloader) headObject() (map[string]*string, error) {
  366. request := d.downloadFileRequest
  367. resp, err := d.client.HeadObjectWithContext(d.context, &HeadObjectInput{
  368. Bucket: request.Bucket,
  369. Key: request.Key,
  370. IfModifiedSince: request.IfModifiedSince,
  371. IfUnmodifiedSince: request.IfUnmodifiedSince,
  372. IfMatch: request.IfMatch,
  373. IfNoneMatch: request.IfNoneMatch,
  374. SSECustomerAlgorithm: request.SSECustomerAlgorithm,
  375. SSECustomerKey: request.SSECustomerKey,
  376. SSECustomerKeyMD5: request.SSECustomerKeyMD5,
  377. })
  378. if err != nil {
  379. return nil, err
  380. }
  381. return resp.Metadata, err
  382. }
  383. func (d *Downloader) createDownloadDir(filePath string) error {
  384. dir := filepath.Dir(filePath)
  385. if !DirExists(dir) {
  386. err := os.MkdirAll(dir, DirPermMode)
  387. if err != nil {
  388. return err
  389. }
  390. }
  391. return nil
  392. }
  393. func (d *Downloader) getObjectRange() []int64 {
  394. objectRange := d.downloadFileRequest.Range
  395. objectSize := d.downloadCheckpoint.ObjectSize
  396. if objectRange == nil {
  397. return []int64{0, objectSize - 1}
  398. }
  399. if !d.isValidRange(objectRange, objectSize) {
  400. d.client.Config.LogWarn("Invalid range value: %v, ignore it and request for entire object", objectRange)
  401. return []int64{0, objectSize - 1}
  402. }
  403. objectStart := objectRange[0]
  404. objectEnd := objectRange[1]
  405. if objectStart < 0 {
  406. return []int64{objectSize - objectEnd, objectSize - 1}
  407. }
  408. if objectEnd < 0 {
  409. return []int64{objectStart, objectSize - 1}
  410. }
  411. return []int64{objectStart, Min(objectEnd, objectSize-1)}
  412. }
  413. func (d *Downloader) isValidRange(objectRange []int64, objectSize int64) bool {
  414. if len(objectRange) != 2 {
  415. return false
  416. }
  417. objectStart := objectRange[0]
  418. objectEnd := objectRange[1]
  419. if objectStart < 0 && objectEnd < 0 || objectEnd >= 0 && objectStart > objectEnd {
  420. return false
  421. }
  422. return objectStart < objectSize
  423. }
  424. func (d *Downloader) normalizeDownloadPath() error {
  425. downloadPath := aws.ToString(d.downloadFileRequest.DownloadFile)
  426. if downloadPath == "" {
  427. downloadPath = aws.ToString(d.downloadFileRequest.Key)
  428. }
  429. // 规范化路径
  430. normalizedPath := filepath.Clean(downloadPath)
  431. // 获取绝对路径
  432. absPath, err := filepath.Abs(normalizedPath)
  433. if err != nil {
  434. return err
  435. }
  436. d.downloadFileRequest.DownloadFile = aws.String(absPath)
  437. return nil
  438. }