| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542 |
- package s3
- import (
- "context"
- "errors"
- "fmt"
- "github.com/ks3sdklib/aws-sdk-go/aws"
- "github.com/ks3sdklib/aws-sdk-go/internal/crc"
- "hash"
- "io"
- "os"
- "path/filepath"
- "sort"
- "strconv"
- "sync"
- "sync/atomic"
- "time"
- )
- type DownloadFileInput struct {
- // The name of the bucket.
- Bucket *string `location:"uri" locationName:"Bucket" type:"string" required:"true"`
- // Object key of the object.
- Key *string `location:"uri" locationName:"Key" type:"string" required:"true"`
- // The path of the file to be downloaded.
- DownloadFile *string `type:"string" locationName:"DownloadFile" required:"true"`
- // The size of each part.
- PartSize *int64 `type:"integer" locationName:"PartSize"`
- // The number of tasks to download the file.
- TaskNum *int64 `type:"integer" locationName:"TaskNum"`
- // Whether to enable checkpoint.
- EnableCheckpoint *bool `type:"boolean" locationName:"EnableCheckpoint"`
- // The directory to store the checkpoint file.
- CheckpointDir *string `type:"string" locationName:"CheckpointDir"`
- // The checkpoint file path.
- CheckpointFile *string `type:"string" locationName:"CheckpointFile"`
- // 下载的范围,range[0]为开始位置,range[1]为结束位置
- // range[0] 小于 0 时表示从文件头开始下载
- // range[1] 小于 0 时表示下载到文件末尾
- // range[0] 和 range[1] 都小于 0 时表示下载整个文件
- // range[0] 和 range[1] 都大于等于 0 时表示下载指定范围的文件
- // range[0] 大于 range[1] 且 range[1] 非负时表示下载整个文件
- // 例如:
- // range=[0, 99] 表示下载文件的前100个字节
- // range=[100, 199] 表示下载文件的第101个字节至第200个字节
- // range=[100, -1] 表示下载文件的第101个字节至文件末尾
- // range=[-1, 100] 表示下载文件的后100个字节
- // Downloads the specified range bytes of an object.
- Range []int64 `locationName:"Range" type:"list"`
- // Sets the Content-Type header of the response.
- ResponseContentType *string `location:"querystring" locationName:"response-content-type" type:"string"`
- // Sets the Content-Language header of the response.
- ResponseContentLanguage *string `location:"querystring" locationName:"response-content-language" type:"string"`
- // Sets the Expires header of the response.
- ResponseExpires *time.Time `location:"querystring" locationName:"response-expires" type:"timestamp" timestampFormat:"iso8601"`
- // Sets the Cache-Control header of the response.
- ResponseCacheControl *string `location:"querystring" locationName:"response-cache-control" type:"string"`
- // Sets the Content-Disposition header of the response
- ResponseContentDisposition *string `location:"querystring" locationName:"response-content-disposition" type:"string"`
- // Sets the Content-Encoding header of the response.
- ResponseContentEncoding *string `location:"querystring" locationName:"response-content-encoding" type:"string"`
- // Return the object only if it has been modified since the specified time,
- // otherwise return a 304 (not modified).
- IfModifiedSince *time.Time `location:"header" locationName:"If-Modified-Since" type:"timestamp" timestampFormat:"rfc822"`
- // Return the object only if it has not been modified since the specified time,
- // otherwise return a 412 (precondition failed).
- IfUnmodifiedSince *time.Time `location:"header" locationName:"If-Unmodified-Since" type:"timestamp" timestampFormat:"rfc822"`
- // Return the object only if its entity tag (ETag) is the same as the one specified,
- // otherwise return a 412 (precondition failed).
- IfMatch *string `location:"header" locationName:"If-Match" type:"string"`
- // Return the object only if its entity tag (ETag) is different from the one
- // specified, otherwise return a 304 (not modified).
- IfNoneMatch *string `location:"header" locationName:"If-None-Match" type:"string"`
- // Specify the encoding type of the client.
- // If you want to compress and transmit the returned content using gzip,
- // you need to add a request header: Accept-Encoding:gzip。
- // KS3 will determine whether to return gzip compressed data based on the
- // Content-Type and Object size (not less than 1 KB) of the object.
- // Value: gzip、br、deflate
- AcceptEncoding *string `location:"header" locationName:"Accept-Encoding" type:"string"`
- // Specifies the algorithm to use to when encrypting the object, eg: AES256.
- SSECustomerAlgorithm *string `location:"header" locationName:"x-amz-server-side-encryption-customer-algorithm" type:"string"`
- // Specifies the customer-provided encryption key for KS3 to use in encrypting data.
- SSECustomerKey *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key" type:"string"`
- // Specifies the 128-bit MD5 digest of the encryption key according to RFC 1321.
- SSECustomerKeyMD5 *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key-MD5" type:"string"`
- // Progress callback function
- ProgressFn aws.ProgressFunc `location:"function"`
- }
- type DownloadFileOutput struct {
- Bucket *string
- Key *string
- ETag *string
- ChecksumCRC64ECMA *string
- ObjectMeta map[string]*string
- }
- func (c *S3) DownloadFile(request *DownloadFileInput) (*DownloadFileOutput, error) {
- return c.DownloadFileWithContext(context.Background(), request)
- }
- func (c *S3) DownloadFileWithContext(ctx context.Context, request *DownloadFileInput) (*DownloadFileOutput, error) {
- return newDownloader(c, ctx, request).downloadFile()
- }
- type Downloader struct {
- client *S3
- context context.Context
- downloadFileRequest *DownloadFileInput
- downloadCheckpoint *DownloadCheckpoint
- CompletedSize int64
- downloadFileSize int64
- downloadFileMeta map[string]*string
- mu sync.Mutex
- error error
- }
- func newDownloader(s3 *S3, ctx context.Context, request *DownloadFileInput) *Downloader {
- return &Downloader{
- client: s3,
- context: ctx,
- downloadFileRequest: request,
- }
- }
- func (d *Downloader) downloadFile() (*DownloadFileOutput, error) {
- err := d.validate()
- if err != nil {
- return nil, err
- }
- d.downloadFileMeta, err = d.headObject()
- if err != nil {
- return nil, err
- }
- dcp, err := newDownloadCheckpoint(d)
- if err != nil {
- return nil, err
- }
- d.downloadCheckpoint = dcp
- if aws.ToBoolean(d.downloadFileRequest.EnableCheckpoint) {
- cpFilePath := aws.ToString(d.downloadFileRequest.CheckpointFile)
- if cpFilePath == "" {
- cpFilePath, err = generateDownloadCpFilePath(d.downloadFileRequest)
- if err != nil {
- return nil, err
- }
- }
- dcp.CpFilePath = cpFilePath
- err = dcp.load()
- if err != nil {
- return nil, err
- }
- if !FileExists(dcp.DownloadFilePath + TempFileSuffix) {
- dcp.PartETagList = make([]*CompletedPart, 0)
- dcp.remove()
- }
- }
- err = d.createDownloadDir(dcp.DownloadFilePath + TempFileSuffix)
- if err != nil {
- return nil, err
- }
- objectRange := d.getObjectRange()
- d.downloadFileSize = objectRange[1] - objectRange[0] + 1
- partSize := aws.ToLong(d.downloadFileRequest.PartSize)
- totalPartNum := (d.downloadFileSize-1)/partSize + 1
- tasks := make(chan DownloadPartTask, totalPartNum)
- var i int64
- for i = 0; i < totalPartNum; i++ {
- partNum := i + 1
- start := objectRange[0] + i*partSize
- end := Min(start+partSize-1, objectRange[1])
- actualPartSize := end - start + 1
- if d.getPartETag(partNum) != nil {
- d.publishProgress(actualPartSize)
- } else {
- downloadPartTask := DownloadPartTask{
- partNumber: partNum,
- start: start,
- end: end,
- actualPartSize: actualPartSize,
- }
- tasks <- downloadPartTask
- }
- }
- close(tasks)
- var wg sync.WaitGroup
- for i = 0; i < aws.ToLong(d.downloadFileRequest.TaskNum); i++ {
- wg.Add(1)
- go d.runTask(tasks, &wg)
- }
- wg.Wait()
- if d.error != nil {
- return nil, d.error
- }
- if d.downloadFileRequest.Range == nil && d.client.Config.CrcCheckEnabled {
- clientCrc64 := d.getCrc64Ecma(dcp.PartETagList)
- serverCrc64, _ := strconv.ParseUint(aws.ToString(d.downloadFileMeta[HTTPHeaderAmzChecksumCrc64ecma]), 10, 64)
- d.client.Config.LogDebug("check file crc64, client crc64:%d, server crc64:%d", clientCrc64, serverCrc64)
- if serverCrc64 != 0 && clientCrc64 != serverCrc64 {
- return nil, errors.New(fmt.Sprintf("crc64 check failed, client crc64:%d, server crc64:%d", clientCrc64, serverCrc64))
- }
- }
- err = d.complete()
- if err != nil {
- return nil, err
- }
- return d.getDownloadFileOutput(), nil
- }
- func (d *Downloader) validate() error {
- request := d.downloadFileRequest
- if request == nil {
- return errors.New("download file request is required")
- }
- if aws.ToString(request.Bucket) == "" {
- return errors.New("bucket is required")
- }
- if aws.ToString(request.Key) == "" {
- return errors.New("key is required")
- }
- err := d.normalizeDownloadPath()
- if err != nil {
- return err
- }
- if request.PartSize == nil {
- request.PartSize = aws.Long(DefaultPartSize)
- } else if aws.ToLong(request.PartSize) < MinPartSize {
- request.PartSize = aws.Long(MinPartSize)
- } else if aws.ToLong(request.PartSize) > MaxPartSize {
- request.PartSize = aws.Long(MaxPartSize)
- }
- if aws.ToLong(request.TaskNum) <= 0 {
- request.TaskNum = aws.Long(DefaultTaskNum)
- }
- return nil
- }
- func (d *Downloader) getDownloadFileOutput() *DownloadFileOutput {
- return &DownloadFileOutput{
- Bucket: d.downloadFileRequest.Bucket,
- Key: d.downloadFileRequest.Key,
- ETag: d.downloadFileMeta[HTTPHeaderEtag],
- ChecksumCRC64ECMA: d.downloadFileMeta[HTTPHeaderAmzChecksumCrc64ecma],
- ObjectMeta: d.downloadFileMeta,
- }
- }
- func (d *Downloader) getActualPartSize(fileSize int64, partSize int64, partNum int64) int64 {
- offset := (partNum - 1) * partSize
- actualPartSize := partSize
- if offset+partSize >= fileSize {
- actualPartSize = fileSize - offset
- }
- return actualPartSize
- }
- func (d *Downloader) getPartETag(partNumber int64) *CompletedPart {
- for _, partETag := range d.downloadCheckpoint.PartETagList {
- if *partETag.PartNumber == partNumber {
- return partETag
- }
- }
- return nil
- }
- type DownloadPartTask struct {
- partNumber int64
- actualPartSize int64
- start int64
- end int64
- }
- func (d *Downloader) runTask(tasks <-chan DownloadPartTask, wg *sync.WaitGroup) {
- defer wg.Done()
- for task := range tasks {
- if d.error != nil {
- return
- }
- partETag, err := d.downloadPart(task)
- if err != nil {
- d.setError(err)
- return
- }
- d.updatePart(partETag)
- }
- }
- func (d *Downloader) downloadPart(task DownloadPartTask) (CompletedPart, error) {
- request := d.downloadFileRequest
- dcp := d.downloadCheckpoint
- tempFilePath := dcp.DownloadFilePath + TempFileSuffix
- var completedPart CompletedPart
- resp, err := d.client.GetObjectWithContext(d.context, &GetObjectInput{
- Bucket: aws.String(dcp.BucketName),
- Key: aws.String(dcp.ObjectKey),
- Range: aws.String(fmt.Sprintf("bytes=%d-%d", task.start, task.end)),
- ResponseContentType: request.ResponseContentType,
- ResponseContentLanguage: request.ResponseContentLanguage,
- ResponseExpires: request.ResponseExpires,
- ResponseCacheControl: request.ResponseCacheControl,
- ResponseContentDisposition: request.ResponseContentDisposition,
- ResponseContentEncoding: request.ResponseContentEncoding,
- IfModifiedSince: request.IfModifiedSince,
- IfUnmodifiedSince: request.IfUnmodifiedSince,
- IfMatch: request.IfMatch,
- IfNoneMatch: request.IfNoneMatch,
- AcceptEncoding: request.AcceptEncoding,
- SSECustomerAlgorithm: request.SSECustomerAlgorithm,
- SSECustomerKey: request.SSECustomerKey,
- SSECustomerKeyMD5: request.SSECustomerKeyMD5,
- })
- if err != nil {
- return completedPart, err
- }
- defer resp.Body.Close()
- var crc64 hash.Hash64
- crc64 = crc.NewCRC(crc.CrcTable(), 0)
- resp.Body = aws.TeeReader(resp.Body, crc64, task.actualPartSize, nil)
- fd, err := os.OpenFile(tempFilePath, os.O_WRONLY|os.O_CREATE, FilePermMode)
- if err != nil {
- return completedPart, err
- }
- defer fd.Close()
- _, err = fd.Seek((task.partNumber-1)*dcp.PartSize, io.SeekStart)
- if err != nil {
- return completedPart, err
- }
- _, err = io.Copy(fd, resp.Body)
- if err != nil {
- return completedPart, err
- }
- completedPart.PartNumber = aws.Long(task.partNumber)
- completedPart.ChecksumCRC64ECMA = aws.String(strconv.FormatUint(crc64.Sum64(), 10))
- d.publishProgress(task.actualPartSize)
- return completedPart, nil
- }
- func (d *Downloader) updatePart(partETag CompletedPart) {
- d.mu.Lock()
- defer d.mu.Unlock()
- d.downloadCheckpoint.PartETagList = append(d.downloadCheckpoint.PartETagList, &partETag)
- d.downloadCheckpoint.dump()
- }
- func (d *Downloader) setError(err error) {
- d.mu.Lock()
- defer d.mu.Unlock()
- if d.error == nil {
- d.error = err
- }
- }
- func (d *Downloader) publishProgress(actualPartSize int64) {
- if d.downloadFileRequest.ProgressFn != nil {
- atomic.AddInt64(&d.CompletedSize, actualPartSize)
- d.downloadFileRequest.ProgressFn(actualPartSize, d.CompletedSize, d.downloadFileSize)
- }
- }
- func (d *Downloader) getCrc64Ecma(parts []*CompletedPart) uint64 {
- if parts == nil || len(parts) == 0 {
- return 0
- }
- sort.Sort(CompletedParts(d.downloadCheckpoint.PartETagList))
- crcTemp, _ := strconv.ParseUint(*parts[0].ChecksumCRC64ECMA, 10, 64)
- for i := 1; i < len(parts); i++ {
- crc2, _ := strconv.ParseUint(*parts[i].ChecksumCRC64ECMA, 10, 64)
- partSize := d.getActualPartSize(d.downloadFileSize, aws.ToLong(d.downloadFileRequest.PartSize), *parts[i].PartNumber)
- crcTemp = crc.CRC64Combine(crcTemp, crc2, (uint64)(partSize))
- }
- return crcTemp
- }
- func (d *Downloader) complete() error {
- fileName := aws.ToString(d.downloadFileRequest.DownloadFile)
- tempFileName := fileName + TempFileSuffix
- err := os.Rename(tempFileName, fileName)
- if err != nil {
- return err
- }
- d.downloadCheckpoint.remove()
- return nil
- }
- func (d *Downloader) headObject() (map[string]*string, error) {
- request := d.downloadFileRequest
- resp, err := d.client.HeadObjectWithContext(d.context, &HeadObjectInput{
- Bucket: request.Bucket,
- Key: request.Key,
- IfModifiedSince: request.IfModifiedSince,
- IfUnmodifiedSince: request.IfUnmodifiedSince,
- IfMatch: request.IfMatch,
- IfNoneMatch: request.IfNoneMatch,
- SSECustomerAlgorithm: request.SSECustomerAlgorithm,
- SSECustomerKey: request.SSECustomerKey,
- SSECustomerKeyMD5: request.SSECustomerKeyMD5,
- })
- if err != nil {
- return nil, err
- }
- return resp.Metadata, err
- }
- func (d *Downloader) createDownloadDir(filePath string) error {
- dir := filepath.Dir(filePath)
- if !DirExists(dir) {
- err := os.MkdirAll(dir, DirPermMode)
- if err != nil {
- return err
- }
- }
- return nil
- }
- func (d *Downloader) getObjectRange() []int64 {
- objectRange := d.downloadFileRequest.Range
- objectSize := d.downloadCheckpoint.ObjectSize
- if objectRange == nil {
- return []int64{0, objectSize - 1}
- }
- if !d.isValidRange(objectRange, objectSize) {
- d.client.Config.LogWarn("Invalid range value: %v, ignore it and request for entire object", objectRange)
- return []int64{0, objectSize - 1}
- }
- objectStart := objectRange[0]
- objectEnd := objectRange[1]
- if objectStart < 0 {
- return []int64{objectSize - objectEnd, objectSize - 1}
- }
- if objectEnd < 0 {
- return []int64{objectStart, objectSize - 1}
- }
- return []int64{objectStart, Min(objectEnd, objectSize-1)}
- }
- func (d *Downloader) isValidRange(objectRange []int64, objectSize int64) bool {
- if len(objectRange) != 2 {
- return false
- }
- objectStart := objectRange[0]
- objectEnd := objectRange[1]
- if objectStart < 0 && objectEnd < 0 || objectEnd >= 0 && objectStart > objectEnd {
- return false
- }
- return objectStart < objectSize
- }
- func (d *Downloader) normalizeDownloadPath() error {
- downloadPath := aws.ToString(d.downloadFileRequest.DownloadFile)
- if downloadPath == "" {
- downloadPath = aws.ToString(d.downloadFileRequest.Key)
- }
- // 规范化路径
- normalizedPath := filepath.Clean(downloadPath)
- // 获取绝对路径
- absPath, err := filepath.Abs(normalizedPath)
- if err != nil {
- return err
- }
- d.downloadFileRequest.DownloadFile = aws.String(absPath)
- return nil
- }
|