From a342c6f3683831a856c92b41d2cdb0658a56ea3f Mon Sep 17 00:00:00 2001 From: Cody Littley <56973212+cody-littley@users.noreply.github.com> Date: Thu, 31 Oct 2024 08:31:47 -0500 Subject: [PATCH 01/11] Revert "Revert "S3 relay interface" (#850)" This reverts commit 855096a864210ad1272a594671ea6db57c0525f1. --- common/aws/cli.go | 102 ++++++++-- common/aws/s3/client.go | 225 ++++++++++++++++++--- common/aws/s3/fragment.go | 128 ++++++++++++ common/aws/s3/s3.go | 39 +++- common/aws/test/client_test.go | 176 +++++++++++++++++ common/aws/test/fragment_test.go | 330 +++++++++++++++++++++++++++++++ common/mock/s3_client.go | 31 ++- inabox/deploy/localstack.go | 2 + 8 files changed, 992 insertions(+), 41 deletions(-) create mode 100644 common/aws/s3/fragment.go create mode 100644 common/aws/test/client_test.go create mode 100644 common/aws/test/fragment_test.go diff --git a/common/aws/cli.go b/common/aws/cli.go index 5a6d11503b..9a4a51b744 100644 --- a/common/aws/cli.go +++ b/common/aws/cli.go @@ -3,20 +3,48 @@ package aws import ( "github.com/Layr-Labs/eigenda/common" "github.com/urfave/cli" + "time" ) var ( - RegionFlagName = "aws.region" - AccessKeyIdFlagName = "aws.access-key-id" - SecretAccessKeyFlagName = "aws.secret-access-key" - EndpointURLFlagName = "aws.endpoint-url" + RegionFlagName = "aws.region" + AccessKeyIdFlagName = "aws.access-key-id" + SecretAccessKeyFlagName = "aws.secret-access-key" + EndpointURLFlagName = "aws.endpoint-url" + FragmentPrefixCharsFlagName = "aws.fragment-prefix-chars" + FragmentParallelismFactorFlagName = "aws.fragment-parallelism-factor" + FragmentParallelismConstantFlagName = "aws.fragment-parallelism-constant" + FragmentReadTimeoutFlagName = "aws.fragment-read-timeout" + FragmentWriteTimeoutFlagName = "aws.fragment-write-timeout" ) type ClientConfig struct { - Region string - AccessKey string + // Region is the region to use when interacting with S3. Default is "us-east-2". + Region string + // AccessKey to use when interacting with S3. + AccessKey string + // SecretAccessKey to use when interacting with S3. SecretAccessKey string - EndpointURL string + // EndpointURL of the S3 endpoint to use. If this is not set then the default AWS S3 endpoint will be used. + EndpointURL string + + // FragmentPrefixChars is the number of characters of the key to use as the prefix for fragmented files. + // A value of "3" for the key "ABCDEFG" will result in the prefix "ABC". Default is 3. + FragmentPrefixChars int + // FragmentParallelismFactor helps determine the size of the pool of workers to help upload/download files. + // A non-zero value for this parameter adds a number of workers equal to the number of cores times this value. + // Default is 8. In general, the number of workers here can be a lot larger than the number of cores because the + // workers will be blocked on I/O most of the time. + FragmentParallelismFactor int + // FragmentParallelismConstant helps determine the size of the pool of workers to help upload/download files. + // A non-zero value for this parameter adds a constant number of workers. Default is 0. + FragmentParallelismConstant int + // FragmentReadTimeout is used to bound the maximum time to wait for a single fragmented read. + // Default is 30 seconds. + FragmentReadTimeout time.Duration + // FragmentWriteTimeout is used to bound the maximum time to wait for a single fragmented write. + // Default is 30 seconds. + FragmentWriteTimeout time.Duration } func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { @@ -48,14 +76,66 @@ func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { Value: "", EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ENDPOINT_URL"), }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), + Usage: "The number of characters of the key to use as the prefix for fragmented files", + Required: false, + Value: 3, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"), + }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), + Usage: "Add this many threads times the number of cores to the worker pool", + Required: false, + Value: 8, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"), + }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName), + Usage: "Add this many threads to the worker pool", + Required: false, + Value: 0, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"), + }, + cli.DurationFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName), + Usage: "The maximum time to wait for a single fragmented read", + Required: false, + Value: 30 * time.Second, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"), + }, + cli.DurationFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName), + Usage: "The maximum time to wait for a single fragmented write", + Required: false, + Value: 30 * time.Second, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"), + }, } } func ReadClientConfig(ctx *cli.Context, flagPrefix string) ClientConfig { return ClientConfig{ - Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), - AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), - SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), - EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), + Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), + AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), + SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), + EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), + FragmentPrefixChars: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName)), + FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)), + FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)), + FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)), + FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)), + } +} + +// DefaultClientConfig returns a new ClientConfig with default values. +func DefaultClientConfig() *ClientConfig { + return &ClientConfig{ + Region: "us-east-2", + FragmentPrefixChars: 3, + FragmentParallelismFactor: 8, + FragmentParallelismConstant: 0, + FragmentReadTimeout: 30 * time.Second, + FragmentWriteTimeout: 30 * time.Second, } } diff --git a/common/aws/s3/client.go b/common/aws/s3/client.go index 231d546ae6..8a88318117 100644 --- a/common/aws/s3/client.go +++ b/common/aws/s3/client.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "errors" + "github.com/gammazero/workerpool" + "runtime" "sync" commonaws "github.com/Layr-Labs/eigenda/common/aws" @@ -27,7 +29,9 @@ type Object struct { } type client struct { + cfg *commonaws.ClientConfig s3Client *s3.Client + pool *workerpool.WorkerPool logger logging.Logger } @@ -36,18 +40,19 @@ var _ Client = (*client)(nil) func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.Logger) (*client, error) { var err error once.Do(func() { - customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { - if cfg.EndpointURL != "" { - return aws.Endpoint{ - PartitionID: "aws", - URL: cfg.EndpointURL, - SigningRegion: cfg.Region, - }, nil - } - - // returning EndpointNotFoundError will allow the service to fallback to its default resolution - return aws.Endpoint{}, &aws.EndpointNotFoundError{} - }) + customResolver := aws.EndpointResolverWithOptionsFunc( + func(service, region string, options ...interface{}) (aws.Endpoint, error) { + if cfg.EndpointURL != "" { + return aws.Endpoint{ + PartitionID: "aws", + URL: cfg.EndpointURL, + SigningRegion: cfg.Region, + }, nil + } + + // returning EndpointNotFoundError will allow the service to fallback to its default resolution + return aws.Endpoint{}, &aws.EndpointNotFoundError{} + }) options := [](func(*config.LoadOptions) error){ config.WithRegion(cfg.Region), @@ -56,7 +61,9 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L } // If access key and secret access key are not provided, use the default credential provider if len(cfg.AccessKey) > 0 && len(cfg.SecretAccessKey) > 0 { - options = append(options, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) + options = append(options, + config.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) } awsConfig, errCfg := config.LoadDefaultConfig(context.Background(), options...) @@ -64,23 +71,32 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L err = errCfg return } + s3Client := s3.NewFromConfig(awsConfig, func(o *s3.Options) { o.UsePathStyle = true }) - ref = &client{s3Client: s3Client, logger: logger.With("component", "S3Client")} - }) - return ref, err -} -func (s *client) CreateBucket(ctx context.Context, bucket string) error { - _, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ - Bucket: aws.String(bucket), - }) - if err != nil { - return err - } + workers := 0 + if cfg.FragmentParallelismConstant > 0 { + workers = cfg.FragmentParallelismConstant + } + if cfg.FragmentParallelismFactor > 0 { + workers = cfg.FragmentParallelismFactor * runtime.NumCPU() + } - return nil + if workers == 0 { + workers = 1 + } + pool := workerpool.New(workers) + + ref = &client{ + cfg: &cfg, + s3Client: s3Client, + pool: pool, + logger: logger.With("component", "S3Client"), + } + }) + return ref, err } func (s *client) DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) { @@ -159,3 +175,162 @@ func (s *client) ListObjects(ctx context.Context, bucket string, prefix string) } return objects, nil } + +func (s *client) CreateBucket(ctx context.Context, bucket string) error { + _, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: aws.String(bucket), + }) + if err != nil { + return err + } + + return nil +} + +func (s *client) FragmentedUploadObject( + ctx context.Context, + bucket string, + key string, + data []byte, + fragmentSize int) error { + + fragments, err := BreakIntoFragments(key, data, s.cfg.FragmentPrefixChars, fragmentSize) + if err != nil { + return err + } + resultChannel := make(chan error, len(fragments)) + + ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) + defer cancel() + + for _, fragment := range fragments { + fragmentCapture := fragment + s.pool.Submit(func() { + s.fragmentedWriteTask(ctx, resultChannel, fragmentCapture, bucket) + }) + } + + for range fragments { + err := <-resultChannel + if err != nil { + return err + } + } + return ctx.Err() + +} + +// fragmentedWriteTask writes a single file to S3. +func (s *client) fragmentedWriteTask( + ctx context.Context, + resultChannel chan error, + fragment *Fragment, + bucket string) { + + _, err := s.s3Client.PutObject(ctx, + &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fragment.FragmentKey), + Body: bytes.NewReader(fragment.Data), + }) + + resultChannel <- err +} + +func (s *client) FragmentedDownloadObject( + ctx context.Context, + bucket string, + key string, + fileSize int, + fragmentSize int) ([]byte, error) { + + if fragmentSize <= 0 { + return nil, errors.New("fragmentSize must be greater than 0") + } + + fragmentKeys, err := GetFragmentKeys(key, s.cfg.FragmentPrefixChars, GetFragmentCount(fileSize, fragmentSize)) + if err != nil { + return nil, err + } + resultChannel := make(chan *readResult, len(fragmentKeys)) + + ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) + defer cancel() + + for i, fragmentKey := range fragmentKeys { + boundFragmentKey := fragmentKey + boundI := i + s.pool.Submit(func() { + s.readTask(ctx, resultChannel, bucket, boundFragmentKey, boundI) + }) + } + + fragments := make([]*Fragment, len(fragmentKeys)) + for i := 0; i < len(fragmentKeys); i++ { + result := <-resultChannel + if result.err != nil { + return nil, result.err + } + fragments[result.fragment.Index] = result.fragment + } + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + return RecombineFragments(fragments) + +} + +// readResult is the result of a read task. +type readResult struct { + fragment *Fragment + err error +} + +// readTask reads a single file from S3. +func (s *client) readTask( + ctx context.Context, + resultChannel chan *readResult, + bucket string, + key string, + index int) { + + result := &readResult{} + defer func() { + resultChannel <- result + }() + + ret, err := s.s3Client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + + if err != nil { + result.err = err + return + } + + data := make([]byte, *ret.ContentLength) + bytesRead := 0 + + for bytesRead < len(data) && ctx.Err() == nil { + count, err := ret.Body.Read(data[bytesRead:]) + if err != nil && err.Error() != "EOF" { + result.err = err + return + } + bytesRead += count + } + + result.fragment = &Fragment{ + FragmentKey: key, + Data: data, + Index: index, + } + + err = ret.Body.Close() + if err != nil { + result.err = err + } +} diff --git a/common/aws/s3/fragment.go b/common/aws/s3/fragment.go new file mode 100644 index 0000000000..6f978fbdc6 --- /dev/null +++ b/common/aws/s3/fragment.go @@ -0,0 +1,128 @@ +package s3 + +import ( + "fmt" + "sort" + "strings" +) + +// GetFragmentCount returns the number of fragments that a file of the given size will be broken into. +func GetFragmentCount(fileSize int, fragmentSize int) int { + if fileSize < fragmentSize { + return 1 + } else if fileSize%fragmentSize == 0 { + return fileSize / fragmentSize + } else { + return fileSize/fragmentSize + 1 + } +} + +// GetFragmentKey returns the key for the fragment at the given index. +// +// Fragment keys take the form of "prefix/body-index[f]". The prefix is the first prefixLength characters +// of the file key. The body is the file key. The index is the index of the fragment. The character "f" is appended +// to the key of the last fragment in the series. +// +// Example: fileKey="abc123", prefixLength=2, fragmentCount=3 +// The keys will be "ab/abc123-0", "ab/abc123-1", "ab/abc123-2f" +func GetFragmentKey(fileKey string, prefixLength int, fragmentCount int, index int) (string, error) { + var prefix string + if prefixLength > len(fileKey) { + prefix = fileKey + } else { + prefix = fileKey[:prefixLength] + } + + postfix := "" + if fragmentCount-1 == index { + postfix = "f" + } + + if index >= fragmentCount { + return "", fmt.Errorf("index %d is too high for fragment count %d", index, fragmentCount) + } + + return fmt.Sprintf("%s/%s-%d%s", prefix, fileKey, index, postfix), nil +} + +// Fragment is a subset of a file. +type Fragment struct { + FragmentKey string + Data []byte + Index int +} + +// BreakIntoFragments breaks a file into fragments of the given size. +func BreakIntoFragments(fileKey string, data []byte, prefixLength int, fragmentSize int) ([]*Fragment, error) { + fragmentCount := GetFragmentCount(len(data), fragmentSize) + fragments := make([]*Fragment, fragmentCount) + for i := 0; i < fragmentCount; i++ { + start := i * fragmentSize + end := start + fragmentSize + if end > len(data) { + end = len(data) + } + + fragmentKey, err := GetFragmentKey(fileKey, prefixLength, fragmentCount, i) + if err != nil { + return nil, err + } + fragments[i] = &Fragment{ + FragmentKey: fragmentKey, + Data: data[start:end], + Index: i, + } + } + return fragments, nil +} + +// GetFragmentKeys returns the keys for all fragments of a file. +func GetFragmentKeys(fileKey string, prefixLength int, fragmentCount int) ([]string, error) { + keys := make([]string, fragmentCount) + for i := 0; i < fragmentCount; i++ { + fragmentKey, err := GetFragmentKey(fileKey, prefixLength, fragmentCount, i) + if err != nil { + return nil, err + } + keys[i] = fragmentKey + } + return keys, nil +} + +// RecombineFragments recombines fragments into a single file. +// Returns an error if any fragments are missing. +func RecombineFragments(fragments []*Fragment) ([]byte, error) { + + if len(fragments) == 0 { + return nil, fmt.Errorf("no fragments") + } + + // Sort the fragments by index + sort.Slice(fragments, func(i, j int) bool { + return fragments[i].Index < fragments[j].Index + }) + + // Make sure there aren't any gaps in the fragment indices + dataSize := 0 + for i, fragment := range fragments { + if fragment.Index != i { + return nil, fmt.Errorf("missing fragment with index %d", i) + } + dataSize += len(fragment.Data) + } + + // Make sure we have the last fragment + if !strings.HasSuffix(fragments[len(fragments)-1].FragmentKey, "f") { + return nil, fmt.Errorf("missing final fragment") + } + + fragmentSize := len(fragments[0].Data) + + // Concatenate the data + result := make([]byte, dataSize) + for _, fragment := range fragments { + copy(result[fragment.Index*fragmentSize:], fragment.Data) + } + + return result, nil +} diff --git a/common/aws/s3/s3.go b/common/aws/s3/s3.go index 475f68c941..74089099a9 100644 --- a/common/aws/s3/s3.go +++ b/common/aws/s3/s3.go @@ -2,10 +2,47 @@ package s3 import "context" +// Client encapsulates the functionality of an S3 client. type Client interface { - CreateBucket(ctx context.Context, bucket string) error + + // DownloadObject downloads an object from S3. DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) + + // UploadObject uploads an object to S3. UploadObject(ctx context.Context, bucket string, key string, data []byte) error + + // DeleteObject deletes an object from S3. DeleteObject(ctx context.Context, bucket string, key string) error + + // ListObjects lists all objects in a bucket with the given prefix. Note that this method may return + // file fragments if the bucket contains files uploaded via FragmentedUploadObject. ListObjects(ctx context.Context, bucket string, prefix string) ([]Object, error) + + // CreateBucket creates a bucket in S3. + CreateBucket(ctx context.Context, bucket string) error + + // FragmentedUploadObject uploads a file to S3. The fragmentSize parameter specifies the maximum size of each + // file uploaded to S3. If the file is larger than fragmentSize then it will be broken into + // smaller parts and uploaded in parallel. The file will be reassembled on download. + // + // Note: if a file is uploaded with this method, only the FragmentedDownloadObject method should be used to + // download the file. It is not advised to use DeleteObject on files uploaded with this method (if such + // functionality is required, a new method to do so should be added to this interface). + FragmentedUploadObject( + ctx context.Context, + bucket string, + key string, + data []byte, + fragmentSize int) error + + // FragmentedDownloadObject downloads a file from S3, as written by Upload. The fileSize (in bytes) and fragmentSize + // must be the same as the values used in the FragmentedUploadObject call. + // + // Note: this method can only be used to download files that were uploaded with the FragmentedUploadObject method. + FragmentedDownloadObject( + ctx context.Context, + bucket string, + key string, + fileSize int, + fragmentSize int) ([]byte, error) } diff --git a/common/aws/test/client_test.go b/common/aws/test/client_test.go new file mode 100644 index 0000000000..0f5bc4087a --- /dev/null +++ b/common/aws/test/client_test.go @@ -0,0 +1,176 @@ +package test + +import ( + "context" + "github.com/Layr-Labs/eigenda/common" + "github.com/Layr-Labs/eigenda/common/aws" + "github.com/Layr-Labs/eigenda/common/aws/s3" + "github.com/Layr-Labs/eigenda/common/mock" + tu "github.com/Layr-Labs/eigenda/common/testutils" + "github.com/Layr-Labs/eigenda/inabox/deploy" + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/assert" + "math/rand" + "os" + "testing" +) + +var ( + dockertestPool *dockertest.Pool + dockertestResource *dockertest.Resource +) + +const ( + localstackPort = "4570" + localstackHost = "http://0.0.0.0:4570" + bucket = "eigen-test" +) + +type clientBuilder struct { + // This method is called at the beginning of the test. + start func() error + // This method is called to build a new client. + build func() (s3.Client, error) + // This method is called at the end of the test when all operations are done. + finish func() error +} + +var clientBuilders = []*clientBuilder{ + { + start: func() error { + return nil + }, + build: func() (s3.Client, error) { + return mock.NewS3Client(), nil + }, + finish: func() error { + return nil + }, + }, + { + start: func() error { + return setupLocalstack() + }, + build: func() (s3.Client, error) { + + logger, err := common.NewLogger(common.DefaultLoggerConfig()) + if err != nil { + return nil, err + } + + config := aws.DefaultClientConfig() + config.EndpointURL = localstackHost + config.Region = "us-east-1" + + err = os.Setenv("AWS_ACCESS_KEY_ID", "localstack") + if err != nil { + return nil, err + } + err = os.Setenv("AWS_SECRET_ACCESS_KEY", "localstack") + if err != nil { + return nil, err + } + + client, err := s3.NewClient(context.Background(), *config, logger) + if err != nil { + return nil, err + } + + err = client.CreateBucket(context.Background(), bucket) + if err != nil { + return nil, err + } + + return client, nil + }, + finish: func() error { + teardownLocalstack() + return nil + }, + }, +} + +func setupLocalstack() error { + deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") + + if deployLocalStack { + var err error + dockertestPool, dockertestResource, err = deploy.StartDockertestWithLocalstackContainer(localstackPort) + if err != nil && err.Error() == "container already exists" { + teardownLocalstack() + return err + } + } + return nil +} + +func teardownLocalstack() { + deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") + + if deployLocalStack { + deploy.PurgeDockertestResources(dockertestPool, dockertestResource) + } +} + +func RandomOperationsTest(t *testing.T, client s3.Client) { + numberToWrite := 100 + expectedData := make(map[string][]byte) + + fragmentSize := rand.Intn(1000) + 1000 + + for i := 0; i < numberToWrite; i++ { + key := tu.RandomString(10) + fragmentMultiple := rand.Float64() * 10 + dataSize := int(fragmentMultiple*float64(fragmentSize)) + 1 + data := tu.RandomBytes(dataSize) + expectedData[key] = data + + err := client.FragmentedUploadObject(context.Background(), bucket, key, data, fragmentSize) + assert.NoError(t, err) + } + + // Read back the data + for key, expected := range expectedData { + data, err := client.FragmentedDownloadObject(context.Background(), bucket, key, len(expected), fragmentSize) + assert.NoError(t, err) + assert.Equal(t, expected, data) + } +} + +func TestRandomOperations(t *testing.T) { + tu.InitializeRandom() + for _, builder := range clientBuilders { + err := builder.start() + assert.NoError(t, err) + + client, err := builder.build() + assert.NoError(t, err) + RandomOperationsTest(t, client) + + err = builder.finish() + assert.NoError(t, err) + } +} + +func ReadNonExistentValueTest(t *testing.T, client s3.Client) { + _, err := client.FragmentedDownloadObject(context.Background(), bucket, "nonexistent", 1000, 1000) + assert.Error(t, err) + randomKey := tu.RandomString(10) + _, err = client.FragmentedDownloadObject(context.Background(), bucket, randomKey, 0, 0) + assert.Error(t, err) +} + +func TestReadNonExistentValue(t *testing.T) { + tu.InitializeRandom() + for _, builder := range clientBuilders { + err := builder.start() + assert.NoError(t, err) + + client, err := builder.build() + assert.NoError(t, err) + ReadNonExistentValueTest(t, client) + + err = builder.finish() + assert.NoError(t, err) + } +} diff --git a/common/aws/test/fragment_test.go b/common/aws/test/fragment_test.go new file mode 100644 index 0000000000..fc5a257731 --- /dev/null +++ b/common/aws/test/fragment_test.go @@ -0,0 +1,330 @@ +package test + +import ( + "fmt" + "github.com/Layr-Labs/eigenda/common/aws/s3" + tu "github.com/Layr-Labs/eigenda/common/testutils" + "github.com/stretchr/testify/assert" + "math/rand" + "strings" + "testing" +) + +func TestGetFragmentCount(t *testing.T) { + tu.InitializeRandom() + + // Test a file smaller than a fragment + fileSize := rand.Intn(100) + 100 + fragmentSize := fileSize * 2 + fragmentCount := s3.GetFragmentCount(fileSize, fragmentSize) + assert.Equal(t, 1, fragmentCount) + + // Test a file that can fit in a single fragment + fileSize = rand.Intn(100) + 100 + fragmentSize = fileSize + fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) + assert.Equal(t, 1, fragmentCount) + + // Test a file that is one byte larger than a fragment + fileSize = rand.Intn(100) + 100 + fragmentSize = fileSize - 1 + fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) + assert.Equal(t, 2, fragmentCount) + + // Test a file that is one less than a multiple of the fragment size + fragmentSize = rand.Intn(100) + 100 + expectedFragmentCount := rand.Intn(10) + 1 + fileSize = fragmentSize*expectedFragmentCount - 1 + fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) + assert.Equal(t, expectedFragmentCount, fragmentCount) + + // Test a file that is a multiple of the fragment size + fragmentSize = rand.Intn(100) + 100 + expectedFragmentCount = rand.Intn(10) + 1 + fileSize = fragmentSize * expectedFragmentCount + fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) + assert.Equal(t, expectedFragmentCount, fragmentCount) + + // Test a file that is one more than a multiple of the fragment size + fragmentSize = rand.Intn(100) + 100 + expectedFragmentCount = rand.Intn(10) + 2 + fileSize = fragmentSize*(expectedFragmentCount-1) + 1 + fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) + assert.Equal(t, expectedFragmentCount, fragmentCount) +} + +// Fragment keys take the form of "prefix/body-index[f]". Verify the prefix part of the key. +func TestPrefix(t *testing.T) { + tu.InitializeRandom() + + keyLength := rand.Intn(10) + 10 + key := tu.RandomString(keyLength) + + for i := 0; i < keyLength*2; i++ { + fragmentCount := rand.Intn(10) + 10 + fragmentIndex := rand.Intn(fragmentCount) + fragmentKey, err := s3.GetFragmentKey(key, i, fragmentCount, fragmentIndex) + assert.NoError(t, err) + + parts := strings.Split(fragmentKey, "/") + assert.Equal(t, 2, len(parts)) + prefix := parts[0] + + if i >= keyLength { + assert.Equal(t, key, prefix) + } else { + assert.Equal(t, key[:i], prefix) + } + } +} + +// Fragment keys take the form of "prefix/body-index[f]". Verify the body part of the key. +func TestKeyBody(t *testing.T) { + tu.InitializeRandom() + + for i := 0; i < 10; i++ { + keyLength := rand.Intn(10) + 10 + key := tu.RandomString(keyLength) + fragmentCount := rand.Intn(10) + 10 + fragmentIndex := rand.Intn(fragmentCount) + fragmentKey, err := s3.GetFragmentKey(key, rand.Intn(10), fragmentCount, fragmentIndex) + assert.NoError(t, err) + + parts := strings.Split(fragmentKey, "/") + assert.Equal(t, 2, len(parts)) + parts = strings.Split(parts[1], "-") + assert.Equal(t, 2, len(parts)) + body := parts[0] + + assert.Equal(t, key, body) + } +} + +// Fragment keys take the form of "prefix/body-index[f]". Verify the index part of the key. +func TestKeyIndex(t *testing.T) { + tu.InitializeRandom() + + for i := 0; i < 10; i++ { + fragmentCount := rand.Intn(10) + 10 + index := rand.Intn(fragmentCount) + fragmentKey, err := s3.GetFragmentKey(tu.RandomString(10), rand.Intn(10), fragmentCount, index) + assert.NoError(t, err) + + parts := strings.Split(fragmentKey, "/") + assert.Equal(t, 2, len(parts)) + parts = strings.Split(parts[1], "-") + assert.Equal(t, 2, len(parts)) + indexStr := parts[1] + assert.True(t, strings.HasPrefix(indexStr, fmt.Sprintf("%d", index))) + } +} + +// Fragment keys take the form of "prefix/body-index[f]". +// Verify the postfix part of the key, which should be "f" for the last fragment. +func TestKeyPostfix(t *testing.T) { + tu.InitializeRandom() + + segmentCount := rand.Intn(10) + 10 + + for i := 0; i < segmentCount; i++ { + fragmentKey, err := s3.GetFragmentKey(tu.RandomString(10), rand.Intn(10), segmentCount, i) + assert.NoError(t, err) + + if i == segmentCount-1 { + assert.True(t, strings.HasSuffix(fragmentKey, "f")) + } else { + assert.False(t, strings.HasSuffix(fragmentKey, "f")) + } + } +} + +func TestExampleInGodoc(t *testing.T) { + fileKey := "abc123" + prefixLength := 2 + fragmentCount := 3 + fragmentKeys, err := s3.GetFragmentKeys(fileKey, prefixLength, fragmentCount) + assert.NoError(t, err) + assert.Equal(t, 3, len(fragmentKeys)) + assert.Equal(t, "ab/abc123-0", fragmentKeys[0]) + assert.Equal(t, "ab/abc123-1", fragmentKeys[1]) + assert.Equal(t, "ab/abc123-2f", fragmentKeys[2]) +} + +func TestGetFragmentKeys(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + prefixLength := rand.Intn(3) + 1 + fragmentCount := rand.Intn(10) + 10 + + fragmentKeys, err := s3.GetFragmentKeys(fileKey, prefixLength, fragmentCount) + assert.NoError(t, err) + assert.Equal(t, fragmentCount, len(fragmentKeys)) + + for i := 0; i < fragmentCount; i++ { + expectedKey, err := s3.GetFragmentKey(fileKey, prefixLength, fragmentCount, i) + assert.NoError(t, err) + assert.Equal(t, expectedKey, fragmentKeys[i]) + + parts := strings.Split(fragmentKeys[i], "/") + assert.Equal(t, 2, len(parts)) + parsedPrefix := parts[0] + assert.Equal(t, fileKey[:prefixLength], parsedPrefix) + parts = strings.Split(parts[1], "-") + assert.Equal(t, 2, len(parts)) + parsedKey := parts[0] + assert.Equal(t, fileKey, parsedKey) + index := parts[1] + + if i == fragmentCount-1 { + assert.Equal(t, fmt.Sprintf("%d", i)+"f", index) + } else { + assert.Equal(t, fmt.Sprintf("%d", i), index) + } + } +} + +func TestGetFragments(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + data := tu.RandomBytes(1000) + prefixLength := rand.Intn(3) + 1 + fragmentSize := rand.Intn(100) + 100 + + fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + assert.Equal(t, s3.GetFragmentCount(len(data), fragmentSize), len(fragments)) + + totalSize := 0 + + for i, fragment := range fragments { + fragmentKey, err := s3.GetFragmentKey(fileKey, prefixLength, len(fragments), i) + assert.NoError(t, err) + assert.Equal(t, fragmentKey, fragment.FragmentKey) + + start := i * fragmentSize + end := start + fragmentSize + if end > len(data) { + end = len(data) + } + assert.Equal(t, data[start:end], fragment.Data) + assert.Equal(t, i, fragment.Index) + totalSize += len(fragment.Data) + } + + assert.Equal(t, len(data), totalSize) +} + +func TestGetFragmentsSmallFile(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + data := tu.RandomBytes(10) + prefixLength := rand.Intn(3) + 1 + fragmentSize := rand.Intn(100) + 100 + + fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + assert.Equal(t, 1, len(fragments)) + + fragmentKey, err := s3.GetFragmentKey(fileKey, prefixLength, 1, 0) + assert.NoError(t, err) + assert.Equal(t, fragmentKey, fragments[0].FragmentKey) + assert.Equal(t, data, fragments[0].Data) + assert.Equal(t, 0, fragments[0].Index) +} + +func TestGetFragmentsExactlyOnePerfectlySizedFile(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + fragmentSize := rand.Intn(100) + 100 + data := tu.RandomBytes(fragmentSize) + prefixLength := rand.Intn(3) + 1 + + fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + assert.Equal(t, 1, len(fragments)) + + fragmentKey, err := s3.GetFragmentKey(fileKey, prefixLength, 1, 0) + assert.NoError(t, err) + assert.Equal(t, fragmentKey, fragments[0].FragmentKey) + assert.Equal(t, data, fragments[0].Data) + assert.Equal(t, 0, fragments[0].Index) +} + +func TestRecombineFragments(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + data := tu.RandomBytes(1000) + prefixLength := rand.Intn(3) + 1 + fragmentSize := rand.Intn(100) + 100 + + fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + recombinedData, err := s3.RecombineFragments(fragments) + assert.NoError(t, err) + assert.Equal(t, data, recombinedData) + + // Shuffle the fragments + for i := range fragments { + j := rand.Intn(i + 1) + fragments[i], fragments[j] = fragments[j], fragments[i] + } + + recombinedData, err = s3.RecombineFragments(fragments) + assert.NoError(t, err) + assert.Equal(t, data, recombinedData) +} + +func TestRecombineFragmentsSmallFile(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + data := tu.RandomBytes(10) + prefixLength := rand.Intn(3) + 1 + fragmentSize := rand.Intn(100) + 100 + + fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + assert.Equal(t, 1, len(fragments)) + recombinedData, err := s3.RecombineFragments(fragments) + assert.NoError(t, err) + assert.Equal(t, data, recombinedData) +} + +func TestMissingFragment(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + data := tu.RandomBytes(1000) + prefixLength := rand.Intn(3) + 1 + fragmentSize := rand.Intn(100) + 100 + + fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + + fragmentIndexToSkip := rand.Intn(len(fragments)) + fragments = append(fragments[:fragmentIndexToSkip], fragments[fragmentIndexToSkip+1:]...) + + _, err = s3.RecombineFragments(fragments[:len(fragments)-1]) + assert.Error(t, err) +} + +func TestMissingFinalFragment(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + data := tu.RandomBytes(1000) + prefixLength := rand.Intn(3) + 1 + fragmentSize := rand.Intn(100) + 100 + + fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + fragments = fragments[:len(fragments)-1] + + _, err = s3.RecombineFragments(fragments) + assert.Error(t, err) +} diff --git a/common/mock/s3_client.go b/common/mock/s3_client.go index d4e79645b0..7f505d56aa 100644 --- a/common/mock/s3_client.go +++ b/common/mock/s3_client.go @@ -17,10 +17,6 @@ func NewS3Client() *S3Client { return &S3Client{bucket: make(map[string][]byte)} } -func (s *S3Client) CreateBucket(ctx context.Context, bucket string) error { - return nil -} - func (s *S3Client) DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) { data, ok := s.bucket[key] if !ok { @@ -48,3 +44,30 @@ func (s *S3Client) ListObjects(ctx context.Context, bucket string, prefix string } return objects, nil } + +func (s *S3Client) CreateBucket(ctx context.Context, bucket string) error { + return nil +} + +func (s *S3Client) FragmentedUploadObject( + ctx context.Context, + bucket string, + key string, + data []byte, + fragmentSize int) error { + s.bucket[key] = data + return nil +} + +func (s *S3Client) FragmentedDownloadObject( + ctx context.Context, + bucket string, + key string, + fileSize int, + fragmentSize int) ([]byte, error) { + data, ok := s.bucket[key] + if !ok { + return []byte{}, s3.ErrObjectNotFound + } + return data, nil +} diff --git a/inabox/deploy/localstack.go b/inabox/deploy/localstack.go index 020f807b65..6a89bbf6ed 100644 --- a/inabox/deploy/localstack.go +++ b/inabox/deploy/localstack.go @@ -8,6 +8,7 @@ import ( "net/http" "path/filepath" "runtime" + "runtime/debug" "time" "github.com/Layr-Labs/eigenda/common/aws" @@ -20,6 +21,7 @@ import ( ) func StartDockertestWithLocalstackContainer(localStackPort string) (*dockertest.Pool, *dockertest.Resource, error) { + debug.PrintStack() // TODO do not merge fmt.Println("Starting Localstack container") pool, err := dockertest.NewPool("") if err != nil { From fe82f90b6d5c41bda07579a92ca8f16b2b21606a Mon Sep 17 00:00:00 2001 From: Cody Littley Date: Thu, 31 Oct 2024 08:55:02 -0500 Subject: [PATCH 02/11] Remove debug code. Signed-off-by: Cody Littley --- inabox/deploy/localstack.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/inabox/deploy/localstack.go b/inabox/deploy/localstack.go index 6a89bbf6ed..020f807b65 100644 --- a/inabox/deploy/localstack.go +++ b/inabox/deploy/localstack.go @@ -8,7 +8,6 @@ import ( "net/http" "path/filepath" "runtime" - "runtime/debug" "time" "github.com/Layr-Labs/eigenda/common/aws" @@ -21,7 +20,6 @@ import ( ) func StartDockertestWithLocalstackContainer(localStackPort string) (*dockertest.Pool, *dockertest.Resource, error) { - debug.PrintStack() // TODO do not merge fmt.Println("Starting Localstack container") pool, err := dockertest.NewPool("") if err != nil { From 51397178796cd63d62369227e1f2aaf8c36dee0a Mon Sep 17 00:00:00 2001 From: Cody Littley Date: Thu, 31 Oct 2024 09:11:57 -0500 Subject: [PATCH 03/11] Temporarily disable code Signed-off-by: Cody Littley --- common/aws/s3/client-copy.txt | 336 +++++++++++++++++++++++++++++++ common/aws/s3/s3.go | 46 ++--- common/aws/test/client_test.go | 348 ++++++++++++++++----------------- 3 files changed, 533 insertions(+), 197 deletions(-) create mode 100644 common/aws/s3/client-copy.txt diff --git a/common/aws/s3/client-copy.txt b/common/aws/s3/client-copy.txt new file mode 100644 index 0000000000..8a88318117 --- /dev/null +++ b/common/aws/s3/client-copy.txt @@ -0,0 +1,336 @@ +package s3 + +import ( + "bytes" + "context" + "errors" + "github.com/gammazero/workerpool" + "runtime" + "sync" + + commonaws "github.com/Layr-Labs/eigenda/common/aws" + "github.com/Layr-Labs/eigensdk-go/logging" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +var ( + once sync.Once + ref *client + ErrObjectNotFound = errors.New("object not found") +) + +type Object struct { + Key string + Size int64 +} + +type client struct { + cfg *commonaws.ClientConfig + s3Client *s3.Client + pool *workerpool.WorkerPool + logger logging.Logger +} + +var _ Client = (*client)(nil) + +func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.Logger) (*client, error) { + var err error + once.Do(func() { + customResolver := aws.EndpointResolverWithOptionsFunc( + func(service, region string, options ...interface{}) (aws.Endpoint, error) { + if cfg.EndpointURL != "" { + return aws.Endpoint{ + PartitionID: "aws", + URL: cfg.EndpointURL, + SigningRegion: cfg.Region, + }, nil + } + + // returning EndpointNotFoundError will allow the service to fallback to its default resolution + return aws.Endpoint{}, &aws.EndpointNotFoundError{} + }) + + options := [](func(*config.LoadOptions) error){ + config.WithRegion(cfg.Region), + config.WithEndpointResolverWithOptions(customResolver), + config.WithRetryMode(aws.RetryModeStandard), + } + // If access key and secret access key are not provided, use the default credential provider + if len(cfg.AccessKey) > 0 && len(cfg.SecretAccessKey) > 0 { + options = append(options, + config.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) + } + awsConfig, errCfg := config.LoadDefaultConfig(context.Background(), options...) + + if errCfg != nil { + err = errCfg + return + } + + s3Client := s3.NewFromConfig(awsConfig, func(o *s3.Options) { + o.UsePathStyle = true + }) + + workers := 0 + if cfg.FragmentParallelismConstant > 0 { + workers = cfg.FragmentParallelismConstant + } + if cfg.FragmentParallelismFactor > 0 { + workers = cfg.FragmentParallelismFactor * runtime.NumCPU() + } + + if workers == 0 { + workers = 1 + } + pool := workerpool.New(workers) + + ref = &client{ + cfg: &cfg, + s3Client: s3Client, + pool: pool, + logger: logger.With("component", "S3Client"), + } + }) + return ref, err +} + +func (s *client) DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) { + var partMiBs int64 = 10 + downloader := manager.NewDownloader(s.s3Client, func(d *manager.Downloader) { + d.PartSize = partMiBs * 1024 * 1024 // 10MB per part + d.Concurrency = 3 //The number of goroutines to spin up in parallel per call to Upload when sending parts + }) + + buffer := manager.NewWriteAtBuffer([]byte{}) + _, err := downloader.Download(ctx, buffer, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return nil, err + } + + if buffer == nil || len(buffer.Bytes()) == 0 { + return nil, ErrObjectNotFound + } + + return buffer.Bytes(), nil +} + +func (s *client) UploadObject(ctx context.Context, bucket string, key string, data []byte) error { + var partMiBs int64 = 10 + uploader := manager.NewUploader(s.s3Client, func(u *manager.Uploader) { + u.PartSize = partMiBs * 1024 * 1024 // 10MiB per part + u.Concurrency = 3 //The number of goroutines to spin up in parallel per call to upload when sending parts + }) + + _, err := uploader.Upload(ctx, &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + Body: bytes.NewReader(data), + }) + if err != nil { + return err + } + + return nil +} + +func (s *client) DeleteObject(ctx context.Context, bucket string, key string) error { + _, err := s.s3Client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return err + } + + return err +} + +func (s *client) ListObjects(ctx context.Context, bucket string, prefix string) ([]Object, error) { + output, err := s.s3Client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), + Prefix: aws.String(prefix), + }) + if err != nil { + return nil, err + } + + objects := make([]Object, 0, len(output.Contents)) + for _, object := range output.Contents { + var size int64 = 0 + if object.Size != nil { + size = *object.Size + } + objects = append(objects, Object{ + Key: *object.Key, + Size: size, + }) + } + return objects, nil +} + +func (s *client) CreateBucket(ctx context.Context, bucket string) error { + _, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: aws.String(bucket), + }) + if err != nil { + return err + } + + return nil +} + +func (s *client) FragmentedUploadObject( + ctx context.Context, + bucket string, + key string, + data []byte, + fragmentSize int) error { + + fragments, err := BreakIntoFragments(key, data, s.cfg.FragmentPrefixChars, fragmentSize) + if err != nil { + return err + } + resultChannel := make(chan error, len(fragments)) + + ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) + defer cancel() + + for _, fragment := range fragments { + fragmentCapture := fragment + s.pool.Submit(func() { + s.fragmentedWriteTask(ctx, resultChannel, fragmentCapture, bucket) + }) + } + + for range fragments { + err := <-resultChannel + if err != nil { + return err + } + } + return ctx.Err() + +} + +// fragmentedWriteTask writes a single file to S3. +func (s *client) fragmentedWriteTask( + ctx context.Context, + resultChannel chan error, + fragment *Fragment, + bucket string) { + + _, err := s.s3Client.PutObject(ctx, + &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fragment.FragmentKey), + Body: bytes.NewReader(fragment.Data), + }) + + resultChannel <- err +} + +func (s *client) FragmentedDownloadObject( + ctx context.Context, + bucket string, + key string, + fileSize int, + fragmentSize int) ([]byte, error) { + + if fragmentSize <= 0 { + return nil, errors.New("fragmentSize must be greater than 0") + } + + fragmentKeys, err := GetFragmentKeys(key, s.cfg.FragmentPrefixChars, GetFragmentCount(fileSize, fragmentSize)) + if err != nil { + return nil, err + } + resultChannel := make(chan *readResult, len(fragmentKeys)) + + ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) + defer cancel() + + for i, fragmentKey := range fragmentKeys { + boundFragmentKey := fragmentKey + boundI := i + s.pool.Submit(func() { + s.readTask(ctx, resultChannel, bucket, boundFragmentKey, boundI) + }) + } + + fragments := make([]*Fragment, len(fragmentKeys)) + for i := 0; i < len(fragmentKeys); i++ { + result := <-resultChannel + if result.err != nil { + return nil, result.err + } + fragments[result.fragment.Index] = result.fragment + } + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + return RecombineFragments(fragments) + +} + +// readResult is the result of a read task. +type readResult struct { + fragment *Fragment + err error +} + +// readTask reads a single file from S3. +func (s *client) readTask( + ctx context.Context, + resultChannel chan *readResult, + bucket string, + key string, + index int) { + + result := &readResult{} + defer func() { + resultChannel <- result + }() + + ret, err := s.s3Client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + + if err != nil { + result.err = err + return + } + + data := make([]byte, *ret.ContentLength) + bytesRead := 0 + + for bytesRead < len(data) && ctx.Err() == nil { + count, err := ret.Body.Read(data[bytesRead:]) + if err != nil && err.Error() != "EOF" { + result.err = err + return + } + bytesRead += count + } + + result.fragment = &Fragment{ + FragmentKey: key, + Data: data, + Index: index, + } + + err = ret.Body.Close() + if err != nil { + result.err = err + } +} diff --git a/common/aws/s3/s3.go b/common/aws/s3/s3.go index 74089099a9..a470d40c3f 100644 --- a/common/aws/s3/s3.go +++ b/common/aws/s3/s3.go @@ -21,28 +21,28 @@ type Client interface { // CreateBucket creates a bucket in S3. CreateBucket(ctx context.Context, bucket string) error - // FragmentedUploadObject uploads a file to S3. The fragmentSize parameter specifies the maximum size of each - // file uploaded to S3. If the file is larger than fragmentSize then it will be broken into - // smaller parts and uploaded in parallel. The file will be reassembled on download. + //// FragmentedUploadObject uploads a file to S3. The fragmentSize parameter specifies the maximum size of each + //// file uploaded to S3. If the file is larger than fragmentSize then it will be broken into + //// smaller parts and uploaded in parallel. The file will be reassembled on download. + //// + //// Note: if a file is uploaded with this method, only the FragmentedDownloadObject method should be used to + //// download the file. It is not advised to use DeleteObject on files uploaded with this method (if such + //// functionality is required, a new method to do so should be added to this interface). + //FragmentedUploadObject( + // ctx context.Context, + // bucket string, + // key string, + // data []byte, + // fragmentSize int) error // - // Note: if a file is uploaded with this method, only the FragmentedDownloadObject method should be used to - // download the file. It is not advised to use DeleteObject on files uploaded with this method (if such - // functionality is required, a new method to do so should be added to this interface). - FragmentedUploadObject( - ctx context.Context, - bucket string, - key string, - data []byte, - fragmentSize int) error - - // FragmentedDownloadObject downloads a file from S3, as written by Upload. The fileSize (in bytes) and fragmentSize - // must be the same as the values used in the FragmentedUploadObject call. - // - // Note: this method can only be used to download files that were uploaded with the FragmentedUploadObject method. - FragmentedDownloadObject( - ctx context.Context, - bucket string, - key string, - fileSize int, - fragmentSize int) ([]byte, error) + //// FragmentedDownloadObject downloads a file from S3, as written by Upload. The fileSize (in bytes) and fragmentSize + //// must be the same as the values used in the FragmentedUploadObject call. + //// + //// Note: this method can only be used to download files that were uploaded with the FragmentedUploadObject method. + //FragmentedDownloadObject( + // ctx context.Context, + // bucket string, + // key string, + // fileSize int, + // fragmentSize int) ([]byte, error) } diff --git a/common/aws/test/client_test.go b/common/aws/test/client_test.go index 0f5bc4087a..dfd935d973 100644 --- a/common/aws/test/client_test.go +++ b/common/aws/test/client_test.go @@ -1,176 +1,176 @@ package test -import ( - "context" - "github.com/Layr-Labs/eigenda/common" - "github.com/Layr-Labs/eigenda/common/aws" - "github.com/Layr-Labs/eigenda/common/aws/s3" - "github.com/Layr-Labs/eigenda/common/mock" - tu "github.com/Layr-Labs/eigenda/common/testutils" - "github.com/Layr-Labs/eigenda/inabox/deploy" - "github.com/ory/dockertest/v3" - "github.com/stretchr/testify/assert" - "math/rand" - "os" - "testing" -) - -var ( - dockertestPool *dockertest.Pool - dockertestResource *dockertest.Resource -) - -const ( - localstackPort = "4570" - localstackHost = "http://0.0.0.0:4570" - bucket = "eigen-test" -) - -type clientBuilder struct { - // This method is called at the beginning of the test. - start func() error - // This method is called to build a new client. - build func() (s3.Client, error) - // This method is called at the end of the test when all operations are done. - finish func() error -} - -var clientBuilders = []*clientBuilder{ - { - start: func() error { - return nil - }, - build: func() (s3.Client, error) { - return mock.NewS3Client(), nil - }, - finish: func() error { - return nil - }, - }, - { - start: func() error { - return setupLocalstack() - }, - build: func() (s3.Client, error) { - - logger, err := common.NewLogger(common.DefaultLoggerConfig()) - if err != nil { - return nil, err - } - - config := aws.DefaultClientConfig() - config.EndpointURL = localstackHost - config.Region = "us-east-1" - - err = os.Setenv("AWS_ACCESS_KEY_ID", "localstack") - if err != nil { - return nil, err - } - err = os.Setenv("AWS_SECRET_ACCESS_KEY", "localstack") - if err != nil { - return nil, err - } - - client, err := s3.NewClient(context.Background(), *config, logger) - if err != nil { - return nil, err - } - - err = client.CreateBucket(context.Background(), bucket) - if err != nil { - return nil, err - } - - return client, nil - }, - finish: func() error { - teardownLocalstack() - return nil - }, - }, -} - -func setupLocalstack() error { - deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") - - if deployLocalStack { - var err error - dockertestPool, dockertestResource, err = deploy.StartDockertestWithLocalstackContainer(localstackPort) - if err != nil && err.Error() == "container already exists" { - teardownLocalstack() - return err - } - } - return nil -} - -func teardownLocalstack() { - deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") - - if deployLocalStack { - deploy.PurgeDockertestResources(dockertestPool, dockertestResource) - } -} - -func RandomOperationsTest(t *testing.T, client s3.Client) { - numberToWrite := 100 - expectedData := make(map[string][]byte) - - fragmentSize := rand.Intn(1000) + 1000 - - for i := 0; i < numberToWrite; i++ { - key := tu.RandomString(10) - fragmentMultiple := rand.Float64() * 10 - dataSize := int(fragmentMultiple*float64(fragmentSize)) + 1 - data := tu.RandomBytes(dataSize) - expectedData[key] = data - - err := client.FragmentedUploadObject(context.Background(), bucket, key, data, fragmentSize) - assert.NoError(t, err) - } - - // Read back the data - for key, expected := range expectedData { - data, err := client.FragmentedDownloadObject(context.Background(), bucket, key, len(expected), fragmentSize) - assert.NoError(t, err) - assert.Equal(t, expected, data) - } -} - -func TestRandomOperations(t *testing.T) { - tu.InitializeRandom() - for _, builder := range clientBuilders { - err := builder.start() - assert.NoError(t, err) - - client, err := builder.build() - assert.NoError(t, err) - RandomOperationsTest(t, client) - - err = builder.finish() - assert.NoError(t, err) - } -} - -func ReadNonExistentValueTest(t *testing.T, client s3.Client) { - _, err := client.FragmentedDownloadObject(context.Background(), bucket, "nonexistent", 1000, 1000) - assert.Error(t, err) - randomKey := tu.RandomString(10) - _, err = client.FragmentedDownloadObject(context.Background(), bucket, randomKey, 0, 0) - assert.Error(t, err) -} - -func TestReadNonExistentValue(t *testing.T) { - tu.InitializeRandom() - for _, builder := range clientBuilders { - err := builder.start() - assert.NoError(t, err) - - client, err := builder.build() - assert.NoError(t, err) - ReadNonExistentValueTest(t, client) - - err = builder.finish() - assert.NoError(t, err) - } -} +//import ( +// "context" +// "github.com/Layr-Labs/eigenda/common" +// "github.com/Layr-Labs/eigenda/common/aws" +// "github.com/Layr-Labs/eigenda/common/aws/s3" +// "github.com/Layr-Labs/eigenda/common/mock" +// tu "github.com/Layr-Labs/eigenda/common/testutils" +// "github.com/Layr-Labs/eigenda/inabox/deploy" +// "github.com/ory/dockertest/v3" +// "github.com/stretchr/testify/assert" +// "math/rand" +// "os" +// "testing" +//) +// +//var ( +// dockertestPool *dockertest.Pool +// dockertestResource *dockertest.Resource +//) +// +//const ( +// localstackPort = "4570" +// localstackHost = "http://0.0.0.0:4570" +// bucket = "eigen-test" +//) +// +//type clientBuilder struct { +// // This method is called at the beginning of the test. +// start func() error +// // This method is called to build a new client. +// build func() (s3.Client, error) +// // This method is called at the end of the test when all operations are done. +// finish func() error +//} +// +//var clientBuilders = []*clientBuilder{ +// { +// start: func() error { +// return nil +// }, +// build: func() (s3.Client, error) { +// return mock.NewS3Client(), nil +// }, +// finish: func() error { +// return nil +// }, +// }, +// { +// start: func() error { +// return setupLocalstack() +// }, +// build: func() (s3.Client, error) { +// +// logger, err := common.NewLogger(common.DefaultLoggerConfig()) +// if err != nil { +// return nil, err +// } +// +// config := aws.DefaultClientConfig() +// config.EndpointURL = localstackHost +// config.Region = "us-east-1" +// +// err = os.Setenv("AWS_ACCESS_KEY_ID", "localstack") +// if err != nil { +// return nil, err +// } +// err = os.Setenv("AWS_SECRET_ACCESS_KEY", "localstack") +// if err != nil { +// return nil, err +// } +// +// client, err := s3.NewClient(context.Background(), *config, logger) +// if err != nil { +// return nil, err +// } +// +// err = client.CreateBucket(context.Background(), bucket) +// if err != nil { +// return nil, err +// } +// +// return client, nil +// }, +// finish: func() error { +// teardownLocalstack() +// return nil +// }, +// }, +//} +// +//func setupLocalstack() error { +// deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") +// +// if deployLocalStack { +// var err error +// dockertestPool, dockertestResource, err = deploy.StartDockertestWithLocalstackContainer(localstackPort) +// if err != nil && err.Error() == "container already exists" { +// teardownLocalstack() +// return err +// } +// } +// return nil +//} +// +//func teardownLocalstack() { +// deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") +// +// if deployLocalStack { +// deploy.PurgeDockertestResources(dockertestPool, dockertestResource) +// } +//} +// +//func RandomOperationsTest(t *testing.T, client s3.Client) { +// numberToWrite := 100 +// expectedData := make(map[string][]byte) +// +// fragmentSize := rand.Intn(1000) + 1000 +// +// for i := 0; i < numberToWrite; i++ { +// key := tu.RandomString(10) +// fragmentMultiple := rand.Float64() * 10 +// dataSize := int(fragmentMultiple*float64(fragmentSize)) + 1 +// data := tu.RandomBytes(dataSize) +// expectedData[key] = data +// +// err := client.FragmentedUploadObject(context.Background(), bucket, key, data, fragmentSize) +// assert.NoError(t, err) +// } +// +// // Read back the data +// for key, expected := range expectedData { +// data, err := client.FragmentedDownloadObject(context.Background(), bucket, key, len(expected), fragmentSize) +// assert.NoError(t, err) +// assert.Equal(t, expected, data) +// } +//} +// +//func TestRandomOperations(t *testing.T) { +// tu.InitializeRandom() +// for _, builder := range clientBuilders { +// err := builder.start() +// assert.NoError(t, err) +// +// client, err := builder.build() +// assert.NoError(t, err) +// RandomOperationsTest(t, client) +// +// err = builder.finish() +// assert.NoError(t, err) +// } +//} +// +//func ReadNonExistentValueTest(t *testing.T, client s3.Client) { +// _, err := client.FragmentedDownloadObject(context.Background(), bucket, "nonexistent", 1000, 1000) +// assert.Error(t, err) +// randomKey := tu.RandomString(10) +// _, err = client.FragmentedDownloadObject(context.Background(), bucket, randomKey, 0, 0) +// assert.Error(t, err) +//} +// +//func TestReadNonExistentValue(t *testing.T) { +// tu.InitializeRandom() +// for _, builder := range clientBuilders { +// err := builder.start() +// assert.NoError(t, err) +// +// client, err := builder.build() +// assert.NoError(t, err) +// ReadNonExistentValueTest(t, client) +// +// err = builder.finish() +// assert.NoError(t, err) +// } +//} From 42c1dbde9e0536b7ba9e8c18930c89d935e9800d Mon Sep 17 00:00:00 2001 From: Cody Littley Date: Thu, 31 Oct 2024 09:12:43 -0500 Subject: [PATCH 04/11] Disable more stuff Signed-off-by: Cody Littley --- common/aws/s3/client.go | 225 +++++----------------------------------- 1 file changed, 25 insertions(+), 200 deletions(-) diff --git a/common/aws/s3/client.go b/common/aws/s3/client.go index 8a88318117..231d546ae6 100644 --- a/common/aws/s3/client.go +++ b/common/aws/s3/client.go @@ -4,8 +4,6 @@ import ( "bytes" "context" "errors" - "github.com/gammazero/workerpool" - "runtime" "sync" commonaws "github.com/Layr-Labs/eigenda/common/aws" @@ -29,9 +27,7 @@ type Object struct { } type client struct { - cfg *commonaws.ClientConfig s3Client *s3.Client - pool *workerpool.WorkerPool logger logging.Logger } @@ -40,19 +36,18 @@ var _ Client = (*client)(nil) func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.Logger) (*client, error) { var err error once.Do(func() { - customResolver := aws.EndpointResolverWithOptionsFunc( - func(service, region string, options ...interface{}) (aws.Endpoint, error) { - if cfg.EndpointURL != "" { - return aws.Endpoint{ - PartitionID: "aws", - URL: cfg.EndpointURL, - SigningRegion: cfg.Region, - }, nil - } - - // returning EndpointNotFoundError will allow the service to fallback to its default resolution - return aws.Endpoint{}, &aws.EndpointNotFoundError{} - }) + customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { + if cfg.EndpointURL != "" { + return aws.Endpoint{ + PartitionID: "aws", + URL: cfg.EndpointURL, + SigningRegion: cfg.Region, + }, nil + } + + // returning EndpointNotFoundError will allow the service to fallback to its default resolution + return aws.Endpoint{}, &aws.EndpointNotFoundError{} + }) options := [](func(*config.LoadOptions) error){ config.WithRegion(cfg.Region), @@ -61,9 +56,7 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L } // If access key and secret access key are not provided, use the default credential provider if len(cfg.AccessKey) > 0 && len(cfg.SecretAccessKey) > 0 { - options = append(options, - config.WithCredentialsProvider( - credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) + options = append(options, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) } awsConfig, errCfg := config.LoadDefaultConfig(context.Background(), options...) @@ -71,34 +64,25 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L err = errCfg return } - s3Client := s3.NewFromConfig(awsConfig, func(o *s3.Options) { o.UsePathStyle = true }) - - workers := 0 - if cfg.FragmentParallelismConstant > 0 { - workers = cfg.FragmentParallelismConstant - } - if cfg.FragmentParallelismFactor > 0 { - workers = cfg.FragmentParallelismFactor * runtime.NumCPU() - } - - if workers == 0 { - workers = 1 - } - pool := workerpool.New(workers) - - ref = &client{ - cfg: &cfg, - s3Client: s3Client, - pool: pool, - logger: logger.With("component", "S3Client"), - } + ref = &client{s3Client: s3Client, logger: logger.With("component", "S3Client")} }) return ref, err } +func (s *client) CreateBucket(ctx context.Context, bucket string) error { + _, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: aws.String(bucket), + }) + if err != nil { + return err + } + + return nil +} + func (s *client) DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) { var partMiBs int64 = 10 downloader := manager.NewDownloader(s.s3Client, func(d *manager.Downloader) { @@ -175,162 +159,3 @@ func (s *client) ListObjects(ctx context.Context, bucket string, prefix string) } return objects, nil } - -func (s *client) CreateBucket(ctx context.Context, bucket string) error { - _, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ - Bucket: aws.String(bucket), - }) - if err != nil { - return err - } - - return nil -} - -func (s *client) FragmentedUploadObject( - ctx context.Context, - bucket string, - key string, - data []byte, - fragmentSize int) error { - - fragments, err := BreakIntoFragments(key, data, s.cfg.FragmentPrefixChars, fragmentSize) - if err != nil { - return err - } - resultChannel := make(chan error, len(fragments)) - - ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) - defer cancel() - - for _, fragment := range fragments { - fragmentCapture := fragment - s.pool.Submit(func() { - s.fragmentedWriteTask(ctx, resultChannel, fragmentCapture, bucket) - }) - } - - for range fragments { - err := <-resultChannel - if err != nil { - return err - } - } - return ctx.Err() - -} - -// fragmentedWriteTask writes a single file to S3. -func (s *client) fragmentedWriteTask( - ctx context.Context, - resultChannel chan error, - fragment *Fragment, - bucket string) { - - _, err := s.s3Client.PutObject(ctx, - &s3.PutObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(fragment.FragmentKey), - Body: bytes.NewReader(fragment.Data), - }) - - resultChannel <- err -} - -func (s *client) FragmentedDownloadObject( - ctx context.Context, - bucket string, - key string, - fileSize int, - fragmentSize int) ([]byte, error) { - - if fragmentSize <= 0 { - return nil, errors.New("fragmentSize must be greater than 0") - } - - fragmentKeys, err := GetFragmentKeys(key, s.cfg.FragmentPrefixChars, GetFragmentCount(fileSize, fragmentSize)) - if err != nil { - return nil, err - } - resultChannel := make(chan *readResult, len(fragmentKeys)) - - ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) - defer cancel() - - for i, fragmentKey := range fragmentKeys { - boundFragmentKey := fragmentKey - boundI := i - s.pool.Submit(func() { - s.readTask(ctx, resultChannel, bucket, boundFragmentKey, boundI) - }) - } - - fragments := make([]*Fragment, len(fragmentKeys)) - for i := 0; i < len(fragmentKeys); i++ { - result := <-resultChannel - if result.err != nil { - return nil, result.err - } - fragments[result.fragment.Index] = result.fragment - } - - if ctx.Err() != nil { - return nil, ctx.Err() - } - - return RecombineFragments(fragments) - -} - -// readResult is the result of a read task. -type readResult struct { - fragment *Fragment - err error -} - -// readTask reads a single file from S3. -func (s *client) readTask( - ctx context.Context, - resultChannel chan *readResult, - bucket string, - key string, - index int) { - - result := &readResult{} - defer func() { - resultChannel <- result - }() - - ret, err := s.s3Client.GetObject(ctx, &s3.GetObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - }) - - if err != nil { - result.err = err - return - } - - data := make([]byte, *ret.ContentLength) - bytesRead := 0 - - for bytesRead < len(data) && ctx.Err() == nil { - count, err := ret.Body.Read(data[bytesRead:]) - if err != nil && err.Error() != "EOF" { - result.err = err - return - } - bytesRead += count - } - - result.fragment = &Fragment{ - FragmentKey: key, - Data: data, - Index: index, - } - - err = ret.Body.Close() - if err != nil { - result.err = err - } -} From c78fa06cbb843f380598b0f90a253e8bbe68ef38 Mon Sep 17 00:00:00 2001 From: Cody Littley Date: Thu, 31 Oct 2024 09:25:16 -0500 Subject: [PATCH 05/11] Disable more code. Signed-off-by: Cody Littley --- common/aws/cli-copy.txt | 141 ++++++++++++++++++++++++++++++++++++++++ common/aws/cli.go | 102 ++++------------------------- 2 files changed, 152 insertions(+), 91 deletions(-) create mode 100644 common/aws/cli-copy.txt diff --git a/common/aws/cli-copy.txt b/common/aws/cli-copy.txt new file mode 100644 index 0000000000..9a4a51b744 --- /dev/null +++ b/common/aws/cli-copy.txt @@ -0,0 +1,141 @@ +package aws + +import ( + "github.com/Layr-Labs/eigenda/common" + "github.com/urfave/cli" + "time" +) + +var ( + RegionFlagName = "aws.region" + AccessKeyIdFlagName = "aws.access-key-id" + SecretAccessKeyFlagName = "aws.secret-access-key" + EndpointURLFlagName = "aws.endpoint-url" + FragmentPrefixCharsFlagName = "aws.fragment-prefix-chars" + FragmentParallelismFactorFlagName = "aws.fragment-parallelism-factor" + FragmentParallelismConstantFlagName = "aws.fragment-parallelism-constant" + FragmentReadTimeoutFlagName = "aws.fragment-read-timeout" + FragmentWriteTimeoutFlagName = "aws.fragment-write-timeout" +) + +type ClientConfig struct { + // Region is the region to use when interacting with S3. Default is "us-east-2". + Region string + // AccessKey to use when interacting with S3. + AccessKey string + // SecretAccessKey to use when interacting with S3. + SecretAccessKey string + // EndpointURL of the S3 endpoint to use. If this is not set then the default AWS S3 endpoint will be used. + EndpointURL string + + // FragmentPrefixChars is the number of characters of the key to use as the prefix for fragmented files. + // A value of "3" for the key "ABCDEFG" will result in the prefix "ABC". Default is 3. + FragmentPrefixChars int + // FragmentParallelismFactor helps determine the size of the pool of workers to help upload/download files. + // A non-zero value for this parameter adds a number of workers equal to the number of cores times this value. + // Default is 8. In general, the number of workers here can be a lot larger than the number of cores because the + // workers will be blocked on I/O most of the time. + FragmentParallelismFactor int + // FragmentParallelismConstant helps determine the size of the pool of workers to help upload/download files. + // A non-zero value for this parameter adds a constant number of workers. Default is 0. + FragmentParallelismConstant int + // FragmentReadTimeout is used to bound the maximum time to wait for a single fragmented read. + // Default is 30 seconds. + FragmentReadTimeout time.Duration + // FragmentWriteTimeout is used to bound the maximum time to wait for a single fragmented write. + // Default is 30 seconds. + FragmentWriteTimeout time.Duration +} + +func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { + return []cli.Flag{ + cli.StringFlag{ + Name: common.PrefixFlag(flagPrefix, RegionFlagName), + Usage: "AWS Region", + Required: true, + EnvVar: common.PrefixEnvVar(envPrefix, "AWS_REGION"), + }, + cli.StringFlag{ + Name: common.PrefixFlag(flagPrefix, AccessKeyIdFlagName), + Usage: "AWS Access Key Id", + Required: false, + Value: "", + EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ACCESS_KEY_ID"), + }, + cli.StringFlag{ + Name: common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName), + Usage: "AWS Secret Access Key", + Required: false, + Value: "", + EnvVar: common.PrefixEnvVar(envPrefix, "AWS_SECRET_ACCESS_KEY"), + }, + cli.StringFlag{ + Name: common.PrefixFlag(flagPrefix, EndpointURLFlagName), + Usage: "AWS Endpoint URL", + Required: false, + Value: "", + EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ENDPOINT_URL"), + }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), + Usage: "The number of characters of the key to use as the prefix for fragmented files", + Required: false, + Value: 3, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"), + }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), + Usage: "Add this many threads times the number of cores to the worker pool", + Required: false, + Value: 8, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"), + }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName), + Usage: "Add this many threads to the worker pool", + Required: false, + Value: 0, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"), + }, + cli.DurationFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName), + Usage: "The maximum time to wait for a single fragmented read", + Required: false, + Value: 30 * time.Second, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"), + }, + cli.DurationFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName), + Usage: "The maximum time to wait for a single fragmented write", + Required: false, + Value: 30 * time.Second, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"), + }, + } +} + +func ReadClientConfig(ctx *cli.Context, flagPrefix string) ClientConfig { + return ClientConfig{ + Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), + AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), + SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), + EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), + FragmentPrefixChars: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName)), + FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)), + FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)), + FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)), + FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)), + } +} + +// DefaultClientConfig returns a new ClientConfig with default values. +func DefaultClientConfig() *ClientConfig { + return &ClientConfig{ + Region: "us-east-2", + FragmentPrefixChars: 3, + FragmentParallelismFactor: 8, + FragmentParallelismConstant: 0, + FragmentReadTimeout: 30 * time.Second, + FragmentWriteTimeout: 30 * time.Second, + } +} diff --git a/common/aws/cli.go b/common/aws/cli.go index 9a4a51b744..5a6d11503b 100644 --- a/common/aws/cli.go +++ b/common/aws/cli.go @@ -3,48 +3,20 @@ package aws import ( "github.com/Layr-Labs/eigenda/common" "github.com/urfave/cli" - "time" ) var ( - RegionFlagName = "aws.region" - AccessKeyIdFlagName = "aws.access-key-id" - SecretAccessKeyFlagName = "aws.secret-access-key" - EndpointURLFlagName = "aws.endpoint-url" - FragmentPrefixCharsFlagName = "aws.fragment-prefix-chars" - FragmentParallelismFactorFlagName = "aws.fragment-parallelism-factor" - FragmentParallelismConstantFlagName = "aws.fragment-parallelism-constant" - FragmentReadTimeoutFlagName = "aws.fragment-read-timeout" - FragmentWriteTimeoutFlagName = "aws.fragment-write-timeout" + RegionFlagName = "aws.region" + AccessKeyIdFlagName = "aws.access-key-id" + SecretAccessKeyFlagName = "aws.secret-access-key" + EndpointURLFlagName = "aws.endpoint-url" ) type ClientConfig struct { - // Region is the region to use when interacting with S3. Default is "us-east-2". - Region string - // AccessKey to use when interacting with S3. - AccessKey string - // SecretAccessKey to use when interacting with S3. + Region string + AccessKey string SecretAccessKey string - // EndpointURL of the S3 endpoint to use. If this is not set then the default AWS S3 endpoint will be used. - EndpointURL string - - // FragmentPrefixChars is the number of characters of the key to use as the prefix for fragmented files. - // A value of "3" for the key "ABCDEFG" will result in the prefix "ABC". Default is 3. - FragmentPrefixChars int - // FragmentParallelismFactor helps determine the size of the pool of workers to help upload/download files. - // A non-zero value for this parameter adds a number of workers equal to the number of cores times this value. - // Default is 8. In general, the number of workers here can be a lot larger than the number of cores because the - // workers will be blocked on I/O most of the time. - FragmentParallelismFactor int - // FragmentParallelismConstant helps determine the size of the pool of workers to help upload/download files. - // A non-zero value for this parameter adds a constant number of workers. Default is 0. - FragmentParallelismConstant int - // FragmentReadTimeout is used to bound the maximum time to wait for a single fragmented read. - // Default is 30 seconds. - FragmentReadTimeout time.Duration - // FragmentWriteTimeout is used to bound the maximum time to wait for a single fragmented write. - // Default is 30 seconds. - FragmentWriteTimeout time.Duration + EndpointURL string } func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { @@ -76,66 +48,14 @@ func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { Value: "", EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ENDPOINT_URL"), }, - cli.IntFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), - Usage: "The number of characters of the key to use as the prefix for fragmented files", - Required: false, - Value: 3, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"), - }, - cli.IntFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), - Usage: "Add this many threads times the number of cores to the worker pool", - Required: false, - Value: 8, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"), - }, - cli.IntFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName), - Usage: "Add this many threads to the worker pool", - Required: false, - Value: 0, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"), - }, - cli.DurationFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName), - Usage: "The maximum time to wait for a single fragmented read", - Required: false, - Value: 30 * time.Second, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"), - }, - cli.DurationFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName), - Usage: "The maximum time to wait for a single fragmented write", - Required: false, - Value: 30 * time.Second, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"), - }, } } func ReadClientConfig(ctx *cli.Context, flagPrefix string) ClientConfig { return ClientConfig{ - Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), - AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), - SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), - EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), - FragmentPrefixChars: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName)), - FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)), - FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)), - FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)), - FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)), - } -} - -// DefaultClientConfig returns a new ClientConfig with default values. -func DefaultClientConfig() *ClientConfig { - return &ClientConfig{ - Region: "us-east-2", - FragmentPrefixChars: 3, - FragmentParallelismFactor: 8, - FragmentParallelismConstant: 0, - FragmentReadTimeout: 30 * time.Second, - FragmentWriteTimeout: 30 * time.Second, + Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), + AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), + SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), + EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), } } From f95b3c24a893d40fb0ff9ee4df72471096a14320 Mon Sep 17 00:00:00 2001 From: Cody Littley Date: Thu, 31 Oct 2024 09:48:15 -0500 Subject: [PATCH 06/11] Experimentation. Signed-off-by: Cody Littley --- common/aws/cli.go | 94 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 87 insertions(+), 7 deletions(-) diff --git a/common/aws/cli.go b/common/aws/cli.go index 5a6d11503b..91bc0f374d 100644 --- a/common/aws/cli.go +++ b/common/aws/cli.go @@ -3,20 +3,48 @@ package aws import ( "github.com/Layr-Labs/eigenda/common" "github.com/urfave/cli" + "time" ) var ( - RegionFlagName = "aws.region" - AccessKeyIdFlagName = "aws.access-key-id" - SecretAccessKeyFlagName = "aws.secret-access-key" - EndpointURLFlagName = "aws.endpoint-url" + RegionFlagName = "aws.region" + AccessKeyIdFlagName = "aws.access-key-id" + SecretAccessKeyFlagName = "aws.secret-access-key" + EndpointURLFlagName = "aws.endpoint-url" + FragmentPrefixCharsFlagName = "aws.fragment-prefix-chars" + FragmentParallelismFactorFlagName = "aws.fragment-parallelism-factor" + FragmentParallelismConstantFlagName = "aws.fragment-parallelism-constant" + FragmentReadTimeoutFlagName = "aws.fragment-read-timeout" + FragmentWriteTimeoutFlagName = "aws.fragment-write-timeout" ) type ClientConfig struct { - Region string - AccessKey string + // Region is the region to use when interacting with S3. Default is "us-east-2". + Region string + // AccessKey to use when interacting with S3. + AccessKey string + // SecretAccessKey to use when interacting with S3. SecretAccessKey string - EndpointURL string + // EndpointURL of the S3 endpoint to use. If this is not set then the default AWS S3 endpoint will be used. + EndpointURL string + + // FragmentPrefixChars is the number of characters of the key to use as the prefix for fragmented files. + // A value of "3" for the key "ABCDEFG" will result in the prefix "ABC". Default is 3. + FragmentPrefixChars int + // FragmentParallelismFactor helps determine the size of the pool of workers to help upload/download files. + // A non-zero value for this parameter adds a number of workers equal to the number of cores times this value. + // Default is 8. In general, the number of workers here can be a lot larger than the number of cores because the + // workers will be blocked on I/O most of the time. + FragmentParallelismFactor int + // FragmentParallelismConstant helps determine the size of the pool of workers to help upload/download files. + // A non-zero value for this parameter adds a constant number of workers. Default is 0. + FragmentParallelismConstant int + // FragmentReadTimeout is used to bound the maximum time to wait for a single fragmented read. + // Default is 30 seconds. + FragmentReadTimeout time.Duration + // FragmentWriteTimeout is used to bound the maximum time to wait for a single fragmented write. + // Default is 30 seconds. + FragmentWriteTimeout time.Duration } func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { @@ -48,6 +76,41 @@ func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { Value: "", EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ENDPOINT_URL"), }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), + Usage: "The number of characters of the key to use as the prefix for fragmented files", + Required: false, + Value: 3, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"), + }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), + Usage: "Add this many threads times the number of cores to the worker pool", + Required: false, + Value: 8, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"), + }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName), + Usage: "Add this many threads to the worker pool", + Required: false, + Value: 0, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"), + }, + cli.DurationFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName), + Usage: "The maximum time to wait for a single fragmented read", + Required: false, + Value: 30 * time.Second, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"), + }, + cli.DurationFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName), + Usage: "The maximum time to wait for a single fragmented write", + Required: false, + Value: 30 * time.Second, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"), + }, } } @@ -57,5 +120,22 @@ func ReadClientConfig(ctx *cli.Context, flagPrefix string) ClientConfig { AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), + //FragmentPrefixChars: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName)), + //FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)), + //FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)), + //FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)), + //FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)), + } +} + +// DefaultClientConfig returns a new ClientConfig with default values. +func DefaultClientConfig() *ClientConfig { + return &ClientConfig{ + Region: "us-east-2", + FragmentPrefixChars: 3, + FragmentParallelismFactor: 8, + FragmentParallelismConstant: 0, + FragmentReadTimeout: 30 * time.Second, + FragmentWriteTimeout: 30 * time.Second, } } From 108e5403712efa49968250c4621236c21a6f005f Mon Sep 17 00:00:00 2001 From: Cody Littley Date: Thu, 31 Oct 2024 10:00:02 -0500 Subject: [PATCH 07/11] Disable more Signed-off-by: Cody Littley --- common/aws/cli.go | 70 +++++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/common/aws/cli.go b/common/aws/cli.go index 91bc0f374d..59b93b4ab0 100644 --- a/common/aws/cli.go +++ b/common/aws/cli.go @@ -76,41 +76,41 @@ func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { Value: "", EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ENDPOINT_URL"), }, - cli.IntFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), - Usage: "The number of characters of the key to use as the prefix for fragmented files", - Required: false, - Value: 3, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"), - }, - cli.IntFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), - Usage: "Add this many threads times the number of cores to the worker pool", - Required: false, - Value: 8, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"), - }, - cli.IntFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName), - Usage: "Add this many threads to the worker pool", - Required: false, - Value: 0, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"), - }, - cli.DurationFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName), - Usage: "The maximum time to wait for a single fragmented read", - Required: false, - Value: 30 * time.Second, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"), - }, - cli.DurationFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName), - Usage: "The maximum time to wait for a single fragmented write", - Required: false, - Value: 30 * time.Second, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"), - }, + //cli.IntFlag{ + // Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), + // Usage: "The number of characters of the key to use as the prefix for fragmented files", + // Required: false, + // Value: 3, + // EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"), + //}, + //cli.IntFlag{ + // Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), + // Usage: "Add this many threads times the number of cores to the worker pool", + // Required: false, + // Value: 8, + // EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"), + //}, + //cli.IntFlag{ + // Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName), + // Usage: "Add this many threads to the worker pool", + // Required: false, + // Value: 0, + // EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"), + //}, + //cli.DurationFlag{ + // Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName), + // Usage: "The maximum time to wait for a single fragmented read", + // Required: false, + // Value: 30 * time.Second, + // EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"), + //}, + //cli.DurationFlag{ + // Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName), + // Usage: "The maximum time to wait for a single fragmented write", + // Required: false, + // Value: 30 * time.Second, + // EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"), + //}, } } From 3aa2fb050e348f3109b54e088c06e7d9eec9ddc8 Mon Sep 17 00:00:00 2001 From: Cody Littley Date: Thu, 31 Oct 2024 10:13:46 -0500 Subject: [PATCH 08/11] Fix flag. Signed-off-by: Cody Littley --- common/aws/cli.go | 70 +++++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/common/aws/cli.go b/common/aws/cli.go index 59b93b4ab0..e10d29f083 100644 --- a/common/aws/cli.go +++ b/common/aws/cli.go @@ -76,41 +76,41 @@ func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { Value: "", EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ENDPOINT_URL"), }, - //cli.IntFlag{ - // Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), - // Usage: "The number of characters of the key to use as the prefix for fragmented files", - // Required: false, - // Value: 3, - // EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"), - //}, - //cli.IntFlag{ - // Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), - // Usage: "Add this many threads times the number of cores to the worker pool", - // Required: false, - // Value: 8, - // EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"), - //}, - //cli.IntFlag{ - // Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName), - // Usage: "Add this many threads to the worker pool", - // Required: false, - // Value: 0, - // EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"), - //}, - //cli.DurationFlag{ - // Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName), - // Usage: "The maximum time to wait for a single fragmented read", - // Required: false, - // Value: 30 * time.Second, - // EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"), - //}, - //cli.DurationFlag{ - // Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName), - // Usage: "The maximum time to wait for a single fragmented write", - // Required: false, - // Value: 30 * time.Second, - // EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"), - //}, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName), + Usage: "The number of characters of the key to use as the prefix for fragmented files", + Required: false, + Value: 3, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"), + }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), + Usage: "Add this many threads times the number of cores to the worker pool", + Required: false, + Value: 8, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"), + }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName), + Usage: "Add this many threads to the worker pool", + Required: false, + Value: 0, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"), + }, + cli.DurationFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName), + Usage: "The maximum time to wait for a single fragmented read", + Required: false, + Value: 30 * time.Second, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"), + }, + cli.DurationFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName), + Usage: "The maximum time to wait for a single fragmented write", + Required: false, + Value: 30 * time.Second, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"), + }, } } From 22304d901232c577ed4847a7589173759a11f348 Mon Sep 17 00:00:00 2001 From: Cody Littley Date: Thu, 31 Oct 2024 10:26:20 -0500 Subject: [PATCH 09/11] Revert changes. Signed-off-by: Cody Littley --- common/aws/cli-copy.txt | 141 ------------- common/aws/cli.go | 18 +- common/aws/s3/client-copy.txt | 336 ------------------------------- common/aws/s3/client.go | 225 ++++++++++++++++++--- common/aws/s3/s3.go | 46 ++--- common/aws/test/client_test.go | 348 ++++++++++++++++----------------- 6 files changed, 406 insertions(+), 708 deletions(-) delete mode 100644 common/aws/cli-copy.txt delete mode 100644 common/aws/s3/client-copy.txt diff --git a/common/aws/cli-copy.txt b/common/aws/cli-copy.txt deleted file mode 100644 index 9a4a51b744..0000000000 --- a/common/aws/cli-copy.txt +++ /dev/null @@ -1,141 +0,0 @@ -package aws - -import ( - "github.com/Layr-Labs/eigenda/common" - "github.com/urfave/cli" - "time" -) - -var ( - RegionFlagName = "aws.region" - AccessKeyIdFlagName = "aws.access-key-id" - SecretAccessKeyFlagName = "aws.secret-access-key" - EndpointURLFlagName = "aws.endpoint-url" - FragmentPrefixCharsFlagName = "aws.fragment-prefix-chars" - FragmentParallelismFactorFlagName = "aws.fragment-parallelism-factor" - FragmentParallelismConstantFlagName = "aws.fragment-parallelism-constant" - FragmentReadTimeoutFlagName = "aws.fragment-read-timeout" - FragmentWriteTimeoutFlagName = "aws.fragment-write-timeout" -) - -type ClientConfig struct { - // Region is the region to use when interacting with S3. Default is "us-east-2". - Region string - // AccessKey to use when interacting with S3. - AccessKey string - // SecretAccessKey to use when interacting with S3. - SecretAccessKey string - // EndpointURL of the S3 endpoint to use. If this is not set then the default AWS S3 endpoint will be used. - EndpointURL string - - // FragmentPrefixChars is the number of characters of the key to use as the prefix for fragmented files. - // A value of "3" for the key "ABCDEFG" will result in the prefix "ABC". Default is 3. - FragmentPrefixChars int - // FragmentParallelismFactor helps determine the size of the pool of workers to help upload/download files. - // A non-zero value for this parameter adds a number of workers equal to the number of cores times this value. - // Default is 8. In general, the number of workers here can be a lot larger than the number of cores because the - // workers will be blocked on I/O most of the time. - FragmentParallelismFactor int - // FragmentParallelismConstant helps determine the size of the pool of workers to help upload/download files. - // A non-zero value for this parameter adds a constant number of workers. Default is 0. - FragmentParallelismConstant int - // FragmentReadTimeout is used to bound the maximum time to wait for a single fragmented read. - // Default is 30 seconds. - FragmentReadTimeout time.Duration - // FragmentWriteTimeout is used to bound the maximum time to wait for a single fragmented write. - // Default is 30 seconds. - FragmentWriteTimeout time.Duration -} - -func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { - return []cli.Flag{ - cli.StringFlag{ - Name: common.PrefixFlag(flagPrefix, RegionFlagName), - Usage: "AWS Region", - Required: true, - EnvVar: common.PrefixEnvVar(envPrefix, "AWS_REGION"), - }, - cli.StringFlag{ - Name: common.PrefixFlag(flagPrefix, AccessKeyIdFlagName), - Usage: "AWS Access Key Id", - Required: false, - Value: "", - EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ACCESS_KEY_ID"), - }, - cli.StringFlag{ - Name: common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName), - Usage: "AWS Secret Access Key", - Required: false, - Value: "", - EnvVar: common.PrefixEnvVar(envPrefix, "AWS_SECRET_ACCESS_KEY"), - }, - cli.StringFlag{ - Name: common.PrefixFlag(flagPrefix, EndpointURLFlagName), - Usage: "AWS Endpoint URL", - Required: false, - Value: "", - EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ENDPOINT_URL"), - }, - cli.IntFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), - Usage: "The number of characters of the key to use as the prefix for fragmented files", - Required: false, - Value: 3, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"), - }, - cli.IntFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), - Usage: "Add this many threads times the number of cores to the worker pool", - Required: false, - Value: 8, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"), - }, - cli.IntFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName), - Usage: "Add this many threads to the worker pool", - Required: false, - Value: 0, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"), - }, - cli.DurationFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName), - Usage: "The maximum time to wait for a single fragmented read", - Required: false, - Value: 30 * time.Second, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"), - }, - cli.DurationFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName), - Usage: "The maximum time to wait for a single fragmented write", - Required: false, - Value: 30 * time.Second, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"), - }, - } -} - -func ReadClientConfig(ctx *cli.Context, flagPrefix string) ClientConfig { - return ClientConfig{ - Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), - AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), - SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), - EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), - FragmentPrefixChars: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName)), - FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)), - FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)), - FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)), - FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)), - } -} - -// DefaultClientConfig returns a new ClientConfig with default values. -func DefaultClientConfig() *ClientConfig { - return &ClientConfig{ - Region: "us-east-2", - FragmentPrefixChars: 3, - FragmentParallelismFactor: 8, - FragmentParallelismConstant: 0, - FragmentReadTimeout: 30 * time.Second, - FragmentWriteTimeout: 30 * time.Second, - } -} diff --git a/common/aws/cli.go b/common/aws/cli.go index e10d29f083..e88618d454 100644 --- a/common/aws/cli.go +++ b/common/aws/cli.go @@ -116,15 +116,15 @@ func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { func ReadClientConfig(ctx *cli.Context, flagPrefix string) ClientConfig { return ClientConfig{ - Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), - AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), - SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), - EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), - //FragmentPrefixChars: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName)), - //FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)), - //FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)), - //FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)), - //FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)), + Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), + AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), + SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), + EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), + FragmentPrefixChars: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName)), + FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)), + FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)), + FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)), + FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)), } } diff --git a/common/aws/s3/client-copy.txt b/common/aws/s3/client-copy.txt deleted file mode 100644 index 8a88318117..0000000000 --- a/common/aws/s3/client-copy.txt +++ /dev/null @@ -1,336 +0,0 @@ -package s3 - -import ( - "bytes" - "context" - "errors" - "github.com/gammazero/workerpool" - "runtime" - "sync" - - commonaws "github.com/Layr-Labs/eigenda/common/aws" - "github.com/Layr-Labs/eigensdk-go/logging" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go-v2/feature/s3/manager" - "github.com/aws/aws-sdk-go-v2/service/s3" -) - -var ( - once sync.Once - ref *client - ErrObjectNotFound = errors.New("object not found") -) - -type Object struct { - Key string - Size int64 -} - -type client struct { - cfg *commonaws.ClientConfig - s3Client *s3.Client - pool *workerpool.WorkerPool - logger logging.Logger -} - -var _ Client = (*client)(nil) - -func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.Logger) (*client, error) { - var err error - once.Do(func() { - customResolver := aws.EndpointResolverWithOptionsFunc( - func(service, region string, options ...interface{}) (aws.Endpoint, error) { - if cfg.EndpointURL != "" { - return aws.Endpoint{ - PartitionID: "aws", - URL: cfg.EndpointURL, - SigningRegion: cfg.Region, - }, nil - } - - // returning EndpointNotFoundError will allow the service to fallback to its default resolution - return aws.Endpoint{}, &aws.EndpointNotFoundError{} - }) - - options := [](func(*config.LoadOptions) error){ - config.WithRegion(cfg.Region), - config.WithEndpointResolverWithOptions(customResolver), - config.WithRetryMode(aws.RetryModeStandard), - } - // If access key and secret access key are not provided, use the default credential provider - if len(cfg.AccessKey) > 0 && len(cfg.SecretAccessKey) > 0 { - options = append(options, - config.WithCredentialsProvider( - credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) - } - awsConfig, errCfg := config.LoadDefaultConfig(context.Background(), options...) - - if errCfg != nil { - err = errCfg - return - } - - s3Client := s3.NewFromConfig(awsConfig, func(o *s3.Options) { - o.UsePathStyle = true - }) - - workers := 0 - if cfg.FragmentParallelismConstant > 0 { - workers = cfg.FragmentParallelismConstant - } - if cfg.FragmentParallelismFactor > 0 { - workers = cfg.FragmentParallelismFactor * runtime.NumCPU() - } - - if workers == 0 { - workers = 1 - } - pool := workerpool.New(workers) - - ref = &client{ - cfg: &cfg, - s3Client: s3Client, - pool: pool, - logger: logger.With("component", "S3Client"), - } - }) - return ref, err -} - -func (s *client) DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) { - var partMiBs int64 = 10 - downloader := manager.NewDownloader(s.s3Client, func(d *manager.Downloader) { - d.PartSize = partMiBs * 1024 * 1024 // 10MB per part - d.Concurrency = 3 //The number of goroutines to spin up in parallel per call to Upload when sending parts - }) - - buffer := manager.NewWriteAtBuffer([]byte{}) - _, err := downloader.Download(ctx, buffer, &s3.GetObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - }) - if err != nil { - return nil, err - } - - if buffer == nil || len(buffer.Bytes()) == 0 { - return nil, ErrObjectNotFound - } - - return buffer.Bytes(), nil -} - -func (s *client) UploadObject(ctx context.Context, bucket string, key string, data []byte) error { - var partMiBs int64 = 10 - uploader := manager.NewUploader(s.s3Client, func(u *manager.Uploader) { - u.PartSize = partMiBs * 1024 * 1024 // 10MiB per part - u.Concurrency = 3 //The number of goroutines to spin up in parallel per call to upload when sending parts - }) - - _, err := uploader.Upload(ctx, &s3.PutObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - Body: bytes.NewReader(data), - }) - if err != nil { - return err - } - - return nil -} - -func (s *client) DeleteObject(ctx context.Context, bucket string, key string) error { - _, err := s.s3Client.DeleteObject(ctx, &s3.DeleteObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - }) - if err != nil { - return err - } - - return err -} - -func (s *client) ListObjects(ctx context.Context, bucket string, prefix string) ([]Object, error) { - output, err := s.s3Client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ - Bucket: aws.String(bucket), - Prefix: aws.String(prefix), - }) - if err != nil { - return nil, err - } - - objects := make([]Object, 0, len(output.Contents)) - for _, object := range output.Contents { - var size int64 = 0 - if object.Size != nil { - size = *object.Size - } - objects = append(objects, Object{ - Key: *object.Key, - Size: size, - }) - } - return objects, nil -} - -func (s *client) CreateBucket(ctx context.Context, bucket string) error { - _, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ - Bucket: aws.String(bucket), - }) - if err != nil { - return err - } - - return nil -} - -func (s *client) FragmentedUploadObject( - ctx context.Context, - bucket string, - key string, - data []byte, - fragmentSize int) error { - - fragments, err := BreakIntoFragments(key, data, s.cfg.FragmentPrefixChars, fragmentSize) - if err != nil { - return err - } - resultChannel := make(chan error, len(fragments)) - - ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) - defer cancel() - - for _, fragment := range fragments { - fragmentCapture := fragment - s.pool.Submit(func() { - s.fragmentedWriteTask(ctx, resultChannel, fragmentCapture, bucket) - }) - } - - for range fragments { - err := <-resultChannel - if err != nil { - return err - } - } - return ctx.Err() - -} - -// fragmentedWriteTask writes a single file to S3. -func (s *client) fragmentedWriteTask( - ctx context.Context, - resultChannel chan error, - fragment *Fragment, - bucket string) { - - _, err := s.s3Client.PutObject(ctx, - &s3.PutObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(fragment.FragmentKey), - Body: bytes.NewReader(fragment.Data), - }) - - resultChannel <- err -} - -func (s *client) FragmentedDownloadObject( - ctx context.Context, - bucket string, - key string, - fileSize int, - fragmentSize int) ([]byte, error) { - - if fragmentSize <= 0 { - return nil, errors.New("fragmentSize must be greater than 0") - } - - fragmentKeys, err := GetFragmentKeys(key, s.cfg.FragmentPrefixChars, GetFragmentCount(fileSize, fragmentSize)) - if err != nil { - return nil, err - } - resultChannel := make(chan *readResult, len(fragmentKeys)) - - ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) - defer cancel() - - for i, fragmentKey := range fragmentKeys { - boundFragmentKey := fragmentKey - boundI := i - s.pool.Submit(func() { - s.readTask(ctx, resultChannel, bucket, boundFragmentKey, boundI) - }) - } - - fragments := make([]*Fragment, len(fragmentKeys)) - for i := 0; i < len(fragmentKeys); i++ { - result := <-resultChannel - if result.err != nil { - return nil, result.err - } - fragments[result.fragment.Index] = result.fragment - } - - if ctx.Err() != nil { - return nil, ctx.Err() - } - - return RecombineFragments(fragments) - -} - -// readResult is the result of a read task. -type readResult struct { - fragment *Fragment - err error -} - -// readTask reads a single file from S3. -func (s *client) readTask( - ctx context.Context, - resultChannel chan *readResult, - bucket string, - key string, - index int) { - - result := &readResult{} - defer func() { - resultChannel <- result - }() - - ret, err := s.s3Client.GetObject(ctx, &s3.GetObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - }) - - if err != nil { - result.err = err - return - } - - data := make([]byte, *ret.ContentLength) - bytesRead := 0 - - for bytesRead < len(data) && ctx.Err() == nil { - count, err := ret.Body.Read(data[bytesRead:]) - if err != nil && err.Error() != "EOF" { - result.err = err - return - } - bytesRead += count - } - - result.fragment = &Fragment{ - FragmentKey: key, - Data: data, - Index: index, - } - - err = ret.Body.Close() - if err != nil { - result.err = err - } -} diff --git a/common/aws/s3/client.go b/common/aws/s3/client.go index 231d546ae6..8a88318117 100644 --- a/common/aws/s3/client.go +++ b/common/aws/s3/client.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "errors" + "github.com/gammazero/workerpool" + "runtime" "sync" commonaws "github.com/Layr-Labs/eigenda/common/aws" @@ -27,7 +29,9 @@ type Object struct { } type client struct { + cfg *commonaws.ClientConfig s3Client *s3.Client + pool *workerpool.WorkerPool logger logging.Logger } @@ -36,18 +40,19 @@ var _ Client = (*client)(nil) func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.Logger) (*client, error) { var err error once.Do(func() { - customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { - if cfg.EndpointURL != "" { - return aws.Endpoint{ - PartitionID: "aws", - URL: cfg.EndpointURL, - SigningRegion: cfg.Region, - }, nil - } - - // returning EndpointNotFoundError will allow the service to fallback to its default resolution - return aws.Endpoint{}, &aws.EndpointNotFoundError{} - }) + customResolver := aws.EndpointResolverWithOptionsFunc( + func(service, region string, options ...interface{}) (aws.Endpoint, error) { + if cfg.EndpointURL != "" { + return aws.Endpoint{ + PartitionID: "aws", + URL: cfg.EndpointURL, + SigningRegion: cfg.Region, + }, nil + } + + // returning EndpointNotFoundError will allow the service to fallback to its default resolution + return aws.Endpoint{}, &aws.EndpointNotFoundError{} + }) options := [](func(*config.LoadOptions) error){ config.WithRegion(cfg.Region), @@ -56,7 +61,9 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L } // If access key and secret access key are not provided, use the default credential provider if len(cfg.AccessKey) > 0 && len(cfg.SecretAccessKey) > 0 { - options = append(options, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) + options = append(options, + config.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) } awsConfig, errCfg := config.LoadDefaultConfig(context.Background(), options...) @@ -64,23 +71,32 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L err = errCfg return } + s3Client := s3.NewFromConfig(awsConfig, func(o *s3.Options) { o.UsePathStyle = true }) - ref = &client{s3Client: s3Client, logger: logger.With("component", "S3Client")} - }) - return ref, err -} -func (s *client) CreateBucket(ctx context.Context, bucket string) error { - _, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ - Bucket: aws.String(bucket), - }) - if err != nil { - return err - } + workers := 0 + if cfg.FragmentParallelismConstant > 0 { + workers = cfg.FragmentParallelismConstant + } + if cfg.FragmentParallelismFactor > 0 { + workers = cfg.FragmentParallelismFactor * runtime.NumCPU() + } - return nil + if workers == 0 { + workers = 1 + } + pool := workerpool.New(workers) + + ref = &client{ + cfg: &cfg, + s3Client: s3Client, + pool: pool, + logger: logger.With("component", "S3Client"), + } + }) + return ref, err } func (s *client) DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) { @@ -159,3 +175,162 @@ func (s *client) ListObjects(ctx context.Context, bucket string, prefix string) } return objects, nil } + +func (s *client) CreateBucket(ctx context.Context, bucket string) error { + _, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: aws.String(bucket), + }) + if err != nil { + return err + } + + return nil +} + +func (s *client) FragmentedUploadObject( + ctx context.Context, + bucket string, + key string, + data []byte, + fragmentSize int) error { + + fragments, err := BreakIntoFragments(key, data, s.cfg.FragmentPrefixChars, fragmentSize) + if err != nil { + return err + } + resultChannel := make(chan error, len(fragments)) + + ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) + defer cancel() + + for _, fragment := range fragments { + fragmentCapture := fragment + s.pool.Submit(func() { + s.fragmentedWriteTask(ctx, resultChannel, fragmentCapture, bucket) + }) + } + + for range fragments { + err := <-resultChannel + if err != nil { + return err + } + } + return ctx.Err() + +} + +// fragmentedWriteTask writes a single file to S3. +func (s *client) fragmentedWriteTask( + ctx context.Context, + resultChannel chan error, + fragment *Fragment, + bucket string) { + + _, err := s.s3Client.PutObject(ctx, + &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fragment.FragmentKey), + Body: bytes.NewReader(fragment.Data), + }) + + resultChannel <- err +} + +func (s *client) FragmentedDownloadObject( + ctx context.Context, + bucket string, + key string, + fileSize int, + fragmentSize int) ([]byte, error) { + + if fragmentSize <= 0 { + return nil, errors.New("fragmentSize must be greater than 0") + } + + fragmentKeys, err := GetFragmentKeys(key, s.cfg.FragmentPrefixChars, GetFragmentCount(fileSize, fragmentSize)) + if err != nil { + return nil, err + } + resultChannel := make(chan *readResult, len(fragmentKeys)) + + ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) + defer cancel() + + for i, fragmentKey := range fragmentKeys { + boundFragmentKey := fragmentKey + boundI := i + s.pool.Submit(func() { + s.readTask(ctx, resultChannel, bucket, boundFragmentKey, boundI) + }) + } + + fragments := make([]*Fragment, len(fragmentKeys)) + for i := 0; i < len(fragmentKeys); i++ { + result := <-resultChannel + if result.err != nil { + return nil, result.err + } + fragments[result.fragment.Index] = result.fragment + } + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + return RecombineFragments(fragments) + +} + +// readResult is the result of a read task. +type readResult struct { + fragment *Fragment + err error +} + +// readTask reads a single file from S3. +func (s *client) readTask( + ctx context.Context, + resultChannel chan *readResult, + bucket string, + key string, + index int) { + + result := &readResult{} + defer func() { + resultChannel <- result + }() + + ret, err := s.s3Client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + + if err != nil { + result.err = err + return + } + + data := make([]byte, *ret.ContentLength) + bytesRead := 0 + + for bytesRead < len(data) && ctx.Err() == nil { + count, err := ret.Body.Read(data[bytesRead:]) + if err != nil && err.Error() != "EOF" { + result.err = err + return + } + bytesRead += count + } + + result.fragment = &Fragment{ + FragmentKey: key, + Data: data, + Index: index, + } + + err = ret.Body.Close() + if err != nil { + result.err = err + } +} diff --git a/common/aws/s3/s3.go b/common/aws/s3/s3.go index a470d40c3f..74089099a9 100644 --- a/common/aws/s3/s3.go +++ b/common/aws/s3/s3.go @@ -21,28 +21,28 @@ type Client interface { // CreateBucket creates a bucket in S3. CreateBucket(ctx context.Context, bucket string) error - //// FragmentedUploadObject uploads a file to S3. The fragmentSize parameter specifies the maximum size of each - //// file uploaded to S3. If the file is larger than fragmentSize then it will be broken into - //// smaller parts and uploaded in parallel. The file will be reassembled on download. - //// - //// Note: if a file is uploaded with this method, only the FragmentedDownloadObject method should be used to - //// download the file. It is not advised to use DeleteObject on files uploaded with this method (if such - //// functionality is required, a new method to do so should be added to this interface). - //FragmentedUploadObject( - // ctx context.Context, - // bucket string, - // key string, - // data []byte, - // fragmentSize int) error + // FragmentedUploadObject uploads a file to S3. The fragmentSize parameter specifies the maximum size of each + // file uploaded to S3. If the file is larger than fragmentSize then it will be broken into + // smaller parts and uploaded in parallel. The file will be reassembled on download. // - //// FragmentedDownloadObject downloads a file from S3, as written by Upload. The fileSize (in bytes) and fragmentSize - //// must be the same as the values used in the FragmentedUploadObject call. - //// - //// Note: this method can only be used to download files that were uploaded with the FragmentedUploadObject method. - //FragmentedDownloadObject( - // ctx context.Context, - // bucket string, - // key string, - // fileSize int, - // fragmentSize int) ([]byte, error) + // Note: if a file is uploaded with this method, only the FragmentedDownloadObject method should be used to + // download the file. It is not advised to use DeleteObject on files uploaded with this method (if such + // functionality is required, a new method to do so should be added to this interface). + FragmentedUploadObject( + ctx context.Context, + bucket string, + key string, + data []byte, + fragmentSize int) error + + // FragmentedDownloadObject downloads a file from S3, as written by Upload. The fileSize (in bytes) and fragmentSize + // must be the same as the values used in the FragmentedUploadObject call. + // + // Note: this method can only be used to download files that were uploaded with the FragmentedUploadObject method. + FragmentedDownloadObject( + ctx context.Context, + bucket string, + key string, + fileSize int, + fragmentSize int) ([]byte, error) } diff --git a/common/aws/test/client_test.go b/common/aws/test/client_test.go index dfd935d973..0f5bc4087a 100644 --- a/common/aws/test/client_test.go +++ b/common/aws/test/client_test.go @@ -1,176 +1,176 @@ package test -//import ( -// "context" -// "github.com/Layr-Labs/eigenda/common" -// "github.com/Layr-Labs/eigenda/common/aws" -// "github.com/Layr-Labs/eigenda/common/aws/s3" -// "github.com/Layr-Labs/eigenda/common/mock" -// tu "github.com/Layr-Labs/eigenda/common/testutils" -// "github.com/Layr-Labs/eigenda/inabox/deploy" -// "github.com/ory/dockertest/v3" -// "github.com/stretchr/testify/assert" -// "math/rand" -// "os" -// "testing" -//) -// -//var ( -// dockertestPool *dockertest.Pool -// dockertestResource *dockertest.Resource -//) -// -//const ( -// localstackPort = "4570" -// localstackHost = "http://0.0.0.0:4570" -// bucket = "eigen-test" -//) -// -//type clientBuilder struct { -// // This method is called at the beginning of the test. -// start func() error -// // This method is called to build a new client. -// build func() (s3.Client, error) -// // This method is called at the end of the test when all operations are done. -// finish func() error -//} -// -//var clientBuilders = []*clientBuilder{ -// { -// start: func() error { -// return nil -// }, -// build: func() (s3.Client, error) { -// return mock.NewS3Client(), nil -// }, -// finish: func() error { -// return nil -// }, -// }, -// { -// start: func() error { -// return setupLocalstack() -// }, -// build: func() (s3.Client, error) { -// -// logger, err := common.NewLogger(common.DefaultLoggerConfig()) -// if err != nil { -// return nil, err -// } -// -// config := aws.DefaultClientConfig() -// config.EndpointURL = localstackHost -// config.Region = "us-east-1" -// -// err = os.Setenv("AWS_ACCESS_KEY_ID", "localstack") -// if err != nil { -// return nil, err -// } -// err = os.Setenv("AWS_SECRET_ACCESS_KEY", "localstack") -// if err != nil { -// return nil, err -// } -// -// client, err := s3.NewClient(context.Background(), *config, logger) -// if err != nil { -// return nil, err -// } -// -// err = client.CreateBucket(context.Background(), bucket) -// if err != nil { -// return nil, err -// } -// -// return client, nil -// }, -// finish: func() error { -// teardownLocalstack() -// return nil -// }, -// }, -//} -// -//func setupLocalstack() error { -// deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") -// -// if deployLocalStack { -// var err error -// dockertestPool, dockertestResource, err = deploy.StartDockertestWithLocalstackContainer(localstackPort) -// if err != nil && err.Error() == "container already exists" { -// teardownLocalstack() -// return err -// } -// } -// return nil -//} -// -//func teardownLocalstack() { -// deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") -// -// if deployLocalStack { -// deploy.PurgeDockertestResources(dockertestPool, dockertestResource) -// } -//} -// -//func RandomOperationsTest(t *testing.T, client s3.Client) { -// numberToWrite := 100 -// expectedData := make(map[string][]byte) -// -// fragmentSize := rand.Intn(1000) + 1000 -// -// for i := 0; i < numberToWrite; i++ { -// key := tu.RandomString(10) -// fragmentMultiple := rand.Float64() * 10 -// dataSize := int(fragmentMultiple*float64(fragmentSize)) + 1 -// data := tu.RandomBytes(dataSize) -// expectedData[key] = data -// -// err := client.FragmentedUploadObject(context.Background(), bucket, key, data, fragmentSize) -// assert.NoError(t, err) -// } -// -// // Read back the data -// for key, expected := range expectedData { -// data, err := client.FragmentedDownloadObject(context.Background(), bucket, key, len(expected), fragmentSize) -// assert.NoError(t, err) -// assert.Equal(t, expected, data) -// } -//} -// -//func TestRandomOperations(t *testing.T) { -// tu.InitializeRandom() -// for _, builder := range clientBuilders { -// err := builder.start() -// assert.NoError(t, err) -// -// client, err := builder.build() -// assert.NoError(t, err) -// RandomOperationsTest(t, client) -// -// err = builder.finish() -// assert.NoError(t, err) -// } -//} -// -//func ReadNonExistentValueTest(t *testing.T, client s3.Client) { -// _, err := client.FragmentedDownloadObject(context.Background(), bucket, "nonexistent", 1000, 1000) -// assert.Error(t, err) -// randomKey := tu.RandomString(10) -// _, err = client.FragmentedDownloadObject(context.Background(), bucket, randomKey, 0, 0) -// assert.Error(t, err) -//} -// -//func TestReadNonExistentValue(t *testing.T) { -// tu.InitializeRandom() -// for _, builder := range clientBuilders { -// err := builder.start() -// assert.NoError(t, err) -// -// client, err := builder.build() -// assert.NoError(t, err) -// ReadNonExistentValueTest(t, client) -// -// err = builder.finish() -// assert.NoError(t, err) -// } -//} +import ( + "context" + "github.com/Layr-Labs/eigenda/common" + "github.com/Layr-Labs/eigenda/common/aws" + "github.com/Layr-Labs/eigenda/common/aws/s3" + "github.com/Layr-Labs/eigenda/common/mock" + tu "github.com/Layr-Labs/eigenda/common/testutils" + "github.com/Layr-Labs/eigenda/inabox/deploy" + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/assert" + "math/rand" + "os" + "testing" +) + +var ( + dockertestPool *dockertest.Pool + dockertestResource *dockertest.Resource +) + +const ( + localstackPort = "4570" + localstackHost = "http://0.0.0.0:4570" + bucket = "eigen-test" +) + +type clientBuilder struct { + // This method is called at the beginning of the test. + start func() error + // This method is called to build a new client. + build func() (s3.Client, error) + // This method is called at the end of the test when all operations are done. + finish func() error +} + +var clientBuilders = []*clientBuilder{ + { + start: func() error { + return nil + }, + build: func() (s3.Client, error) { + return mock.NewS3Client(), nil + }, + finish: func() error { + return nil + }, + }, + { + start: func() error { + return setupLocalstack() + }, + build: func() (s3.Client, error) { + + logger, err := common.NewLogger(common.DefaultLoggerConfig()) + if err != nil { + return nil, err + } + + config := aws.DefaultClientConfig() + config.EndpointURL = localstackHost + config.Region = "us-east-1" + + err = os.Setenv("AWS_ACCESS_KEY_ID", "localstack") + if err != nil { + return nil, err + } + err = os.Setenv("AWS_SECRET_ACCESS_KEY", "localstack") + if err != nil { + return nil, err + } + + client, err := s3.NewClient(context.Background(), *config, logger) + if err != nil { + return nil, err + } + + err = client.CreateBucket(context.Background(), bucket) + if err != nil { + return nil, err + } + + return client, nil + }, + finish: func() error { + teardownLocalstack() + return nil + }, + }, +} + +func setupLocalstack() error { + deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") + + if deployLocalStack { + var err error + dockertestPool, dockertestResource, err = deploy.StartDockertestWithLocalstackContainer(localstackPort) + if err != nil && err.Error() == "container already exists" { + teardownLocalstack() + return err + } + } + return nil +} + +func teardownLocalstack() { + deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") + + if deployLocalStack { + deploy.PurgeDockertestResources(dockertestPool, dockertestResource) + } +} + +func RandomOperationsTest(t *testing.T, client s3.Client) { + numberToWrite := 100 + expectedData := make(map[string][]byte) + + fragmentSize := rand.Intn(1000) + 1000 + + for i := 0; i < numberToWrite; i++ { + key := tu.RandomString(10) + fragmentMultiple := rand.Float64() * 10 + dataSize := int(fragmentMultiple*float64(fragmentSize)) + 1 + data := tu.RandomBytes(dataSize) + expectedData[key] = data + + err := client.FragmentedUploadObject(context.Background(), bucket, key, data, fragmentSize) + assert.NoError(t, err) + } + + // Read back the data + for key, expected := range expectedData { + data, err := client.FragmentedDownloadObject(context.Background(), bucket, key, len(expected), fragmentSize) + assert.NoError(t, err) + assert.Equal(t, expected, data) + } +} + +func TestRandomOperations(t *testing.T) { + tu.InitializeRandom() + for _, builder := range clientBuilders { + err := builder.start() + assert.NoError(t, err) + + client, err := builder.build() + assert.NoError(t, err) + RandomOperationsTest(t, client) + + err = builder.finish() + assert.NoError(t, err) + } +} + +func ReadNonExistentValueTest(t *testing.T, client s3.Client) { + _, err := client.FragmentedDownloadObject(context.Background(), bucket, "nonexistent", 1000, 1000) + assert.Error(t, err) + randomKey := tu.RandomString(10) + _, err = client.FragmentedDownloadObject(context.Background(), bucket, randomKey, 0, 0) + assert.Error(t, err) +} + +func TestReadNonExistentValue(t *testing.T) { + tu.InitializeRandom() + for _, builder := range clientBuilders { + err := builder.start() + assert.NoError(t, err) + + client, err := builder.build() + assert.NoError(t, err) + ReadNonExistentValueTest(t, client) + + err = builder.finish() + assert.NoError(t, err) + } +} From 9e4b3c186ebff68524f6be77ae06251d1e0abd72 Mon Sep 17 00:00:00 2001 From: Cody Littley Date: Thu, 31 Oct 2024 13:03:08 -0500 Subject: [PATCH 10/11] Made suggested changes. Signed-off-by: Cody Littley --- common/aws/s3/client.go | 6 +-- common/aws/s3/fragment.go | 26 +++++----- common/aws/{test => s3}/fragment_test.go | 65 +++++++++++++----------- 3 files changed, 50 insertions(+), 47 deletions(-) rename common/aws/{test => s3}/fragment_test.go (77%) diff --git a/common/aws/s3/client.go b/common/aws/s3/client.go index 8a88318117..10b32272c2 100644 --- a/common/aws/s3/client.go +++ b/common/aws/s3/client.go @@ -194,7 +194,7 @@ func (s *client) FragmentedUploadObject( data []byte, fragmentSize int) error { - fragments, err := BreakIntoFragments(key, data, s.cfg.FragmentPrefixChars, fragmentSize) + fragments, err := breakIntoFragments(key, data, s.cfg.FragmentPrefixChars, fragmentSize) if err != nil { return err } @@ -248,7 +248,7 @@ func (s *client) FragmentedDownloadObject( return nil, errors.New("fragmentSize must be greater than 0") } - fragmentKeys, err := GetFragmentKeys(key, s.cfg.FragmentPrefixChars, GetFragmentCount(fileSize, fragmentSize)) + fragmentKeys, err := getFragmentKeys(key, s.cfg.FragmentPrefixChars, getFragmentCount(fileSize, fragmentSize)) if err != nil { return nil, err } @@ -278,7 +278,7 @@ func (s *client) FragmentedDownloadObject( return nil, ctx.Err() } - return RecombineFragments(fragments) + return recombineFragments(fragments) } diff --git a/common/aws/s3/fragment.go b/common/aws/s3/fragment.go index 6f978fbdc6..21da697d96 100644 --- a/common/aws/s3/fragment.go +++ b/common/aws/s3/fragment.go @@ -6,8 +6,8 @@ import ( "strings" ) -// GetFragmentCount returns the number of fragments that a file of the given size will be broken into. -func GetFragmentCount(fileSize int, fragmentSize int) int { +// getFragmentCount returns the number of fragments that a file of the given size will be broken into. +func getFragmentCount(fileSize int, fragmentSize int) int { if fileSize < fragmentSize { return 1 } else if fileSize%fragmentSize == 0 { @@ -17,7 +17,7 @@ func GetFragmentCount(fileSize int, fragmentSize int) int { } } -// GetFragmentKey returns the key for the fragment at the given index. +// getFragmentKey returns the key for the fragment at the given index. // // Fragment keys take the form of "prefix/body-index[f]". The prefix is the first prefixLength characters // of the file key. The body is the file key. The index is the index of the fragment. The character "f" is appended @@ -25,7 +25,7 @@ func GetFragmentCount(fileSize int, fragmentSize int) int { // // Example: fileKey="abc123", prefixLength=2, fragmentCount=3 // The keys will be "ab/abc123-0", "ab/abc123-1", "ab/abc123-2f" -func GetFragmentKey(fileKey string, prefixLength int, fragmentCount int, index int) (string, error) { +func getFragmentKey(fileKey string, prefixLength int, fragmentCount int, index int) (string, error) { var prefix string if prefixLength > len(fileKey) { prefix = fileKey @@ -52,9 +52,9 @@ type Fragment struct { Index int } -// BreakIntoFragments breaks a file into fragments of the given size. -func BreakIntoFragments(fileKey string, data []byte, prefixLength int, fragmentSize int) ([]*Fragment, error) { - fragmentCount := GetFragmentCount(len(data), fragmentSize) +// breakIntoFragments breaks a file into fragments of the given size. +func breakIntoFragments(fileKey string, data []byte, prefixLength int, fragmentSize int) ([]*Fragment, error) { + fragmentCount := getFragmentCount(len(data), fragmentSize) fragments := make([]*Fragment, fragmentCount) for i := 0; i < fragmentCount; i++ { start := i * fragmentSize @@ -63,7 +63,7 @@ func BreakIntoFragments(fileKey string, data []byte, prefixLength int, fragmentS end = len(data) } - fragmentKey, err := GetFragmentKey(fileKey, prefixLength, fragmentCount, i) + fragmentKey, err := getFragmentKey(fileKey, prefixLength, fragmentCount, i) if err != nil { return nil, err } @@ -76,11 +76,11 @@ func BreakIntoFragments(fileKey string, data []byte, prefixLength int, fragmentS return fragments, nil } -// GetFragmentKeys returns the keys for all fragments of a file. -func GetFragmentKeys(fileKey string, prefixLength int, fragmentCount int) ([]string, error) { +// getFragmentKeys returns the keys for all fragments of a file. +func getFragmentKeys(fileKey string, prefixLength int, fragmentCount int) ([]string, error) { keys := make([]string, fragmentCount) for i := 0; i < fragmentCount; i++ { - fragmentKey, err := GetFragmentKey(fileKey, prefixLength, fragmentCount, i) + fragmentKey, err := getFragmentKey(fileKey, prefixLength, fragmentCount, i) if err != nil { return nil, err } @@ -89,9 +89,9 @@ func GetFragmentKeys(fileKey string, prefixLength int, fragmentCount int) ([]str return keys, nil } -// RecombineFragments recombines fragments into a single file. +// recombineFragments recombines fragments into a single file. // Returns an error if any fragments are missing. -func RecombineFragments(fragments []*Fragment) ([]byte, error) { +func recombineFragments(fragments []*Fragment) ([]byte, error) { if len(fragments) == 0 { return nil, fmt.Errorf("no fragments") diff --git a/common/aws/test/fragment_test.go b/common/aws/s3/fragment_test.go similarity index 77% rename from common/aws/test/fragment_test.go rename to common/aws/s3/fragment_test.go index fc5a257731..04271ce8e8 100644 --- a/common/aws/test/fragment_test.go +++ b/common/aws/s3/fragment_test.go @@ -1,8 +1,7 @@ -package test +package s3 import ( "fmt" - "github.com/Layr-Labs/eigenda/common/aws/s3" tu "github.com/Layr-Labs/eigenda/common/testutils" "github.com/stretchr/testify/assert" "math/rand" @@ -16,40 +15,40 @@ func TestGetFragmentCount(t *testing.T) { // Test a file smaller than a fragment fileSize := rand.Intn(100) + 100 fragmentSize := fileSize * 2 - fragmentCount := s3.GetFragmentCount(fileSize, fragmentSize) + fragmentCount := getFragmentCount(fileSize, fragmentSize) assert.Equal(t, 1, fragmentCount) // Test a file that can fit in a single fragment fileSize = rand.Intn(100) + 100 fragmentSize = fileSize - fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) + fragmentCount = getFragmentCount(fileSize, fragmentSize) assert.Equal(t, 1, fragmentCount) // Test a file that is one byte larger than a fragment fileSize = rand.Intn(100) + 100 fragmentSize = fileSize - 1 - fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) + fragmentCount = getFragmentCount(fileSize, fragmentSize) assert.Equal(t, 2, fragmentCount) // Test a file that is one less than a multiple of the fragment size fragmentSize = rand.Intn(100) + 100 expectedFragmentCount := rand.Intn(10) + 1 fileSize = fragmentSize*expectedFragmentCount - 1 - fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) + fragmentCount = getFragmentCount(fileSize, fragmentSize) assert.Equal(t, expectedFragmentCount, fragmentCount) // Test a file that is a multiple of the fragment size fragmentSize = rand.Intn(100) + 100 expectedFragmentCount = rand.Intn(10) + 1 fileSize = fragmentSize * expectedFragmentCount - fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) + fragmentCount = getFragmentCount(fileSize, fragmentSize) assert.Equal(t, expectedFragmentCount, fragmentCount) // Test a file that is one more than a multiple of the fragment size fragmentSize = rand.Intn(100) + 100 expectedFragmentCount = rand.Intn(10) + 2 fileSize = fragmentSize*(expectedFragmentCount-1) + 1 - fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) + fragmentCount = getFragmentCount(fileSize, fragmentSize) assert.Equal(t, expectedFragmentCount, fragmentCount) } @@ -63,7 +62,7 @@ func TestPrefix(t *testing.T) { for i := 0; i < keyLength*2; i++ { fragmentCount := rand.Intn(10) + 10 fragmentIndex := rand.Intn(fragmentCount) - fragmentKey, err := s3.GetFragmentKey(key, i, fragmentCount, fragmentIndex) + fragmentKey, err := getFragmentKey(key, i, fragmentCount, fragmentIndex) assert.NoError(t, err) parts := strings.Split(fragmentKey, "/") @@ -87,7 +86,7 @@ func TestKeyBody(t *testing.T) { key := tu.RandomString(keyLength) fragmentCount := rand.Intn(10) + 10 fragmentIndex := rand.Intn(fragmentCount) - fragmentKey, err := s3.GetFragmentKey(key, rand.Intn(10), fragmentCount, fragmentIndex) + fragmentKey, err := getFragmentKey(key, rand.Intn(10), fragmentCount, fragmentIndex) assert.NoError(t, err) parts := strings.Split(fragmentKey, "/") @@ -107,7 +106,7 @@ func TestKeyIndex(t *testing.T) { for i := 0; i < 10; i++ { fragmentCount := rand.Intn(10) + 10 index := rand.Intn(fragmentCount) - fragmentKey, err := s3.GetFragmentKey(tu.RandomString(10), rand.Intn(10), fragmentCount, index) + fragmentKey, err := getFragmentKey(tu.RandomString(10), rand.Intn(10), fragmentCount, index) assert.NoError(t, err) parts := strings.Split(fragmentKey, "/") @@ -127,7 +126,7 @@ func TestKeyPostfix(t *testing.T) { segmentCount := rand.Intn(10) + 10 for i := 0; i < segmentCount; i++ { - fragmentKey, err := s3.GetFragmentKey(tu.RandomString(10), rand.Intn(10), segmentCount, i) + fragmentKey, err := getFragmentKey(tu.RandomString(10), rand.Intn(10), segmentCount, i) assert.NoError(t, err) if i == segmentCount-1 { @@ -138,11 +137,15 @@ func TestKeyPostfix(t *testing.T) { } } +// TestExampleInGodoc tests the example provided in the documentation for getFragmentKey(). +// +// Example: fileKey="abc123", prefixLength=2, fragmentCount=3 +// The keys will be "ab/abc123-0", "ab/abc123-1", "ab/abc123-2f" func TestExampleInGodoc(t *testing.T) { fileKey := "abc123" prefixLength := 2 fragmentCount := 3 - fragmentKeys, err := s3.GetFragmentKeys(fileKey, prefixLength, fragmentCount) + fragmentKeys, err := getFragmentKeys(fileKey, prefixLength, fragmentCount) assert.NoError(t, err) assert.Equal(t, 3, len(fragmentKeys)) assert.Equal(t, "ab/abc123-0", fragmentKeys[0]) @@ -157,12 +160,12 @@ func TestGetFragmentKeys(t *testing.T) { prefixLength := rand.Intn(3) + 1 fragmentCount := rand.Intn(10) + 10 - fragmentKeys, err := s3.GetFragmentKeys(fileKey, prefixLength, fragmentCount) + fragmentKeys, err := getFragmentKeys(fileKey, prefixLength, fragmentCount) assert.NoError(t, err) assert.Equal(t, fragmentCount, len(fragmentKeys)) for i := 0; i < fragmentCount; i++ { - expectedKey, err := s3.GetFragmentKey(fileKey, prefixLength, fragmentCount, i) + expectedKey, err := getFragmentKey(fileKey, prefixLength, fragmentCount, i) assert.NoError(t, err) assert.Equal(t, expectedKey, fragmentKeys[i]) @@ -192,14 +195,14 @@ func TestGetFragments(t *testing.T) { prefixLength := rand.Intn(3) + 1 fragmentSize := rand.Intn(100) + 100 - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) assert.NoError(t, err) - assert.Equal(t, s3.GetFragmentCount(len(data), fragmentSize), len(fragments)) + assert.Equal(t, getFragmentCount(len(data), fragmentSize), len(fragments)) totalSize := 0 for i, fragment := range fragments { - fragmentKey, err := s3.GetFragmentKey(fileKey, prefixLength, len(fragments), i) + fragmentKey, err := getFragmentKey(fileKey, prefixLength, len(fragments), i) assert.NoError(t, err) assert.Equal(t, fragmentKey, fragment.FragmentKey) @@ -224,11 +227,11 @@ func TestGetFragmentsSmallFile(t *testing.T) { prefixLength := rand.Intn(3) + 1 fragmentSize := rand.Intn(100) + 100 - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) assert.NoError(t, err) assert.Equal(t, 1, len(fragments)) - fragmentKey, err := s3.GetFragmentKey(fileKey, prefixLength, 1, 0) + fragmentKey, err := getFragmentKey(fileKey, prefixLength, 1, 0) assert.NoError(t, err) assert.Equal(t, fragmentKey, fragments[0].FragmentKey) assert.Equal(t, data, fragments[0].Data) @@ -243,11 +246,11 @@ func TestGetFragmentsExactlyOnePerfectlySizedFile(t *testing.T) { data := tu.RandomBytes(fragmentSize) prefixLength := rand.Intn(3) + 1 - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) assert.NoError(t, err) assert.Equal(t, 1, len(fragments)) - fragmentKey, err := s3.GetFragmentKey(fileKey, prefixLength, 1, 0) + fragmentKey, err := getFragmentKey(fileKey, prefixLength, 1, 0) assert.NoError(t, err) assert.Equal(t, fragmentKey, fragments[0].FragmentKey) assert.Equal(t, data, fragments[0].Data) @@ -262,9 +265,9 @@ func TestRecombineFragments(t *testing.T) { prefixLength := rand.Intn(3) + 1 fragmentSize := rand.Intn(100) + 100 - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) assert.NoError(t, err) - recombinedData, err := s3.RecombineFragments(fragments) + recombinedData, err := recombineFragments(fragments) assert.NoError(t, err) assert.Equal(t, data, recombinedData) @@ -274,7 +277,7 @@ func TestRecombineFragments(t *testing.T) { fragments[i], fragments[j] = fragments[j], fragments[i] } - recombinedData, err = s3.RecombineFragments(fragments) + recombinedData, err = recombineFragments(fragments) assert.NoError(t, err) assert.Equal(t, data, recombinedData) } @@ -287,10 +290,10 @@ func TestRecombineFragmentsSmallFile(t *testing.T) { prefixLength := rand.Intn(3) + 1 fragmentSize := rand.Intn(100) + 100 - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) assert.NoError(t, err) assert.Equal(t, 1, len(fragments)) - recombinedData, err := s3.RecombineFragments(fragments) + recombinedData, err := recombineFragments(fragments) assert.NoError(t, err) assert.Equal(t, data, recombinedData) } @@ -303,13 +306,13 @@ func TestMissingFragment(t *testing.T) { prefixLength := rand.Intn(3) + 1 fragmentSize := rand.Intn(100) + 100 - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) assert.NoError(t, err) fragmentIndexToSkip := rand.Intn(len(fragments)) fragments = append(fragments[:fragmentIndexToSkip], fragments[fragmentIndexToSkip+1:]...) - _, err = s3.RecombineFragments(fragments[:len(fragments)-1]) + _, err = recombineFragments(fragments[:len(fragments)-1]) assert.Error(t, err) } @@ -321,10 +324,10 @@ func TestMissingFinalFragment(t *testing.T) { prefixLength := rand.Intn(3) + 1 fragmentSize := rand.Intn(100) + 100 - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) assert.NoError(t, err) fragments = fragments[:len(fragments)-1] - _, err = s3.RecombineFragments(fragments) + _, err = recombineFragments(fragments) assert.Error(t, err) } From 39a8a8ae65c626c152695d8ae34f25ee9a1dcdc0 Mon Sep 17 00:00:00 2001 From: Cody Littley Date: Fri, 1 Nov 2024 14:51:14 -0500 Subject: [PATCH 11/11] Made suggested changes. Signed-off-by: Cody Littley --- common/aws/s3/client.go | 14 +++++++++----- common/aws/s3/s3.go | 4 ++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/common/aws/s3/client.go b/common/aws/s3/client.go index 10b32272c2..ddc3ce4e14 100644 --- a/common/aws/s3/client.go +++ b/common/aws/s3/client.go @@ -4,7 +4,7 @@ import ( "bytes" "context" "errors" - "github.com/gammazero/workerpool" + "golang.org/x/sync/errgroup" "runtime" "sync" @@ -31,7 +31,7 @@ type Object struct { type client struct { cfg *commonaws.ClientConfig s3Client *s3.Client - pool *workerpool.WorkerPool + pool *errgroup.Group logger logging.Logger } @@ -87,7 +87,9 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L if workers == 0 { workers = 1 } - pool := workerpool.New(workers) + + pool, _ := errgroup.WithContext(ctx) + pool.SetLimit(workers) ref = &client{ cfg: &cfg, @@ -205,8 +207,9 @@ func (s *client) FragmentedUploadObject( for _, fragment := range fragments { fragmentCapture := fragment - s.pool.Submit(func() { + s.pool.Go(func() error { s.fragmentedWriteTask(ctx, resultChannel, fragmentCapture, bucket) + return nil }) } @@ -260,8 +263,9 @@ func (s *client) FragmentedDownloadObject( for i, fragmentKey := range fragmentKeys { boundFragmentKey := fragmentKey boundI := i - s.pool.Submit(func() { + s.pool.Go(func() error { s.readTask(ctx, resultChannel, bucket, boundFragmentKey, boundI) + return nil }) } diff --git a/common/aws/s3/s3.go b/common/aws/s3/s3.go index 74089099a9..d96172dbcc 100644 --- a/common/aws/s3/s3.go +++ b/common/aws/s3/s3.go @@ -28,6 +28,10 @@ type Client interface { // Note: if a file is uploaded with this method, only the FragmentedDownloadObject method should be used to // download the file. It is not advised to use DeleteObject on files uploaded with this method (if such // functionality is required, a new method to do so should be added to this interface). + // + // Note: if this operation fails partway through, some file fragments may have made it to S3 and others may not. + // In order to prevent long term accumulation of fragments, it is suggested to use this method in conjunction with + // a bucket configured to have a TTL. FragmentedUploadObject( ctx context.Context, bucket string,