Skip to content

Commit 8226f42

Browse files
authored
Improve batch job submit validation efficiency (#2179)
1 parent 7fc4226 commit 8226f42

File tree

5 files changed

+72
-44
lines changed

5 files changed

+72
-44
lines changed

pkg/enqueuer/enqueuer.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,10 @@ type ItemList struct {
5656
}
5757

5858
type S3Lister struct {
59-
S3Paths []string `json:"s3_paths"` // s3://<bucket_name>/key
60-
Includes []string `json:"includes"`
61-
Excludes []string `json:"excludes"`
59+
S3Paths []string `json:"s3_paths"` // s3://<bucket_name>/key
60+
Includes []string `json:"includes"`
61+
Excludes []string `json:"excludes"`
62+
MaxResults *int64 `json:"-"` // this is not currently exposed to the user (it's used for validations)
6263
}
6364

6465
type FilePathLister struct {
@@ -246,7 +247,7 @@ func (e *Enqueuer) enqueueS3Paths(s3PathsLister *FilePathLister) (int, error) {
246247
var s3PathList []string
247248
uploader := newSQSBatchUploader(e.envConfig.APIName, e.envConfig.JobID, e.queueURL, e.aws.SQS())
248249

249-
err := s3IteratorFromLister(e.aws, s3PathsLister.S3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) {
250+
_, err := s3IteratorFromLister(e.aws, s3PathsLister.S3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) {
250251
s3Path := awslib.S3Path(bucket, *s3Obj.Key)
251252

252253
s3PathList = append(s3PathList, s3Path)
@@ -290,7 +291,7 @@ func (e *Enqueuer) enqueueS3FileContents(delimitedFiles *DelimitedFiles) (int, e
290291
uploader := newSQSBatchUploader(e.envConfig.APIName, e.envConfig.JobID, e.queueURL, e.aws.SQS())
291292

292293
bytesBuffer := bytes.NewBuffer([]byte{})
293-
err := s3IteratorFromLister(e.aws, delimitedFiles.S3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) {
294+
_, err := s3IteratorFromLister(e.aws, delimitedFiles.S3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) {
294295
s3Path := awslib.S3Path(bucket, *s3Obj.Key)
295296
log.Info("enqueuing contents from file", zap.String("path", s3Path))
296297

pkg/enqueuer/helpers.go

+19-8
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ func addJSONObjectsToQueue(uploader *sqsBatchUploader, jsonMessageList *jsonBuff
6464
return nil
6565
}
6666

67-
func s3IteratorFromLister(awsClient *awslib.Client, s3Lister S3Lister, fn func(string, *s3.Object) (bool, error)) error {
67+
func s3IteratorFromLister(awsClient *awslib.Client, s3Lister S3Lister, fn func(string, *s3.Object) (bool, error)) (int64, error) {
6868
includeGlobPatterns := make([]glob.Glob, 0, len(s3Lister.Includes))
6969

7070
for _, includePattern := range s3Lister.Includes {
7171
globExpression, err := glob.Compile(includePattern, '/')
7272
if err != nil {
73-
return errors.Wrap(err, "failed to interpret glob pattern", includePattern)
73+
return 0, errors.Wrap(err, "failed to interpret glob pattern", includePattern)
7474
}
7575
includeGlobPatterns = append(includeGlobPatterns, globExpression)
7676
}
@@ -79,20 +79,22 @@ func s3IteratorFromLister(awsClient *awslib.Client, s3Lister S3Lister, fn func(s
7979
for _, excludePattern := range s3Lister.Excludes {
8080
globExpression, err := glob.Compile(excludePattern, '/')
8181
if err != nil {
82-
return errors.Wrap(err, "failed to interpret glob pattern", excludePattern)
82+
return 0, errors.Wrap(err, "failed to interpret glob pattern", excludePattern)
8383
}
8484
excludeGlobPatterns = append(excludeGlobPatterns, globExpression)
8585
}
8686

87+
var numResults int64
88+
8789
for _, s3Path := range s3Lister.S3Paths {
8890
bucket, key, err := awslib.SplitS3Path(s3Path)
8991
if err != nil {
90-
return err
92+
return 0, err
9193
}
9294

9395
awsClientForBucket, err := awslib.NewFromClientS3Path(s3Path, awsClient)
9496
if err != nil {
95-
return err
97+
return 0, err
9698
}
9799

98100
err = awsClientForBucket.S3Iterator(bucket, key, false, nil, nil, func(s3Obj *s3.Object) (bool, error) {
@@ -117,15 +119,24 @@ func s3IteratorFromLister(awsClient *awslib.Client, s3Lister S3Lister, fn func(s
117119
}
118120

119121
if !shouldSkip {
120-
return fn(bucket, s3Obj)
122+
shouldContinue, err := fn(bucket, s3Obj)
123+
numResults++
124+
if s3Lister.MaxResults != nil && numResults >= *s3Lister.MaxResults {
125+
shouldContinue = false
126+
}
127+
return shouldContinue, err
121128
}
122129

123130
return true, nil
124131
})
125132
if err != nil {
126-
return err
133+
return 0, err
134+
}
135+
136+
if s3Lister.MaxResults != nil && numResults >= *s3Lister.MaxResults {
137+
return numResults, nil
127138
}
128139
}
129140

130-
return nil
141+
return numResults, nil
131142
}

pkg/operator/resources/job/batchapi/s3_iterator.go

+19-8
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ import (
2626
)
2727

2828
// Takes in a function(shouldSkip, bucketName, s3.Object)
29-
func s3IteratorFromLister(s3Lister schema.S3Lister, fn func(string, *s3.Object) (bool, error)) error {
29+
func s3IteratorFromLister(s3Lister schema.S3Lister, fn func(string, *s3.Object) (bool, error)) (int64, error) {
3030
includeGlobPatterns := make([]glob.Glob, 0, len(s3Lister.Includes))
3131

3232
for _, includePattern := range s3Lister.Includes {
3333
globExpression, err := glob.Compile(includePattern, '/')
3434
if err != nil {
35-
return errors.Wrap(err, "failed to interpret glob pattern", includePattern)
35+
return 0, errors.Wrap(err, "failed to interpret glob pattern", includePattern)
3636
}
3737
includeGlobPatterns = append(includeGlobPatterns, globExpression)
3838
}
@@ -41,20 +41,22 @@ func s3IteratorFromLister(s3Lister schema.S3Lister, fn func(string, *s3.Object)
4141
for _, excludePattern := range s3Lister.Excludes {
4242
globExpression, err := glob.Compile(excludePattern, '/')
4343
if err != nil {
44-
return errors.Wrap(err, "failed to interpret glob pattern", excludePattern)
44+
return 0, errors.Wrap(err, "failed to interpret glob pattern", excludePattern)
4545
}
4646
excludeGlobPatterns = append(excludeGlobPatterns, globExpression)
4747
}
4848

49+
var numResults int64
50+
4951
for _, s3Path := range s3Lister.S3Paths {
5052
bucket, key, err := aws.SplitS3Path(s3Path)
5153
if err != nil {
52-
return err
54+
return 0, err
5355
}
5456

5557
awsClientForBucket, err := aws.NewFromClientS3Path(s3Path, config.AWS)
5658
if err != nil {
57-
return err
59+
return 0, err
5860
}
5961

6062
err = awsClientForBucket.S3Iterator(bucket, key, false, nil, nil, func(s3Obj *s3.Object) (bool, error) {
@@ -79,15 +81,24 @@ func s3IteratorFromLister(s3Lister schema.S3Lister, fn func(string, *s3.Object)
7981
}
8082

8183
if !shouldSkip {
82-
return fn(bucket, s3Obj)
84+
shouldContinue, err := fn(bucket, s3Obj)
85+
numResults++
86+
if s3Lister.MaxResults != nil && numResults >= *s3Lister.MaxResults {
87+
shouldContinue = false
88+
}
89+
return shouldContinue, err
8390
}
8491

8592
return true, nil
8693
})
8794
if err != nil {
88-
return err
95+
return 0, err
96+
}
97+
98+
if s3Lister.MaxResults != nil && numResults >= *s3Lister.MaxResults {
99+
return numResults, nil
89100
}
90101
}
91102

92-
return nil
103+
return numResults, nil
93104
}

pkg/operator/resources/job/batchapi/validations.go

+24-20
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
awslib "github.com/cortexlabs/cortex/pkg/lib/aws"
2525
cr "github.com/cortexlabs/cortex/pkg/lib/configreader"
2626
"github.com/cortexlabs/cortex/pkg/lib/errors"
27+
"github.com/cortexlabs/cortex/pkg/lib/pointer"
2728
"github.com/cortexlabs/cortex/pkg/operator/resources/job"
2829
"github.com/cortexlabs/cortex/pkg/operator/schema"
2930
"github.com/gobwas/glob"
@@ -143,26 +144,30 @@ func validateS3Lister(s3Lister *schema.S3Lister) error {
143144
}
144145
}
145146

146-
filesFound := 0
147147
for _, s3Path := range s3Lister.S3Paths {
148148
if !awslib.IsValidS3Path(s3Path) {
149149
return awslib.ErrorInvalidS3Path(s3Path)
150150
}
151+
}
151152

152-
err := s3IteratorFromLister(*s3Lister, func(objPath string, s3Obj *s3.Object) (bool, error) {
153-
filesFound++
154-
return false, nil
155-
})
156-
if err != nil {
157-
return errors.Wrap(err, s3Path)
158-
}
153+
shortCircuitLister := schema.S3Lister{
154+
S3Paths: s3Lister.S3Paths,
155+
Includes: s3Lister.Includes,
156+
Excludes: s3Lister.Excludes,
157+
MaxResults: pointer.Int64(1),
158+
}
159+
numResults, err := s3IteratorFromLister(shortCircuitLister, func(objPath string, s3Obj *s3.Object) (bool, error) {
160+
return false, nil
161+
})
162+
if err != nil {
163+
return err
164+
}
159165

160-
if filesFound > 0 {
161-
return nil
162-
}
166+
if numResults == 0 {
167+
return ErrorNoS3FilesFound()
163168
}
164169

165-
return ErrorNoS3FilesFound()
170+
return nil
166171
}
167172

168173
func listFilesDryRun(s3Lister *schema.S3Lister) ([]string, error) {
@@ -171,15 +176,14 @@ func listFilesDryRun(s3Lister *schema.S3Lister) ([]string, error) {
171176
if !awslib.IsValidS3Path(s3Path) {
172177
return nil, awslib.ErrorInvalidS3Path(s3Path)
173178
}
179+
}
174180

175-
err := s3IteratorFromLister(*s3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) {
176-
s3Files = append(s3Files, awslib.S3Path(bucket, *s3Obj.Key))
177-
return true, nil
178-
})
179-
180-
if err != nil {
181-
return nil, errors.Wrap(err, s3Path)
182-
}
181+
_, err := s3IteratorFromLister(*s3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) {
182+
s3Files = append(s3Files, awslib.S3Path(bucket, *s3Obj.Key))
183+
return true, nil
184+
})
185+
if err != nil {
186+
return nil, err
183187
}
184188

185189
if len(s3Files) == 0 {

pkg/operator/schema/job_submission.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ type ItemList struct {
2828
}
2929

3030
type S3Lister struct {
31-
S3Paths []string `json:"s3_paths"` // s3://<bucket_name>/key
32-
Includes []string `json:"includes"`
33-
Excludes []string `json:"excludes"`
31+
S3Paths []string `json:"s3_paths"` // s3://<bucket_name>/key
32+
Includes []string `json:"includes"`
33+
Excludes []string `json:"excludes"`
34+
MaxResults *int64 `json:"-"` // this is not currently exposed to the user (it's used for validations)
3435
}
3536

3637
type FilePathLister struct {

0 commit comments

Comments
 (0)