Skip to content

Commit b4b1319

Browse files
authored
Training dataset resource bug (#86)
* Default spark compute resource to spark workloads * Add spark compute to model config and remove default * Update model config documentation * Refactoring variable and deleting unnecessary function * Remove duplicate tag field in doc * Declare spark compute validation object const as ptr * Update models.md
1 parent 9eaac2d commit b4b1319

File tree

7 files changed

+89
-72
lines changed

7 files changed

+89
-72
lines changed

docs/applications/resources/models.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,21 @@ Train custom TensorFlow models at scale.
4141
start_delay_secs: <int> # start evaluating after waiting for this many seconds (default: 120)
4242
throttle_secs: <int> # do not re-evaluate unless the last evaluation was started at least this many seconds ago (default: 600)
4343

44-
compute:
44+
compute: # Resources for training and evaluations steps (TensorFlow)
4545
cpu: <string> # CPU request (default: Null)
4646
mem: <string> # memory request (default: Null)
4747
gpu: <string> # GPU request (default: Null)
4848

49+
dataset_compute: # Resources for constructing training dataset (Spark)
50+
executors: <int> # number of spark executors (default: 1)
51+
driver_cpu: <string> # CPU request for spark driver (default: 1)
52+
driver_mem: <string> # memory request for spark driver (default: 500Mi)
53+
driver_mem_overhead: <string> # off-heap (non-JVM) memory allocated to the driver (overrides mem_overhead_factor) (default: min[driver_mem * 0.4, 384Mi])
54+
executor_cpu: <string> # CPU request for each spark executor (default: 1)
55+
executor_mem: <string> # memory request for each spark executor (default: 500Mi)
56+
executor_mem_overhead: <string> # off-heap (non-JVM) memory allocated to each executor (overrides mem_overhead_factor) (default: min[executor_mem * 0.4, 384Mi])
57+
mem_overhead_factor: <float> # the proportion of driver_mem/executor_mem which will be additionally allocated for off-heap (non-JVM) memory (default: 0.4)
58+
4959
tags:
5060
<string>: <scalar> # arbitrary key/value pairs to attach to the resource (optional)
5161
...

pkg/operator/api/userconfig/aggregates.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ var aggregateValidation = &cr.StructValidation{
5151
},
5252
},
5353
inputValuesFieldValidation,
54-
sparkComputeFieldValidation,
54+
sparkComputeFieldValidation("Compute"),
5555
tagsFieldValidation,
5656
typeFieldValidation,
5757
},

pkg/operator/api/userconfig/compute.go

Lines changed: 70 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -45,84 +45,88 @@ type SparkCompute struct {
4545
MemOverheadFactor *float64 `json:"mem_overhead_factor" yaml:"mem_overhead_factor"`
4646
}
4747

48-
var sparkComputeFieldValidation = &cr.StructFieldValidation{
49-
StructField: "Compute",
50-
StructValidation: &cr.StructValidation{
51-
StructFieldValidations: []*cr.StructFieldValidation{
52-
{
53-
StructField: "Executors",
54-
Int32Validation: &cr.Int32Validation{
55-
Default: 1,
56-
GreaterThan: pointer.Int32(0),
57-
},
48+
var sparkComputeStructValidation = &cr.StructValidation{
49+
StructFieldValidations: []*cr.StructFieldValidation{
50+
{
51+
StructField: "Executors",
52+
Int32Validation: &cr.Int32Validation{
53+
Default: 1,
54+
GreaterThan: pointer.Int32(0),
5855
},
59-
{
60-
StructField: "DriverCPU",
61-
StringValidation: &cr.StringValidation{
62-
Default: "1",
63-
},
64-
Parser: QuantityParser(&QuantityValidation{
65-
Min: k8sresource.MustParse("1"),
66-
}),
56+
},
57+
{
58+
StructField: "DriverCPU",
59+
StringValidation: &cr.StringValidation{
60+
Default: "1",
6761
},
68-
{
69-
StructField: "ExecutorCPU",
70-
StringValidation: &cr.StringValidation{
71-
Default: "1",
72-
},
73-
Parser: QuantityParser(&QuantityValidation{
74-
Min: k8sresource.MustParse("1"),
75-
Int: true,
76-
}),
62+
Parser: QuantityParser(&QuantityValidation{
63+
Min: k8sresource.MustParse("1"),
64+
}),
65+
},
66+
{
67+
StructField: "ExecutorCPU",
68+
StringValidation: &cr.StringValidation{
69+
Default: "1",
7770
},
78-
{
79-
StructField: "DriverMem",
80-
StringValidation: &cr.StringValidation{
81-
Default: "500Mi",
82-
},
83-
Parser: QuantityParser(&QuantityValidation{
84-
Min: k8sresource.MustParse("500Mi"),
85-
}),
71+
Parser: QuantityParser(&QuantityValidation{
72+
Min: k8sresource.MustParse("1"),
73+
Int: true,
74+
}),
75+
},
76+
{
77+
StructField: "DriverMem",
78+
StringValidation: &cr.StringValidation{
79+
Default: "500Mi",
8680
},
87-
{
88-
StructField: "ExecutorMem",
89-
StringValidation: &cr.StringValidation{
90-
Default: "500Mi",
91-
},
92-
Parser: QuantityParser(&QuantityValidation{
93-
Min: k8sresource.MustParse("500Mi"),
94-
}),
81+
Parser: QuantityParser(&QuantityValidation{
82+
Min: k8sresource.MustParse("500Mi"),
83+
}),
84+
},
85+
{
86+
StructField: "ExecutorMem",
87+
StringValidation: &cr.StringValidation{
88+
Default: "500Mi",
9589
},
96-
{
97-
StructField: "DriverMemOverhead",
98-
StringPtrValidation: &cr.StringPtrValidation{
99-
Default: nil, // min(DriverMem * 0.4, 384Mi)
100-
},
101-
Parser: QuantityParser(&QuantityValidation{
102-
Min: k8sresource.MustParse("0"),
103-
}),
90+
Parser: QuantityParser(&QuantityValidation{
91+
Min: k8sresource.MustParse("500Mi"),
92+
}),
93+
},
94+
{
95+
StructField: "DriverMemOverhead",
96+
StringPtrValidation: &cr.StringPtrValidation{
97+
Default: nil, // min(DriverMem * 0.4, 384Mi)
10498
},
105-
{
106-
StructField: "ExecutorMemOverhead",
107-
StringPtrValidation: &cr.StringPtrValidation{
108-
Default: nil, // min(ExecutorMem * 0.4, 384Mi)
109-
},
110-
Parser: QuantityParser(&QuantityValidation{
111-
Min: k8sresource.MustParse("0"),
112-
}),
99+
Parser: QuantityParser(&QuantityValidation{
100+
Min: k8sresource.MustParse("0"),
101+
}),
102+
},
103+
{
104+
StructField: "ExecutorMemOverhead",
105+
StringPtrValidation: &cr.StringPtrValidation{
106+
Default: nil, // min(ExecutorMem * 0.4, 384Mi)
113107
},
114-
{
115-
StructField: "MemOverheadFactor",
116-
Float64PtrValidation: &cr.Float64PtrValidation{
117-
Default: nil, // set to 0.4 by Spark
118-
GreaterThanOrEqualTo: pointer.Float64(0),
119-
LessThan: pointer.Float64(1),
120-
},
108+
Parser: QuantityParser(&QuantityValidation{
109+
Min: k8sresource.MustParse("0"),
110+
}),
111+
},
112+
{
113+
StructField: "MemOverheadFactor",
114+
Float64PtrValidation: &cr.Float64PtrValidation{
115+
Default: nil, // set to 0.4 by Spark
116+
GreaterThanOrEqualTo: pointer.Float64(0),
117+
LessThan: pointer.Float64(1),
121118
},
122119
},
123120
},
124121
}
125122

123+
func sparkComputeFieldValidation(fieldName string) *cr.StructFieldValidation {
124+
return &cr.StructFieldValidation{
125+
StructField: fieldName,
126+
StructValidation: sparkComputeStructValidation,
127+
}
128+
}
129+
126130
func (sparkCompute *SparkCompute) ID() string {
127131
var buf bytes.Buffer
128132
buf.WriteString(s.Int32(sparkCompute.Executors))

pkg/operator/api/userconfig/models.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ type Model struct {
4242
Training *ModelTraining `json:"training" yaml:"training"`
4343
Evaluation *ModelEvaluation `json:"evaluation" yaml:"evaluation"`
4444
Compute *TFCompute `json:"compute" yaml:"compute"`
45+
DatasetCompute *SparkCompute `json:"dataset_compute" yaml:"dataset_compute"`
4546
Tags Tags `json:"tags" yaml:"tags"`
4647
}
4748

@@ -127,6 +128,7 @@ var modelValidation = &cr.StructValidation{
127128
StructValidation: modelEvaluationValidation,
128129
},
129130
tfComputeFieldValidation,
131+
sparkComputeFieldValidation("DatasetCompute"),
130132
tagsFieldValidation,
131133
typeFieldValidation,
132134
},

pkg/operator/api/userconfig/raw_columns.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ var rawIntColumnFieldValidations = []*cr.StructFieldValidation{
9696
AllowNull: true,
9797
},
9898
},
99-
sparkComputeFieldValidation,
99+
sparkComputeFieldValidation("Compute"),
100100
tagsFieldValidation,
101101
typeFieldValidation,
102102
}
@@ -145,7 +145,7 @@ var rawFloatColumnFieldValidations = []*cr.StructFieldValidation{
145145
AllowNull: true,
146146
},
147147
},
148-
sparkComputeFieldValidation,
148+
sparkComputeFieldValidation("Compute"),
149149
tagsFieldValidation,
150150
typeFieldValidation,
151151
}
@@ -182,7 +182,7 @@ var rawStringColumnFieldValidations = []*cr.StructFieldValidation{
182182
AllowNull: true,
183183
},
184184
},
185-
sparkComputeFieldValidation,
185+
sparkComputeFieldValidation("Compute"),
186186
tagsFieldValidation,
187187
typeFieldValidation,
188188
}

pkg/operator/api/userconfig/transformed_columns.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ var transformedColumnValidation = &cr.StructValidation{
5151
},
5252
},
5353
inputValuesFieldValidation,
54-
sparkComputeFieldValidation,
54+
sparkComputeFieldValidation("Compute"),
5555
tagsFieldValidation,
5656
typeFieldValidation,
5757
},

pkg/operator/workloads/data_job.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ func dataWorkloadSpecs(ctx *context.Context) ([]*WorkloadSpec, error) {
257257
allComputes = append(allComputes, transformedColumn.Compute)
258258
}
259259
}
260+
allComputes = append(allComputes, model.DatasetCompute)
260261
}
261262

262263
resourceIDSet := strset.Union(rawColumnIDs, aggregateIDs, transformedColumnIDs, trainingDatasetIDs)

0 commit comments

Comments
 (0)