From 8cee779bd8c800788b04daba454dceb52adead54 Mon Sep 17 00:00:00 2001 From: Bertrand Paquet Date: Mon, 16 Dec 2024 14:58:47 +0100 Subject: [PATCH] Implement a parallel download on S3 Signed-off-by: Bertrand Paquet --- cache/remotecache/s3/readerat.go | 75 --------- cache/remotecache/s3/s3.go | 174 +++++++++++--------- cache/remotecache/s3/s3_downloader.go | 219 ++++++++++++++++++++++++++ 3 files changed, 317 insertions(+), 151 deletions(-) delete mode 100644 cache/remotecache/s3/readerat.go create mode 100644 cache/remotecache/s3/s3_downloader.go diff --git a/cache/remotecache/s3/readerat.go b/cache/remotecache/s3/readerat.go deleted file mode 100644 index 666606817ec4..000000000000 --- a/cache/remotecache/s3/readerat.go +++ /dev/null @@ -1,75 +0,0 @@ -package s3 - -import ( - "io" -) - -type ReaderAtCloser interface { - io.ReaderAt - io.Closer -} - -type readerAtCloser struct { - offset int64 - rc io.ReadCloser - ra io.ReaderAt - open func(offset int64) (io.ReadCloser, error) - closed bool -} - -func toReaderAtCloser(open func(offset int64) (io.ReadCloser, error)) ReaderAtCloser { - return &readerAtCloser{ - open: open, - } -} - -func (hrs *readerAtCloser) ReadAt(p []byte, off int64) (n int, err error) { - if hrs.closed { - return 0, io.EOF - } - - if hrs.ra != nil { - return hrs.ra.ReadAt(p, off) - } - - if hrs.rc == nil || off != hrs.offset { - if hrs.rc != nil { - hrs.rc.Close() - hrs.rc = nil - } - rc, err := hrs.open(off) - if err != nil { - return 0, err - } - hrs.rc = rc - } - if ra, ok := hrs.rc.(io.ReaderAt); ok { - hrs.ra = ra - n, err = ra.ReadAt(p, off) - } else { - for { - var nn int - nn, err = hrs.rc.Read(p) - n += nn - p = p[nn:] - if nn == len(p) || err != nil { - break - } - } - } - - hrs.offset += int64(n) - return -} - -func (hrs *readerAtCloser) Close() error { - if hrs.closed { - return nil - } - hrs.closed = true - if hrs.rc != nil { - return hrs.rc.Close() - } - - return nil -} diff --git a/cache/remotecache/s3/s3.go b/cache/remotecache/s3/s3.go index ffa2ae6c3eca..1b5f8ed5fe96 100644 --- a/cache/remotecache/s3/s3.go +++ b/cache/remotecache/s3/s3.go @@ -34,36 +34,40 @@ import ( ) const ( - attrBucket = "bucket" - attrRegion = "region" - attrPrefix = "prefix" - attrManifestsPrefix = "manifests_prefix" - attrBlobsPrefix = "blobs_prefix" - attrName = "name" - attrTouchRefresh = "touch_refresh" - attrEndpointURL = "endpoint_url" - attrAccessKeyID = "access_key_id" - attrSecretAccessKey = "secret_access_key" - attrSessionToken = "session_token" - attrUsePathStyle = "use_path_style" - attrUploadParallelism = "upload_parallelism" - maxCopyObjectSize = 5 * 1024 * 1024 * 1024 + attrBucket = "bucket" + attrRegion = "region" + attrPrefix = "prefix" + attrManifestsPrefix = "manifests_prefix" + attrBlobsPrefix = "blobs_prefix" + attrName = "name" + attrTouchRefresh = "touch_refresh" + attrEndpointURL = "endpoint_url" + attrAccessKeyID = "access_key_id" + attrSecretAccessKey = "secret_access_key" + attrSessionToken = "session_token" + attrUsePathStyle = "use_path_style" + attrUploadParallelism = "upload_parallelism" + attrDownloadParallelism = "download_parallelism" + attrDownloadPartSize = "download_part_size" + maxCopyObjectSize = 5 * 1024 * 1024 * 1024 ) type Config struct { - Bucket string - Region string - Prefix string - ManifestsPrefix string - BlobsPrefix string - Names []string - TouchRefresh time.Duration - EndpointURL string - AccessKeyID string - SecretAccessKey string - SessionToken string - UsePathStyle bool - UploadParallelism int + Bucket string + Region string + Prefix string + ManifestsPrefix string + BlobsPrefix string + Names []string + TouchRefresh time.Duration + EndpointURL string + AccessKeyID string + SecretAccessKey string + SessionToken string + UsePathStyle bool + UploadParallelism int + DownloadParallelism int + DownloadPartSize int } func getConfig(attrs map[string]string) (Config, error) { @@ -141,20 +145,48 @@ func getConfig(attrs map[string]string) (Config, error) { uploadParallelism = uploadParallelismInt } + downloadParallism := 4 + downloadParallismStr, ok := attrs[attrDownloadParallelism] + if ok { + downloadParallismInt, err := strconv.Atoi(downloadParallismStr) + if err != nil { + return Config{}, errors.Errorf("download_parallelism must be a positive integer") + } + if downloadParallismInt <= 0 { + return Config{}, errors.Errorf("download_parallelism must be a positive integer") + } + downloadParallism = downloadParallismInt + } + + downloadPartSize := 5 * 1024 * 1024 + downloadPartSizeStr, ok := attrs[attrDownloadPartSize] + if ok { + downloadPartSizeInt, err := strconv.Atoi(downloadPartSizeStr) + if err != nil { + return Config{}, errors.Errorf("download_part_size must be a positive integer") + } + if downloadPartSizeInt <= 0 { + return Config{}, errors.Errorf("download_part_size must be a positive integer") + } + downloadParallism = downloadPartSizeInt + } + return Config{ - Bucket: bucket, - Region: region, - Prefix: prefix, - ManifestsPrefix: manifestsPrefix, - BlobsPrefix: blobsPrefix, - Names: names, - TouchRefresh: touchRefresh, - EndpointURL: endpointURL, - AccessKeyID: accessKeyID, - SecretAccessKey: secretAccessKey, - SessionToken: sessionToken, - UsePathStyle: usePathStyle, - UploadParallelism: uploadParallelism, + Bucket: bucket, + Region: region, + Prefix: prefix, + ManifestsPrefix: manifestsPrefix, + BlobsPrefix: blobsPrefix, + Names: names, + TouchRefresh: touchRefresh, + EndpointURL: endpointURL, + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + SessionToken: sessionToken, + UsePathStyle: usePathStyle, + UploadParallelism: uploadParallelism, + DownloadParallelism: downloadParallism, + DownloadPartSize: downloadPartSize, }, nil } @@ -385,22 +417,15 @@ func (i *importer) Resolve(ctx context.Context, _ ocispecs.Descriptor, id string return solver.NewCacheManager(ctx, id, keysStorage, resultStorage), nil } -type readerAt struct { - ReaderAtCloser - size int64 -} - -func (r *readerAt) Size() int64 { - return r.size -} - type s3Client struct { *s3.Client *manager.Uploader - bucket string - prefix string - blobsPrefix string - manifestsPrefix string + bucket string + prefix string + blobsPrefix string + manifestsPrefix string + downloadParallelism int + downloadPartSize int } func newS3Client(ctx context.Context, config Config) (*s3Client, error) { @@ -419,12 +444,14 @@ func newS3Client(ctx context.Context, config Config) (*s3Client, error) { }) return &s3Client{ - Client: client, - Uploader: manager.NewUploader(client), - bucket: config.Bucket, - prefix: config.Prefix, - blobsPrefix: config.BlobsPrefix, - manifestsPrefix: config.ManifestsPrefix, + Client: client, + Uploader: manager.NewUploader(client), + bucket: config.Bucket, + prefix: config.Prefix, + blobsPrefix: config.BlobsPrefix, + manifestsPrefix: config.ManifestsPrefix, + downloadParallelism: config.DownloadParallelism, + downloadPartSize: config.DownloadPartSize, }, nil } @@ -454,19 +481,6 @@ func (s3Client *s3Client) getManifest(ctx context.Context, key string, config *v return true, nil } -func (s3Client *s3Client) getReader(ctx context.Context, key string) (io.ReadCloser, error) { - input := &s3.GetObjectInput{ - Bucket: &s3Client.bucket, - Key: &key, - } - - output, err := s3Client.GetObject(ctx, input) - if err != nil { - return nil, err - } - return output.Body, nil -} - func (s3Client *s3Client) saveMutableAt(ctx context.Context, key string, body io.Reader) error { input := &s3.PutObjectInput{ Bucket: &s3Client.bucket, @@ -587,10 +601,18 @@ func (s3Client *s3Client) touch(ctx context.Context, key string, size *int64) (e } func (s3Client *s3Client) ReaderAt(ctx context.Context, desc ocispecs.Descriptor) (content.ReaderAt, error) { - readerAtCloser := toReaderAtCloser(func(offset int64) (io.ReadCloser, error) { - return s3Client.getReader(ctx, s3Client.blobKey(desc.Digest)) - }) - return &readerAt{ReaderAtCloser: readerAtCloser, size: desc.Size}, nil + key := s3Client.blobKey(desc.Digest) + input := &S3DownloaderInput{ + Bucket: &s3Client.bucket, + Key: &key, + S3Client: s3Client.Client, + Size: desc.Size, + Parallelism: s3Client.downloadParallelism, + PartSize: s3Client.downloadPartSize, + } + downloader := newDownloader(input) + + return downloader.Download(ctx) } func (s3Client *s3Client) manifestKey(name string) string { diff --git a/cache/remotecache/s3/s3_downloader.go b/cache/remotecache/s3/s3_downloader.go new file mode 100644 index 000000000000..9c0400aa766a --- /dev/null +++ b/cache/remotecache/s3/s3_downloader.go @@ -0,0 +1,219 @@ +package s3 + +// To be replaced by direct SDK call when available +// https://github.com/aws/aws-sdk-go-v2/issues/2247 + +import ( + "context" + "io" + "strconv" + + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/containerd/containerd/content" + "golang.org/x/sync/errgroup" +) + +type S3Downloader struct { + input *S3DownloaderInput +} + +type S3DownloaderInput struct { + Bucket *string + Key *string + S3Client *s3.Client + Size int64 + Parallelism int + PartSize int +} + +type chunkStatus struct { + writeOffset int64 + readOffset int64 + done bool + size int64 + start int64 + buffer []byte + io.Writer +} + +func (chunk *chunkStatus) Write(p []byte) (n int, err error) { + copied := copy(chunk.buffer[chunk.writeOffset:], p) + chunk.writeOffset += int64(copied) + return copied, nil +} + +type s3Reader struct { + content.ReaderAt + inChan chan int + outChan chan error + start int64 + currentOffset int64 + currentChunk int + input *S3DownloaderInput + ctx context.Context + eg *errgroup.Group + groupCtx context.Context + totalChunk int + chunks []chunkStatus +} + +func newDownloader(input *S3DownloaderInput) *S3Downloader { + if input.Parallelism == 0 { + input.Parallelism = 8 + } + if input.PartSize == 0 { + input.PartSize = 1024 * 1024 + } + return &S3Downloader{ + input: input, + } +} + +func (s *S3Downloader) Download(ctx context.Context) (content.ReaderAt, error) { + reader := &s3Reader{ + input: s.input, + currentOffset: -1, + ctx: ctx, + } + return reader, nil +} + +func (r *s3Reader) Size() int64 { + return r.input.Size +} + +func (r *s3Reader) downloadChunk(chunk int) error { + r.chunks[chunk].start = r.start + int64(chunk)*int64(r.input.PartSize) + r.chunks[chunk].size = int64(r.input.PartSize) + if r.chunks[chunk].start+r.chunks[chunk].size > r.input.Size { + r.chunks[chunk].size = r.input.Size - r.chunks[chunk].start + } + + r.chunks[chunk].buffer = make([]byte, r.chunks[chunk].size) + bytesRange := r.buildRange(r.chunks[chunk].start, r.chunks[chunk].size) + input := &s3.GetObjectInput{ + Bucket: r.input.Bucket, + Key: r.input.Key, + Range: &bytesRange, + } + resp, err := r.input.S3Client.GetObject(r.groupCtx, input) + if err != nil { + return err + } + n, err := io.Copy(&r.chunks[chunk], resp.Body) + if err != nil { + return err + } + if err := resp.Body.Close(); err != nil { + return err + } + r.chunks[chunk].size = n + r.chunks[chunk].done = true + return nil +} + +func (r *s3Reader) buildRange(start int64, size int64) string { + startRange := strconv.FormatInt(start, 10) + stopRange := strconv.FormatInt(start+size-1, 10) + return "bytes=" + startRange + "-" + stopRange +} + +func (r *s3Reader) reset(offset int64) error { + if err := r.Close(); err != nil { + return err + } + r.eg, r.groupCtx = errgroup.WithContext(r.ctx) + r.totalChunk = int((r.input.Size-offset)/int64(r.input.PartSize)) + 1 + r.chunks = make([]chunkStatus, r.totalChunk) + r.inChan = make(chan int, r.totalChunk) + r.outChan = make(chan error, r.totalChunk) + r.currentOffset = offset + r.currentChunk = 0 + r.start = offset + for i := 0; i < r.totalChunk; i++ { + r.inChan <- i + } + for k := 0; k < r.input.Parallelism; k++ { + r.eg.Go(func() error { + for item := range r.inChan { + err := r.downloadChunk(item) + r.outChan <- err + if err != nil { + return err + } + } + return nil + }) + } + return nil +} + +func (r *s3Reader) ReadAt(p []byte, off int64) (n int, err error) { + if off >= r.input.Size { + return 0, io.EOF + } + if off != r.currentOffset { + if err := r.reset(off); err != nil { + return 0, err + } + } + + pOffset := 0 + for { + copied, err := r.readAtOneChunk(p[pOffset:]) + pOffset += copied + n += copied + if err != nil { + return n, err + } + if pOffset == len(p) { + break + } + } + + return n, nil +} + +func (r *s3Reader) readAtOneChunk(p []byte) (n int, err error) { + for { + if r.chunks[r.currentChunk].done { + break + } + err := <-r.outChan + if err != nil { + r.Close() + return 0, err + } + } + + copied := copy(p, r.chunks[r.currentChunk].buffer[r.chunks[r.currentChunk].readOffset:]) + r.chunks[r.currentChunk].readOffset += int64(copied) + r.currentOffset += int64(copied) + + if r.currentOffset >= r.chunks[r.currentChunk].start+r.chunks[r.currentChunk].size { + r.chunks[r.currentChunk].buffer = nil + r.currentChunk++ + } + + if r.currentChunk == r.totalChunk { + r.Close() + return copied, io.EOF + } + + return copied, nil +} + +func (r *s3Reader) Close() error { + if r.eg != nil { + close(r.inChan) + err := r.eg.Wait() + close(r.outChan) + r.eg = nil + r.currentOffset = -1 + r.inChan = nil + r.outChan = nil + r.chunks = nil + return err + } + return nil +}