peridot/vendor/github.com/aws/aws-sdk-go/service/s3/s3manager/download.go

609 lines
18 KiB
Go
Raw Normal View History

2022-07-07 20:11:50 +00:00
package s3manager
import (
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
// DefaultDownloadPartSize is the default range of bytes to get at a time when
// using Download().
const DefaultDownloadPartSize = 1024 * 1024 * 5
// DefaultDownloadConcurrency is the default number of goroutines to spin up
// when using Download().
const DefaultDownloadConcurrency = 5
type errReadingBody struct {
err error
}
func (e *errReadingBody) Error() string {
return fmt.Sprintf("failed to read part body: %v", e.err)
}
func (e *errReadingBody) Unwrap() error {
return e.err
}
// The Downloader structure that calls Download(). It is safe to call Download()
// on this structure for multiple objects and across concurrent goroutines.
// Mutating the Downloader's properties is not safe to be done concurrently.
type Downloader struct {
// The size (in bytes) to request from S3 for each part.
// The minimum allowed part size is 5MB, and if this value is set to zero,
// the DefaultDownloadPartSize value will be used.
//
// PartSize is ignored if the Range input parameter is provided.
PartSize int64
// The number of goroutines to spin up in parallel when sending parts.
// If this is set to zero, the DefaultDownloadConcurrency value will be used.
//
// Concurrency of 1 will download the parts sequentially.
//
// Concurrency is ignored if the Range input parameter is provided.
Concurrency int
// An S3 client to use when performing downloads.
S3 s3iface.S3API
// List of request options that will be passed down to individual API
// operation requests made by the downloader.
RequestOptions []request.Option
// Defines the buffer strategy used when downloading a part.
//
// If a WriterReadFromProvider is given the Download manager
// will pass the io.WriterAt of the Download request to the provider
// and will use the returned WriterReadFrom from the provider as the
// destination writer when copying from http response body.
BufferProvider WriterReadFromProvider
}
// WithDownloaderRequestOptions appends to the Downloader's API request options.
func WithDownloaderRequestOptions(opts ...request.Option) func(*Downloader) {
return func(d *Downloader) {
d.RequestOptions = append(d.RequestOptions, opts...)
}
}
// NewDownloader creates a new Downloader instance to downloads objects from
// S3 in concurrent chunks. Pass in additional functional options to customize
// the downloader behavior. Requires a client.ConfigProvider in order to create
// a S3 service client. The session.Session satisfies the client.ConfigProvider
// interface.
//
// Example:
//
2022-11-04 02:21:49 +00:00
// // The session the S3 Downloader will use
// sess := session.Must(session.NewSession())
2022-07-07 20:11:50 +00:00
//
2022-11-04 02:21:49 +00:00
// // Create a downloader with the session and default options
// downloader := s3manager.NewDownloader(sess)
//
// // Create a downloader with the session and custom options
// downloader := s3manager.NewDownloader(sess, func(d *s3manager.Downloader) {
// d.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
2022-07-07 20:11:50 +00:00
func NewDownloader(c client.ConfigProvider, options ...func(*Downloader)) *Downloader {
return newDownloader(s3.New(c), options...)
}
func newDownloader(client s3iface.S3API, options ...func(*Downloader)) *Downloader {
d := &Downloader{
S3: client,
PartSize: DefaultDownloadPartSize,
Concurrency: DefaultDownloadConcurrency,
BufferProvider: defaultDownloadBufferProvider(),
}
for _, option := range options {
option(d)
}
return d
}
// NewDownloaderWithClient creates a new Downloader instance to downloads
// objects from S3 in concurrent chunks. Pass in additional functional
// options to customize the downloader behavior. Requires a S3 service client
// to make S3 API calls.
//
// Example:
//
2022-11-04 02:21:49 +00:00
// // The session the S3 Downloader will use
// sess := session.Must(session.NewSession())
//
// // The S3 client the S3 Downloader will use
// s3Svc := s3.New(sess)
2022-07-07 20:11:50 +00:00
//
2022-11-04 02:21:49 +00:00
// // Create a downloader with the s3 client and default options
// downloader := s3manager.NewDownloaderWithClient(s3Svc)
2022-07-07 20:11:50 +00:00
//
2022-11-04 02:21:49 +00:00
// // Create a downloader with the s3 client and custom options
// downloader := s3manager.NewDownloaderWithClient(s3Svc, func(d *s3manager.Downloader) {
// d.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
2022-07-07 20:11:50 +00:00
func NewDownloaderWithClient(svc s3iface.S3API, options ...func(*Downloader)) *Downloader {
return newDownloader(svc, options...)
}
type maxRetrier interface {
MaxRetries() int
}
// Download downloads an object in S3 and writes the payload into w using
// concurrent GET requests. The n int64 returned is the size of the object downloaded
// in bytes.
//
// Additional functional options can be provided to configure the individual
// download. These options are copies of the Downloader instance Download is called from.
// Modifying the options will not impact the original Downloader instance.
//
// It is safe to call this method concurrently across goroutines.
//
// The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
// downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
//
// Specifying a Downloader.Concurrency of 1 will cause the Downloader to
// download the parts from S3 sequentially.
//
// If the GetObjectInput's Range value is provided that will cause the downloader
// to perform a single GetObjectInput request for that object's range. This will
// caused the part size, and concurrency configurations to be ignored.
func (d Downloader) Download(w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
return d.DownloadWithContext(aws.BackgroundContext(), w, input, options...)
}
// DownloadWithContext downloads an object in S3 and writes the payload into w
// using concurrent GET requests. The n int64 returned is the size of the object downloaded
// in bytes.
//
// DownloadWithContext is the same as Download with the additional support for
// Context input parameters. The Context must not be nil. A nil Context will
// cause a panic. Use the Context to add deadlining, timeouts, etc. The
// DownloadWithContext may create sub-contexts for individual underlying
// requests.
//
// Additional functional options can be provided to configure the individual
// download. These options are copies of the Downloader instance Download is
// called from. Modifying the options will not impact the original Downloader
// instance. Use the WithDownloaderRequestOptions helper function to pass in request
// options that will be applied to all API operations made with this downloader.
//
// The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
// downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
//
// Specifying a Downloader.Concurrency of 1 will cause the Downloader to
// download the parts from S3 sequentially.
//
// It is safe to call this method concurrently across goroutines.
//
// If the GetObjectInput's Range value is provided that will cause the downloader
// to perform a single GetObjectInput request for that object's range. This will
// caused the part size, and concurrency configurations to be ignored.
func (d Downloader) DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
2022-11-04 02:21:49 +00:00
if err := validateSupportedARNType(aws.StringValue(input.Bucket)); err != nil {
return 0, err
}
2022-07-07 20:11:50 +00:00
impl := downloader{w: w, in: input, cfg: d, ctx: ctx}
for _, option := range options {
option(&impl.cfg)
}
impl.cfg.RequestOptions = append(impl.cfg.RequestOptions, request.WithAppendUserAgent("S3Manager"))
if s, ok := d.S3.(maxRetrier); ok {
impl.partBodyMaxRetries = s.MaxRetries()
}
impl.totalBytes = -1
if impl.cfg.Concurrency == 0 {
impl.cfg.Concurrency = DefaultDownloadConcurrency
}
if impl.cfg.PartSize == 0 {
impl.cfg.PartSize = DefaultDownloadPartSize
}
return impl.download()
}
// DownloadWithIterator will download a batched amount of objects in S3 and writes them
// to the io.WriterAt specificed in the iterator.
//
// Example:
2022-11-04 02:21:49 +00:00
//
2022-07-07 20:11:50 +00:00
// svc := s3manager.NewDownloader(session)
//
// fooFile, err := os.Open("/tmp/foo.file")
// if err != nil {
// return err
// }
//
// barFile, err := os.Open("/tmp/bar.file")
// if err != nil {
// return err
// }
//
// objects := []s3manager.BatchDownloadObject {
// {
// Object: &s3.GetObjectInput {
// Bucket: aws.String("bucket"),
// Key: aws.String("foo"),
// },
// Writer: fooFile,
// },
// {
// Object: &s3.GetObjectInput {
// Bucket: aws.String("bucket"),
// Key: aws.String("bar"),
// },
// Writer: barFile,
// },
// }
//
// iter := &s3manager.DownloadObjectsIterator{Objects: objects}
// if err := svc.DownloadWithIterator(aws.BackgroundContext(), iter); err != nil {
// return err
// }
func (d Downloader) DownloadWithIterator(ctx aws.Context, iter BatchDownloadIterator, opts ...func(*Downloader)) error {
var errs []Error
for iter.Next() {
object := iter.DownloadObject()
if _, err := d.DownloadWithContext(ctx, object.Writer, object.Object, opts...); err != nil {
errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key))
}
if object.After == nil {
continue
}
if err := object.After(); err != nil {
errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key))
}
}
if len(errs) > 0 {
return NewBatchError("BatchedDownloadIncomplete", "some objects have failed to download.", errs)
}
return nil
}
// downloader is the implementation structure used internally by Downloader.
type downloader struct {
ctx aws.Context
cfg Downloader
in *s3.GetObjectInput
w io.WriterAt
wg sync.WaitGroup
m sync.Mutex
pos int64
totalBytes int64
written int64
err error
partBodyMaxRetries int
}
// download performs the implementation of the object download across ranged
// GETs.
func (d *downloader) download() (n int64, err error) {
// If range is specified fall back to single download of that range
// this enables the functionality of ranged gets with the downloader but
// at the cost of no multipart downloads.
if rng := aws.StringValue(d.in.Range); len(rng) > 0 {
d.downloadRange(rng)
return d.written, d.err
}
// Spin off first worker to check additional header information
d.getChunk()
if total := d.getTotalBytes(); total >= 0 {
// Spin up workers
ch := make(chan dlchunk, d.cfg.Concurrency)
for i := 0; i < d.cfg.Concurrency; i++ {
d.wg.Add(1)
go d.downloadPart(ch)
}
// Assign work
for d.getErr() == nil {
if d.pos >= total {
break // We're finished queuing chunks
}
// Queue the next range of bytes to read.
ch <- dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize}
d.pos += d.cfg.PartSize
}
// Wait for completion
close(ch)
d.wg.Wait()
} else {
// Checking if we read anything new
for d.err == nil {
d.getChunk()
}
// We expect a 416 error letting us know we are done downloading the
// total bytes. Since we do not know the content's length, this will
// keep grabbing chunks of data until the range of bytes specified in
// the request is out of range of the content. Once, this happens, a
// 416 should occur.
e, ok := d.err.(awserr.RequestFailure)
if ok && e.StatusCode() == http.StatusRequestedRangeNotSatisfiable {
d.err = nil
}
}
// Return error
return d.written, d.err
}
// downloadPart is an individual goroutine worker reading from the ch channel
// and performing a GetObject request on the data with a given byte range.
//
// If this is the first worker, this operation also resolves the total number
// of bytes to be read so that the worker manager knows when it is finished.
func (d *downloader) downloadPart(ch chan dlchunk) {
defer d.wg.Done()
for {
chunk, ok := <-ch
if !ok {
break
}
if d.getErr() != nil {
// Drain the channel if there is an error, to prevent deadlocking
// of download producer.
continue
}
if err := d.downloadChunk(chunk); err != nil {
d.setErr(err)
}
}
}
// getChunk grabs a chunk of data from the body.
// Not thread safe. Should only used when grabbing data on a single thread.
func (d *downloader) getChunk() {
if d.getErr() != nil {
return
}
chunk := dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize}
d.pos += d.cfg.PartSize
if err := d.downloadChunk(chunk); err != nil {
d.setErr(err)
}
}
// downloadRange downloads an Object given the passed in Byte-Range value.
// The chunk used down download the range will be configured for that range.
func (d *downloader) downloadRange(rng string) {
if d.getErr() != nil {
return
}
chunk := dlchunk{w: d.w, start: d.pos}
// Ranges specified will short circuit the multipart download
chunk.withRange = rng
if err := d.downloadChunk(chunk); err != nil {
d.setErr(err)
}
// Update the position based on the amount of data received.
d.pos = d.written
}
// downloadChunk downloads the chunk from s3
func (d *downloader) downloadChunk(chunk dlchunk) error {
in := &s3.GetObjectInput{}
awsutil.Copy(in, d.in)
// Get the next byte range of data
in.Range = aws.String(chunk.ByteRange())
var n int64
var err error
for retry := 0; retry <= d.partBodyMaxRetries; retry++ {
n, err = d.tryDownloadChunk(in, &chunk)
if err == nil {
break
}
// Check if the returned error is an errReadingBody.
// If err is errReadingBody this indicates that an error
// occurred while copying the http response body.
// If this occurs we unwrap the err to set the underlying error
// and attempt any remaining retries.
if bodyErr, ok := err.(*errReadingBody); ok {
err = bodyErr.Unwrap()
} else {
return err
}
chunk.cur = 0
logMessage(d.cfg.S3, aws.LogDebugWithRequestRetries,
fmt.Sprintf("DEBUG: object part body download interrupted %s, err, %v, retrying attempt %d",
aws.StringValue(in.Key), err, retry))
}
d.incrWritten(n)
return err
}
func (d *downloader) tryDownloadChunk(in *s3.GetObjectInput, w io.Writer) (int64, error) {
cleanup := func() {}
if d.cfg.BufferProvider != nil {
w, cleanup = d.cfg.BufferProvider.GetReadFrom(w)
}
defer cleanup()
resp, err := d.cfg.S3.GetObjectWithContext(d.ctx, in, d.cfg.RequestOptions...)
if err != nil {
return 0, err
}
d.setTotalBytes(resp) // Set total if not yet set.
2022-11-04 02:21:49 +00:00
var src io.Reader = resp.Body
if d.cfg.BufferProvider != nil {
src = &suppressWriterAt{suppressed: src}
}
n, err := io.Copy(w, src)
2022-07-07 20:11:50 +00:00
resp.Body.Close()
if err != nil {
return n, &errReadingBody{err: err}
}
return n, nil
}
func logMessage(svc s3iface.S3API, level aws.LogLevelType, msg string) {
s, ok := svc.(*s3.S3)
if !ok {
return
}
if s.Config.Logger == nil {
return
}
if s.Config.LogLevel.Matches(level) {
s.Config.Logger.Log(msg)
}
}
// getTotalBytes is a thread-safe getter for retrieving the total byte status.
func (d *downloader) getTotalBytes() int64 {
d.m.Lock()
defer d.m.Unlock()
return d.totalBytes
}
// setTotalBytes is a thread-safe setter for setting the total byte status.
// Will extract the object's total bytes from the Content-Range if the file
// will be chunked, or Content-Length. Content-Length is used when the response
// does not include a Content-Range. Meaning the object was not chunked. This
// occurs when the full file fits within the PartSize directive.
func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
d.m.Lock()
defer d.m.Unlock()
if d.totalBytes >= 0 {
return
}
if resp.ContentRange == nil {
// ContentRange is nil when the full file contents is provided, and
// is not chunked. Use ContentLength instead.
if resp.ContentLength != nil {
d.totalBytes = *resp.ContentLength
return
}
} else {
parts := strings.Split(*resp.ContentRange, "/")
total := int64(-1)
var err error
// Checking for whether or not a numbered total exists
// If one does not exist, we will assume the total to be -1, undefined,
// and sequentially download each chunk until hitting a 416 error
totalStr := parts[len(parts)-1]
if totalStr != "*" {
total, err = strconv.ParseInt(totalStr, 10, 64)
if err != nil {
d.err = err
return
}
}
d.totalBytes = total
}
}
func (d *downloader) incrWritten(n int64) {
d.m.Lock()
defer d.m.Unlock()
d.written += n
}
// getErr is a thread-safe getter for the error object
func (d *downloader) getErr() error {
d.m.Lock()
defer d.m.Unlock()
return d.err
}
// setErr is a thread-safe setter for the error object
func (d *downloader) setErr(e error) {
d.m.Lock()
defer d.m.Unlock()
d.err = e
}
// dlchunk represents a single chunk of data to write by the worker routine.
// This structure also implements an io.SectionReader style interface for
// io.WriterAt, effectively making it an io.SectionWriter (which does not
// exist).
type dlchunk struct {
w io.WriterAt
start int64
size int64
cur int64
// specifies the byte range the chunk should be downloaded with.
withRange string
}
// Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
// position to its end (or EOF).
//
// If a range is specified on the dlchunk the size will be ignored when writing.
// as the total size may not of be known ahead of time.
func (c *dlchunk) Write(p []byte) (n int, err error) {
if c.cur >= c.size && len(c.withRange) == 0 {
return 0, io.EOF
}
n, err = c.w.WriteAt(p, c.start+c.cur)
c.cur += int64(n)
return
}
// ByteRange returns a HTTP Byte-Range header value that should be used by the
// client to request the chunk's range.
func (c *dlchunk) ByteRange() string {
if len(c.withRange) != 0 {
return c.withRange
}
return fmt.Sprintf("bytes=%d-%d", c.start, c.start+c.size-1)
}