From 3fbd972724196535b98aef76b9b7c46ad5140d95 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 22 May 2019 15:31:50 -0400 Subject: [PATCH 01/27] progress --- .../resources/{raw_columns.yaml => raw_data.yaml} | 8 -------- examples/reviews/resources/tokenized_columns.yaml | 10 ++++++++++ 2 files changed, 10 insertions(+), 8 deletions(-) rename examples/reviews/resources/{raw_columns.yaml => raw_data.yaml} (61%) diff --git a/examples/reviews/resources/raw_columns.yaml b/examples/reviews/resources/raw_data.yaml similarity index 61% rename from examples/reviews/resources/raw_columns.yaml rename to examples/reviews/resources/raw_data.yaml index 5afb119f67..992d784ecc 100644 --- a/examples/reviews/resources/raw_columns.yaml +++ b/examples/reviews/resources/raw_data.yaml @@ -7,11 +7,3 @@ header: true escape: "\"" schema: ["review", "label"] - -- kind: raw_column - name: review - type: STRING_COLUMN - -- kind: raw_column - name: label - type: STRING_COLUMN diff --git a/examples/reviews/resources/tokenized_columns.yaml b/examples/reviews/resources/tokenized_columns.yaml index cc2715e76f..911997f1b5 100644 --- a/examples/reviews/resources/tokenized_columns.yaml +++ b/examples/reviews/resources/tokenized_columns.yaml @@ -1,3 +1,13 @@ +- kind: environment + name: dev + data: + type: csv + path: s3a://cortex-examples/reviews.csv + csv_config: + header: true + escape: "\"" + schema: ["review", "label"] + - kind: transformed_column name: embedding_input transformer_path: implementations/transformers/tokenize_string_to_int.py From 08d682152bef3dc970b73390479b4e9c4b7be803 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Fri, 24 May 2019 15:33:31 -0400 Subject: [PATCH 02/27] progress --- cli/cmd/get.go | 2 +- examples/reviews/resources/raw_data.yaml | 9 ----- examples/reviews/resources/vocab.yaml | 11 +----- pkg/lib/errors/errors.go | 2 +- pkg/operator/api/context/raw_columns.go | 10 ++++- pkg/operator/api/context/serialize.go | 20 +++++++--- pkg/operator/api/userconfig/column_type.go | 3 ++ pkg/operator/api/userconfig/config.go | 24 ++++++++++++ pkg/operator/api/userconfig/raw_columns.go | 43 ++++++++++++++++------ pkg/operator/api/userconfig/validators.go | 5 +++ pkg/operator/context/raw_columns.go | 15 +++++++- 11 files changed, 104 insertions(+), 40 deletions(-) delete mode 100644 examples/reviews/resources/raw_data.yaml diff --git a/cli/cmd/get.go b/cli/cmd/get.go index e4661dd27b..7661e90d72 100644 --- a/cli/cmd/get.go +++ b/cli/cmd/get.go @@ -306,7 +306,7 @@ func describeRawColumn(name string, resourcesRes *schema.GetResourcesResponse) ( } dataStatus := resourcesRes.DataStatuses[rawColumn.GetID()] out := dataStatusSummary(dataStatus) - out += resourceStr(rawColumn.GetUserConfig()) + out += resourceStr(rawColumn) return out, nil } diff --git a/examples/reviews/resources/raw_data.yaml b/examples/reviews/resources/raw_data.yaml deleted file mode 100644 index 992d784ecc..0000000000 --- a/examples/reviews/resources/raw_data.yaml +++ /dev/null @@ -1,9 +0,0 @@ -- kind: environment - name: dev - data: - type: csv - path: s3a://cortex-examples/reviews.csv - csv_config: - header: true - escape: "\"" - schema: ["review", "label"] diff --git a/examples/reviews/resources/vocab.yaml b/examples/reviews/resources/vocab.yaml index 169965f2e9..a2da56f53f 100644 --- a/examples/reviews/resources/vocab.yaml +++ b/examples/reviews/resources/vocab.yaml @@ -1,15 +1,6 @@ -- kind: aggregator - name: vocab - output_type: {STRING: INT} - inputs: - columns: - col: STRING_COLUMN - args: - vocab_size: INT - - kind: aggregate name: reviews_vocab - aggregator: vocab + aggregator_path: implementations/aggregators/vocab.py inputs: columns: col: review diff --git a/pkg/lib/errors/errors.go b/pkg/lib/errors/errors.go index 342e07a307..c3dc2de155 100644 --- a/pkg/lib/errors/errors.go +++ b/pkg/lib/errors/errors.go @@ -151,7 +151,7 @@ func Panic(items ...interface{}) { func PrintError(err error, strs ...string) { wrappedErr := Wrap(err, strs...) fmt.Println("error:", wrappedErr.Error()) - // PrintStacktrace(wrappedErr) + PrintStacktrace(wrappedErr) } func PrintStacktrace(err error) { diff --git a/pkg/operator/api/context/raw_columns.go b/pkg/operator/api/context/raw_columns.go index b428e8f3ed..353691c77d 100644 --- a/pkg/operator/api/context/raw_columns.go +++ b/pkg/operator/api/context/raw_columns.go @@ -27,7 +27,6 @@ type RawColumns map[string]RawColumn type RawColumn interface { Column GetCompute() *userconfig.SparkCompute - GetUserConfig() userconfig.Resource } type RawIntColumn struct { @@ -45,6 +44,11 @@ type RawStringColumn struct { *ComputedResourceFields } +type RawInferredColumn struct { + *userconfig.RawInferredColumn + *ComputedResourceFields +} + func (rawColumns RawColumns) OneByID(id string) RawColumn { for _, rawColumn := range rawColumns { if rawColumn.GetID() == id { @@ -98,3 +102,7 @@ func (rawColumn *RawFloatColumn) GetInputRawColumnNames() []string { func (rawColumn *RawStringColumn) GetInputRawColumnNames() []string { return []string{rawColumn.GetName()} } + +func (rawColumn *RawInferredColumn) GetInputRawColumnNames() []string { + return []string{rawColumn.GetName()} +} diff --git a/pkg/operator/api/context/serialize.go b/pkg/operator/api/context/serialize.go index a3fd126fd3..dee8e00c7d 100644 --- a/pkg/operator/api/context/serialize.go +++ b/pkg/operator/api/context/serialize.go @@ -25,9 +25,10 @@ import ( ) type RawColumnsTypeSplit struct { - RawIntColumns map[string]*RawIntColumn `json:"raw_int_columns"` - RawStringColumns map[string]*RawStringColumn `json:"raw_string_columns"` - RawFloatColumns map[string]*RawFloatColumn `json:"raw_float_columns"` + RawIntColumns map[string]*RawIntColumn `json:"raw_int_columns"` + RawStringColumns map[string]*RawStringColumn `json:"raw_string_columns"` + RawFloatColumns map[string]*RawFloatColumn `json:"raw_float_columns"` + RawInferredColumns map[string]*RawInferredColumn `json:"raw_inferred_columns"` } type DataSplit struct { @@ -45,6 +46,7 @@ func (ctx Context) splitRawColumns() *RawColumnsTypeSplit { var rawIntColumns = make(map[string]*RawIntColumn) var rawFloatColumns = make(map[string]*RawFloatColumn) var rawStringColumns = make(map[string]*RawStringColumn) + var rawInferredColumns = make(map[string]*RawInferredColumn) for name, rawColumn := range ctx.RawColumns { switch typedRawColumn := rawColumn.(type) { case *RawIntColumn: @@ -53,13 +55,16 @@ func (ctx Context) splitRawColumns() *RawColumnsTypeSplit { rawFloatColumns[name] = typedRawColumn case *RawStringColumn: rawStringColumns[name] = typedRawColumn + case *RawInferredColumn: + rawInferredColumns[name] = typedRawColumn } } return &RawColumnsTypeSplit{ - RawIntColumns: rawIntColumns, - RawFloatColumns: rawFloatColumns, - RawStringColumns: rawStringColumns, + RawIntColumns: rawIntColumns, + RawFloatColumns: rawFloatColumns, + RawStringColumns: rawStringColumns, + RawInferredColumns: rawInferredColumns, } } @@ -75,6 +80,9 @@ func (serial Serial) collectRawColumns() RawColumns { for name, rawColumn := range serial.RawColumnSplit.RawStringColumns { rawColumns[name] = rawColumn } + for name, rawColumn := range serial.RawColumnSplit.RawInferredColumns { + rawColumns[name] = rawColumn + } return rawColumns } diff --git a/pkg/operator/api/userconfig/column_type.go b/pkg/operator/api/userconfig/column_type.go index 94c34635ec..b6f6990fe7 100644 --- a/pkg/operator/api/userconfig/column_type.go +++ b/pkg/operator/api/userconfig/column_type.go @@ -31,6 +31,7 @@ const ( IntegerListColumnType FloatListColumnType StringListColumnType + InferredColumnType ) var columnTypes = []string{ @@ -41,6 +42,7 @@ var columnTypes = []string{ "INT_LIST_COLUMN", "FLOAT_LIST_COLUMN", "STRING_LIST_COLUMN", + "INFERRED_COLUMN", } var columnJSONPlaceholders = []string{ @@ -51,6 +53,7 @@ var columnJSONPlaceholders = []string{ "[INT]", "[FLOAT]", "[\"STRING\"]", + "INFER", } func ColumnTypeFromString(s string) ColumnType { diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index 99a3086b17..bc3ac0b580 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -21,6 +21,8 @@ import ( "io/ioutil" "strings" + k8sresource "k8s.io/apimachinery/pkg/api/resource" + "github.com/cortexlabs/cortex/pkg/lib/cast" "github.com/cortexlabs/cortex/pkg/lib/configreader" cr "github.com/cortexlabs/cortex/pkg/lib/configreader" @@ -440,6 +442,28 @@ func New(configs map[string][]byte, envName string) (*Config, error) { } } + rawColumnNames := config.RawColumns.Names() + for _, env := range config.Environments { + ingestedColumnNames := env.Data.GetIngestedColumns() + missingColumns := slices.SubtractStrSlice(ingestedColumnNames, rawColumnNames) + for _, inferredColumn := range missingColumns { + inferredRawColumn := &RawInferredColumn{ + ResourceFields: ResourceFields{ + Name: inferredColumn, + }, + Type: InferredColumnType, + Compute: &SparkCompute{ + Executors: 1, + DriverCPU: Quantity{Quantity: k8sresource.MustParse("1")}, + ExecutorCPU: Quantity{Quantity: k8sresource.MustParse("1")}, + DriverMem: Quantity{Quantity: k8sresource.MustParse("500Mi")}, + ExecutorMem: Quantity{Quantity: k8sresource.MustParse("500Mi")}, + }, + } + config.RawColumns = append(config.RawColumns, inferredRawColumn) + } + } + if err := config.Validate(envName); err != nil { return nil, err } diff --git a/pkg/operator/api/userconfig/raw_columns.go b/pkg/operator/api/userconfig/raw_columns.go index 1f04b7dad7..29577e228c 100644 --- a/pkg/operator/api/userconfig/raw_columns.go +++ b/pkg/operator/api/userconfig/raw_columns.go @@ -25,7 +25,6 @@ type RawColumn interface { Column GetType() ColumnType GetCompute() *SparkCompute - GetUserConfig() Resource } type RawColumns []RawColumn @@ -181,6 +180,24 @@ var rawStringColumnFieldValidations = []*cr.StructFieldValidation{ typeFieldValidation, } +type RawInferredColumn struct { + ResourceFields + Type ColumnType `json:"type" yaml:"type"` + Compute *SparkCompute `json:"compute" yaml:"compute"` +} + +var rawInferredColumnFieldValidations = []*cr.StructFieldValidation{ + { + Key: "name", + StructField: "Name", + StringValidation: &cr.StringValidation{ + AlphaNumericDashUnderscore: true, + Required: true, + }, + }, + sparkComputeFieldValidation("Compute"), +} + func (rawColumns RawColumns) Validate() error { resources := make([]Resource, len(rawColumns)) for i, res := range rawColumns { @@ -224,6 +241,10 @@ func (column *RawStringColumn) GetType() ColumnType { return column.Type } +func (column *RawInferredColumn) GetType() ColumnType { + return column.Type +} + func (column *RawIntColumn) GetCompute() *SparkCompute { return column.Compute } @@ -236,6 +257,10 @@ func (column *RawStringColumn) GetCompute() *SparkCompute { return column.Compute } +func (column *RawInferredColumn) GetCompute() *SparkCompute { + return column.Compute +} + func (column *RawIntColumn) GetResourceType() resource.Type { return resource.RawColumnType } @@ -248,6 +273,10 @@ func (column *RawStringColumn) GetResourceType() resource.Type { return resource.RawColumnType } +func (column *RawInferredColumn) GetResourceType() resource.Type { + return resource.RawColumnType +} + func (column *RawIntColumn) IsRaw() bool { return true } @@ -260,14 +289,6 @@ func (column *RawStringColumn) IsRaw() bool { return true } -func (column *RawIntColumn) GetUserConfig() Resource { - return column -} - -func (column *RawFloatColumn) GetUserConfig() Resource { - return column -} - -func (column *RawStringColumn) GetUserConfig() Resource { - return column +func (column *RawInferredColumn) IsRaw() bool { + return true } diff --git a/pkg/operator/api/userconfig/validators.go b/pkg/operator/api/userconfig/validators.go index b934f3c750..3a107a6bc7 100644 --- a/pkg/operator/api/userconfig/validators.go +++ b/pkg/operator/api/userconfig/validators.go @@ -139,6 +139,11 @@ func CheckColumnRuntimeTypesMatch(columnRuntimeTypes map[string]interface{}, col if !ok { return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeTypeInter, validTypes), columnInputName) } + + if columnRuntimeType == InferredColumnType { + continue + } + if !slices.HasString(validTypes, columnRuntimeType.String()) { return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeType, validTypes), columnInputName) } diff --git a/pkg/operator/context/raw_columns.go b/pkg/operator/context/raw_columns.go index a3102cff45..365e07628e 100644 --- a/pkg/operator/context/raw_columns.go +++ b/pkg/operator/context/raw_columns.go @@ -98,8 +98,21 @@ func getRawColumns( }, RawStringColumn: typedColumnConfig, } + case *userconfig.RawInferredColumn: + buf.WriteString(typedColumnConfig.Name) + id := hash.Bytes(buf.Bytes()) + rawColumn = &context.RawInferredColumn{ + ComputedResourceFields: &context.ComputedResourceFields{ + ResourceFields: &context.ResourceFields{ + ID: id, + ResourceType: resource.RawColumnType, + MetadataKey: filepath.Join(consts.RawColumnsDir, id+"_metadata.json"), + }, + }, + RawInferredColumn: typedColumnConfig, + } default: - return nil, errors.Wrap(configreader.ErrorInvalidStr(userconfig.TypeKey, userconfig.IntegerColumnType.String(), userconfig.FloatColumnType.String(), userconfig.StringColumnType.String()), userconfig.Identify(columnConfig)) // unexpected error + return nil, errors.Wrap(configreader.ErrorInvalidStr(typedColumnConfig.GetType().String(), userconfig.IntegerColumnType.String(), userconfig.FloatColumnType.String(), userconfig.StringColumnType.String()), userconfig.Identify(columnConfig)) // unexpected error } rawColumns[columnConfig.GetName()] = rawColumn From dafa4b001e58f2037e56117dee05256e244a3b19 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 28 May 2019 11:42:05 -0400 Subject: [PATCH 03/27] raw columns optional --- pkg/operator/context/raw_columns.go | 1 - pkg/workloads/consts.py | 2 ++ pkg/workloads/lib/context.py | 8 ++---- pkg/workloads/lib/tf_lib.py | 18 ------------ pkg/workloads/spark_job/spark_util.py | 41 ++++++++++++++++++--------- pkg/workloads/tf_api/api.py | 3 +- pkg/workloads/tf_train/train_util.py | 4 +-- 7 files changed, 35 insertions(+), 42 deletions(-) diff --git a/pkg/operator/context/raw_columns.go b/pkg/operator/context/raw_columns.go index c53a2107db..9e236ba913 100644 --- a/pkg/operator/context/raw_columns.go +++ b/pkg/operator/context/raw_columns.go @@ -101,7 +101,6 @@ func getRawColumns( ResourceFields: &context.ResourceFields{ ID: id, ResourceType: resource.RawColumnType, - MetadataKey: filepath.Join(consts.RawColumnsDir, id+"_metadata.json"), }, }, RawInferredColumn: typedColumnConfig, diff --git a/pkg/workloads/consts.py b/pkg/workloads/consts.py index aa6c57f6cf..ed330d68be 100644 --- a/pkg/workloads/consts.py +++ b/pkg/workloads/consts.py @@ -20,6 +20,7 @@ COLUMN_TYPE_INT_LIST = "INT_LIST_COLUMN" COLUMN_TYPE_FLOAT_LIST = "FLOAT_LIST_COLUMN" COLUMN_TYPE_STRING_LIST = "STRING_LIST_COLUMN" +COLUMN_TYPE_INFERRED = "INFERRED_COLUMN" COLUMN_LIST_TYPES = [COLUMN_TYPE_INT_LIST, COLUMN_TYPE_FLOAT_LIST, COLUMN_TYPE_STRING_LIST] @@ -30,6 +31,7 @@ COLUMN_TYPE_INT_LIST, COLUMN_TYPE_FLOAT_LIST, COLUMN_TYPE_STRING_LIST, + COLUMN_TYPE_INFERRED, ] VALUE_TYPE_INT = "INT" diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 50ed2c6c7f..639f5ba852 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -489,7 +489,7 @@ def get_metadata(self, resource_id, use_cache=True): def get_inferred_column_type(self, column_name): column = self.columns[column_name] column_type = self.columns[column_name].get("type", "unknown") - if column_type == "unknown": + if column_type == "unknown" or column_type == "INFERRED_COLUMN": column_type = self.get_metadata(column["id"])["type"] self.columns[column_name]["type"] = column_type @@ -575,11 +575,7 @@ def create_inputs_map(values_map, input_config): def _deserialize_raw_ctx(raw_ctx): raw_columns = raw_ctx["raw_columns"] - raw_ctx["raw_columns"] = util.merge_dicts_overwrite( - raw_columns["raw_int_columns"], - raw_columns["raw_float_columns"], - raw_columns["raw_string_columns"], - ) + raw_ctx["raw_columns"] = util.merge_dicts_overwrite(*raw_columns.values()) data_split = raw_ctx["environment_data"] diff --git a/pkg/workloads/lib/tf_lib.py b/pkg/workloads/lib/tf_lib.py index 9123e69160..1952149fa0 100644 --- a/pkg/workloads/lib/tf_lib.py +++ b/pkg/workloads/lib/tf_lib.py @@ -30,24 +30,6 @@ } -def add_tf_types(config): - if not util.is_dict(config): - return - - type_fields = {} - for k, v in config.items(): - if util.is_str(k) and util.is_str(v) and v in consts.COLUMN_TYPES: - type_fields[k] = v - elif util.is_dict(v): - add_tf_types(v) - elif util.is_list(v): - for sub_v in v: - add_tf_types(sub_v) - - for k, v in type_fields.items(): - config[k + "_tf"] = CORTEX_TYPE_TO_TF_TYPE[v] - - def set_logging_verbosity(verbosity): tf.logging.set_verbosity(verbosity) os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(tf.logging.__dict__[verbosity] / 10) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 92ce98a554..15be92d76e 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -223,8 +223,33 @@ def ingest(ctx, spark): for raw_column_name in ctx.raw_columns.keys(): raw_column = ctx.raw_columns[raw_column_name] - expected_types = CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[raw_column["type"]] actual_type = input_type_map[raw_column_name] + + if actual_type not in SPARK_TYPE_TO_CORTEX_TYPE.keys(): + actual_type = StringType() + + column_type = raw_column["type"] + if column_type == consts.COLUMN_TYPE_INFERRED: + TEST_DF_SIZE = 10 + sample_df = df.select(raw_column_name).limit(TEST_DF_SIZE).collect() + sample = sample_df[0][raw_column_name] + inferred_type = infer_type(sample) + + for row in sample_df: + if inferred_type != infer_type(row[raw_column_name]): + raise UserRuntimeException( + "raw column " + raw_column_name, + "type inference failed, mixed data types in dataframe.", + 'expected type of "' + + row + + '" to be ' + + inferred_type, + ) + + ctx.write_metadata(raw_column["id"], {"type": inferred_type}) + column_type = inferred_type + + expected_types = CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[column_type] if actual_type not in expected_types: logger.error("found schema:") log_df_schema(df, logger.error) @@ -236,7 +261,7 @@ def ingest(ctx, spark): " or ".join(str(x) for x in expected_types), actual_type ), ) - target_type = CORTEX_TYPE_TO_SPARK_TYPE[raw_column["type"]] + target_type = CORTEX_TYPE_TO_SPARK_TYPE[column_type] if target_type != actual_type: df = df.withColumn(raw_column_name, F.col(raw_column_name).cast(target_type)) @@ -246,16 +271,6 @@ def ingest(ctx, spark): def read_csv(ctx, spark): data_config = ctx.environment["data"] - expected_field_names = data_config["schema"] - - schema_fields = [] - for field_name in expected_field_names: - if field_name in ctx.raw_columns: - spark_type = CORTEX_TYPE_TO_SPARK_TYPE[ctx.raw_columns[field_name]["type"]] - else: - spark_type = StringType() - - schema_fields.append(StructField(name=field_name, dataType=spark_type)) csv_config = { util.snake_to_camel(param_name): val @@ -264,7 +279,7 @@ def read_csv(ctx, spark): } df = spark.read.csv( - data_config["path"], schema=StructType(schema_fields), mode="FAILFAST", **csv_config + data_config["path"], inferSchema=True, **csv_config ) return df.select(*ctx.raw_columns.keys()) diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 28eb1aab5e..aead7b828d 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -209,7 +209,8 @@ def is_valid_sample(sample): return False, "{} is missing".format(column["name"]) sample_val = sample[column["name"]] - is_valid = util.CORTEX_TYPE_TO_UPCAST_VALIDATOR[column["type"]](sample_val) + column_type = local_cache["ctx"].get_inferred_column_type(column["name"]) + is_valid = util.CORTEX_TYPE_TO_UPCAST_VALIDATOR[column_type](sample_val) if not is_valid: return (False, "{} should be a {}".format(column["name"], column["type"])) diff --git a/pkg/workloads/tf_train/train_util.py b/pkg/workloads/tf_train/train_util.py index 9e36a08ae6..203fc134bc 100644 --- a/pkg/workloads/tf_train/train_util.py +++ b/pkg/workloads/tf_train/train_util.py @@ -157,7 +157,6 @@ def train(model_name, model_impl, ctx, model_dir): ) train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=train_num_steps) - eval_num_steps = model["evaluation"]["num_steps"] if model["evaluation"]["num_epochs"]: eval_num_steps = ( @@ -175,12 +174,11 @@ def train(model_name, model_impl, ctx, model_dir): ) model_config = ctx.model_config(model["name"]) - tf_lib.add_tf_types(model_config) try: estimator = model_impl.create_estimator(run_config, model_config) except Exception as e: - raise UserRuntimeException("model " + model_name) from e + raise UserRuntimeException("model " + model_name) from e if model["type"] == "regression": estimator = tf.contrib.estimator.add_metrics(estimator, get_regression_eval_metrics) From 4034a054031a59afe037cd6d0d113b5a9f3209e5 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 28 May 2019 11:50:26 -0400 Subject: [PATCH 04/27] format and comment out print stack --- pkg/lib/errors/errors.go | 2 +- pkg/workloads/spark_job/spark_util.py | 9 ++------- pkg/workloads/tf_train/train_util.py | 2 +- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/pkg/lib/errors/errors.go b/pkg/lib/errors/errors.go index c3dc2de155..342e07a307 100644 --- a/pkg/lib/errors/errors.go +++ b/pkg/lib/errors/errors.go @@ -151,7 +151,7 @@ func Panic(items ...interface{}) { func PrintError(err error, strs ...string) { wrappedErr := Wrap(err, strs...) fmt.Println("error:", wrappedErr.Error()) - PrintStacktrace(wrappedErr) + // PrintStacktrace(wrappedErr) } func PrintStacktrace(err error) { diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 15be92d76e..04a40ffa2a 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -240,10 +240,7 @@ def ingest(ctx, spark): raise UserRuntimeException( "raw column " + raw_column_name, "type inference failed, mixed data types in dataframe.", - 'expected type of "' - + row - + '" to be ' - + inferred_type, + 'expected type of "' + row + '" to be ' + inferred_type, ) ctx.write_metadata(raw_column["id"], {"type": inferred_type}) @@ -278,9 +275,7 @@ def read_csv(ctx, spark): if val is not None } - df = spark.read.csv( - data_config["path"], inferSchema=True, **csv_config - ) + df = spark.read.csv(data_config["path"], inferSchema=True, **csv_config) return df.select(*ctx.raw_columns.keys()) diff --git a/pkg/workloads/tf_train/train_util.py b/pkg/workloads/tf_train/train_util.py index 203fc134bc..d82c572d50 100644 --- a/pkg/workloads/tf_train/train_util.py +++ b/pkg/workloads/tf_train/train_util.py @@ -178,7 +178,7 @@ def train(model_name, model_impl, ctx, model_dir): try: estimator = model_impl.create_estimator(run_config, model_config) except Exception as e: - raise UserRuntimeException("model " + model_name) from e + raise UserRuntimeException("model " + model_name) from e if model["type"] == "regression": estimator = tf.contrib.estimator.add_metrics(estimator, get_regression_eval_metrics) From ef79bb5b6dbc83bb2d2737e2e02a74372771c859 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 28 May 2019 12:01:01 -0400 Subject: [PATCH 05/27] add back spacing --- pkg/workloads/tf_train/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/workloads/tf_train/train_util.py b/pkg/workloads/tf_train/train_util.py index d82c572d50..8f7105d25d 100644 --- a/pkg/workloads/tf_train/train_util.py +++ b/pkg/workloads/tf_train/train_util.py @@ -157,6 +157,7 @@ def train(model_name, model_impl, ctx, model_dir): ) train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=train_num_steps) + eval_num_steps = model["evaluation"]["num_steps"] if model["evaluation"]["num_epochs"]: eval_num_steps = ( From c8ec015f7fae2a44b58c044e233fad36748c89ae Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 28 May 2019 16:54:32 -0400 Subject: [PATCH 06/27] fix tests --- pkg/workloads/lib/util.py | 4 ---- pkg/workloads/spark_job/spark_util.py | 9 +++++++++ pkg/workloads/spark_job/test/unit/spark_util_test.py | 8 ++++---- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index 8a5bbf695f..b9ab3fb2a1 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -316,10 +316,6 @@ def flatten(var): return [var] -def subtract_lists(lst1, lst2): - return [e for e in lst1 if e not in lst2] - - def keep_dict_keys(d, keys): key_set = set(keys) for key in list(d.keys()): diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 04a40ffa2a..ab28ab328c 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -276,6 +276,15 @@ def read_csv(ctx, spark): } df = spark.read.csv(data_config["path"], inferSchema=True, **csv_config) + renamed_cols = [F.col(c).alias(data_config["schema"][idx]) for idx, c in enumerate(df.columns)] + df = df.select(*renamed_cols) + + missing_cols = set(data_config["schema"]) - set(df.columns) + if len(missing_cols) > 0: + logger.error("found schema:") + log_df_schema(df, logger.error) + raise UserException("missing column(s) in input dataset", str(missing_cols)) + return df.select(*ctx.raw_columns.keys()) diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index 7b6b71fda0..5c47a0e65a 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -49,7 +49,7 @@ def test_read_csv_valid(spark, write_csv_file, ctx_obj, get_context): "data": { "type": "csv", "path": path_to_file, - "schema": ["a_str", "b_float", "c_long", "d_long"], + "schema": ["a_str", "b_float", "c_long"], } } @@ -71,12 +71,12 @@ def test_read_csv_invalid_type(spark, write_csv_file, ctx_obj, get_context): "c_long": {"name": "c_long", "type": "INT_COLUMN", "required": False, "id": "-"}, } - with pytest.raises(Py4JJavaError): + with pytest.raises(UserException): spark_util.ingest(get_context(ctx_obj), spark).collect() def test_read_csv_missing_column(spark, write_csv_file, ctx_obj, get_context): - csv_str = "\n".join(["a,0.1,", "b,1,1"]) + csv_str = "\n".join(["a,1,", "b,1,"]) path_to_file = write_csv_file(csv_str) @@ -90,7 +90,7 @@ def test_read_csv_missing_column(spark, write_csv_file, ctx_obj, get_context): "c_long": {"name": "c_long", "type": "INT_COLUMN", "required": False, "id": "-"}, } - with pytest.raises(Py4JJavaError) as exec_info: + with pytest.raises(UserException) as exec_info: spark_util.ingest(get_context(ctx_obj), spark).collect() From 51e4a27b4577d0f8e078a890a68c78ffb9212f68 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 28 May 2019 17:00:35 -0400 Subject: [PATCH 07/27] format --- pkg/workloads/spark_job/spark_util.py | 10 +++++++--- pkg/workloads/spark_job/test/unit/spark_util_test.py | 6 +----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index ab28ab328c..74ea19d237 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -279,11 +279,15 @@ def read_csv(ctx, spark): renamed_cols = [F.col(c).alias(data_config["schema"][idx]) for idx, c in enumerate(df.columns)] df = df.select(*renamed_cols) - missing_cols = set(data_config["schema"]) - set(df.columns) - if len(missing_cols) > 0: + if set(data_config["schema"]) != set(df.columns): logger.error("found schema:") log_df_schema(df, logger.error) - raise UserException("missing column(s) in input dataset", str(missing_cols)) + raise UserException( + "expected column(s) " + + str(set(data_config["schema"])) + + " but got " + + str(set(df.columns)) + ) return df.select(*ctx.raw_columns.keys()) diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index 5c47a0e65a..f4434cc32e 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -46,11 +46,7 @@ def test_read_csv_valid(spark, write_csv_file, ctx_obj, get_context): assert spark_util.read_csv(get_context(ctx_obj), spark).count() == 3 ctx_obj["environment"] = { - "data": { - "type": "csv", - "path": path_to_file, - "schema": ["a_str", "b_float", "c_long"], - } + "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_float", "c_long"]} } assert spark_util.read_csv(get_context(ctx_obj), spark).count() == 3 From ff451f8efd5cc344a3b06db82e6da2242c792acc Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 29 May 2019 15:14:04 -0400 Subject: [PATCH 08/27] address some comments --- cli/cmd/get.go | 2 +- .../{tokenized_columns.yaml => columns.yaml} | 0 pkg/operator/api/context/raw_columns.go | 15 ++++++++ pkg/operator/api/userconfig/config.go | 25 +++++--------- pkg/operator/api/userconfig/raw_columns.go | 12 ------- pkg/operator/context/transformers.go | 3 +- pkg/workloads/lib/context.py | 2 +- pkg/workloads/lib/tf_lib.py | 5 ++- pkg/workloads/spark_job/spark_util.py | 34 +++---------------- .../spark_job/test/unit/spark_util_test.py | 6 ---- pkg/workloads/tf_api/api.py | 9 ++--- 11 files changed, 40 insertions(+), 73 deletions(-) rename examples/reviews/resources/{tokenized_columns.yaml => columns.yaml} (100%) diff --git a/cli/cmd/get.go b/cli/cmd/get.go index 7661e90d72..c35771e70d 100644 --- a/cli/cmd/get.go +++ b/cli/cmd/get.go @@ -306,7 +306,7 @@ func describeRawColumn(name string, resourcesRes *schema.GetResourcesResponse) ( } dataStatus := resourcesRes.DataStatuses[rawColumn.GetID()] out := dataStatusSummary(dataStatus) - out += resourceStr(rawColumn) + out += resourceStr(context.GetRawColumnUserConfig(rawColumn)) return out, nil } diff --git a/examples/reviews/resources/tokenized_columns.yaml b/examples/reviews/resources/columns.yaml similarity index 100% rename from examples/reviews/resources/tokenized_columns.yaml rename to examples/reviews/resources/columns.yaml diff --git a/pkg/operator/api/context/raw_columns.go b/pkg/operator/api/context/raw_columns.go index 353691c77d..6a6830a19a 100644 --- a/pkg/operator/api/context/raw_columns.go +++ b/pkg/operator/api/context/raw_columns.go @@ -83,6 +83,21 @@ func (rawColumns RawColumns) columnInputsID(columnInputValues map[string]interfa return hash.Any(columnIDMap) } +func GetRawColumnUserConfig(rawColumn RawColumn) userconfig.Resource { + switch rawColumn.GetType() { + case userconfig.IntegerColumnType: + return rawColumn.(*RawIntColumn).RawIntColumn + case userconfig.FloatColumnType: + return rawColumn.(*RawFloatColumn).RawFloatColumn + case userconfig.StringColumnType: + return rawColumn.(*RawStringColumn).RawStringColumn + case userconfig.InferredColumnType: + return rawColumn.(*RawInferredColumn).RawInferredColumn + } + + return nil +} + func (rawColumns RawColumns) ColumnInputsID(columnInputValues map[string]interface{}) string { return rawColumns.columnInputsID(columnInputValues, false) } diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index bc3ac0b580..c69acb40e8 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -21,8 +21,6 @@ import ( "io/ioutil" "strings" - k8sresource "k8s.io/apimachinery/pkg/api/resource" - "github.com/cortexlabs/cortex/pkg/lib/cast" "github.com/cortexlabs/cortex/pkg/lib/configreader" cr "github.com/cortexlabs/cortex/pkg/lib/configreader" @@ -156,9 +154,9 @@ func (config *Config) Validate(envName string) error { rawColumnNames := config.RawColumns.Names() for _, env := range config.Environments { ingestedColumnNames := env.Data.GetIngestedColumns() - missingColumns := slices.SubtractStrSlice(rawColumnNames, ingestedColumnNames) - if len(missingColumns) > 0 { - return errors.Wrap(ErrorRawColumnNotInEnv(env.Name), Identify(config.RawColumns.Get(missingColumns[0]))) + missingColumnNames := slices.SubtractStrSlice(rawColumnNames, ingestedColumnNames) + if len(missingColumnNames) > 0 { + return errors.Wrap(ErrorRawColumnNotInEnv(env.Name), Identify(config.RawColumns.Get(missingColumnNames[0]))) } extraColumns := slices.SubtractStrSlice(rawColumnNames, ingestedColumnNames) if len(extraColumns) > 0 { @@ -445,21 +443,16 @@ func New(configs map[string][]byte, envName string) (*Config, error) { rawColumnNames := config.RawColumns.Names() for _, env := range config.Environments { ingestedColumnNames := env.Data.GetIngestedColumns() - missingColumns := slices.SubtractStrSlice(ingestedColumnNames, rawColumnNames) - for _, inferredColumn := range missingColumns { + missingColumnNames := slices.SubtractStrSlice(ingestedColumnNames, rawColumnNames) + for _, inferredColumnName := range missingColumnNames { inferredRawColumn := &RawInferredColumn{ ResourceFields: ResourceFields{ - Name: inferredColumn, - }, - Type: InferredColumnType, - Compute: &SparkCompute{ - Executors: 1, - DriverCPU: Quantity{Quantity: k8sresource.MustParse("1")}, - ExecutorCPU: Quantity{Quantity: k8sresource.MustParse("1")}, - DriverMem: Quantity{Quantity: k8sresource.MustParse("500Mi")}, - ExecutorMem: Quantity{Quantity: k8sresource.MustParse("500Mi")}, + Name: inferredColumnName, }, + Type: InferredColumnType, + Compute: &SparkCompute{}, } + cr.Struct(inferredRawColumn.Compute, make(map[string]interface{}), sparkComputeStructValidation) config.RawColumns = append(config.RawColumns, inferredRawColumn) } } diff --git a/pkg/operator/api/userconfig/raw_columns.go b/pkg/operator/api/userconfig/raw_columns.go index 29577e228c..8bda695782 100644 --- a/pkg/operator/api/userconfig/raw_columns.go +++ b/pkg/operator/api/userconfig/raw_columns.go @@ -186,18 +186,6 @@ type RawInferredColumn struct { Compute *SparkCompute `json:"compute" yaml:"compute"` } -var rawInferredColumnFieldValidations = []*cr.StructFieldValidation{ - { - Key: "name", - StructField: "Name", - StringValidation: &cr.StringValidation{ - AlphaNumericDashUnderscore: true, - Required: true, - }, - }, - sparkComputeFieldValidation("Compute"), -} - func (rawColumns RawColumns) Validate() error { resources := make([]Resource, len(rawColumns)) for i, res := range rawColumns { diff --git a/pkg/operator/context/transformers.go b/pkg/operator/context/transformers.go index 859e0fe75f..d19aac4fbd 100644 --- a/pkg/operator/context/transformers.go +++ b/pkg/operator/context/transformers.go @@ -67,7 +67,8 @@ func loadUserTransformers( ResourceFields: userconfig.ResourceFields{ Name: implHash, }, - Path: *transColConfig.TransformerPath, + OutputType: userconfig.InferredColumnType, + Path: *transColConfig.TransformerPath, } transformer, err := newTransformer(*anonTransformerConfig, impl, nil, pythonPackages) if err != nil { diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 639f5ba852..4ff01c2058 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -489,7 +489,7 @@ def get_metadata(self, resource_id, use_cache=True): def get_inferred_column_type(self, column_name): column = self.columns[column_name] column_type = self.columns[column_name].get("type", "unknown") - if column_type == "unknown" or column_type == "INFERRED_COLUMN": + if column_type == consts.COLUMN_TYPE_INFERRED: column_type = self.get_metadata(column["id"])["type"] self.columns[column_name]["type"] = column_type diff --git a/pkg/workloads/lib/tf_lib.py b/pkg/workloads/lib/tf_lib.py index 1952149fa0..09962f5d50 100644 --- a/pkg/workloads/lib/tf_lib.py +++ b/pkg/workloads/lib/tf_lib.py @@ -46,9 +46,8 @@ def get_column_tf_types(model_name, ctx, training=True): if training: target_column_name = model["target_column"] - column_types[target_column_name] = CORTEX_TYPE_TO_TF_TYPE[ - ctx.columns[target_column_name]["type"] - ] + column_type = ctx.get_inferred_column_type(target_column_name) + column_types[target_column_name] = CORTEX_TYPE_TO_TF_TYPE[column_type] for column_name in model["training_columns"]: column_type = ctx.get_inferred_column_type(column_name) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 74ea19d237..b24072d6bb 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -225,24 +225,14 @@ def ingest(ctx, spark): raw_column = ctx.raw_columns[raw_column_name] actual_type = input_type_map[raw_column_name] - if actual_type not in SPARK_TYPE_TO_CORTEX_TYPE.keys(): + if actual_type not in SPARK_TYPE_TO_CORTEX_TYPE: actual_type = StringType() column_type = raw_column["type"] if column_type == consts.COLUMN_TYPE_INFERRED: - TEST_DF_SIZE = 10 - sample_df = df.select(raw_column_name).limit(TEST_DF_SIZE).collect() + sample_df = df.select(raw_column_name).limit(1).collect() sample = sample_df[0][raw_column_name] inferred_type = infer_type(sample) - - for row in sample_df: - if inferred_type != infer_type(row[raw_column_name]): - raise UserRuntimeException( - "raw column " + raw_column_name, - "type inference failed, mixed data types in dataframe.", - 'expected type of "' + row + '" to be ' + inferred_type, - ) - ctx.write_metadata(raw_column["id"], {"type": inferred_type}) column_type = inferred_type @@ -527,27 +517,13 @@ def validate_transformer(column_name, test_df, ctx, spark): if hasattr(trans_impl, "transform_python"): try: - if transformer["output_type"] == "unknown": + if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED: sample_df = test_df.collect() sample = sample_df[0] inputs = ctx.create_column_inputs_map(sample, column_name) _, impl_args = extract_inputs(column_name, ctx) initial_transformed_sample = trans_impl.transform_python(inputs, impl_args) inferred_python_type = infer_type(initial_transformed_sample) - - for row in sample_df: - inputs = ctx.create_column_inputs_map(row, column_name) - transformed_sample = trans_impl.transform_python(inputs, impl_args) - if inferred_python_type != infer_type(transformed_sample): - raise UserRuntimeException( - "transformed column " + column_name, - "type inference failed, mixed data types in dataframe.", - 'expected type of "' - + transformed_sample - + '" to be ' - + inferred_python_type, - ) - ctx.write_metadata(transformed_column["id"], {"type": inferred_python_type}) transform_python_collect = execute_transform_python( @@ -609,7 +585,7 @@ def validate_transformer(column_name, test_df, ctx, spark): ) ) - if transformer["output_type"] == "unknown": + if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED: inferred_spark_type = transform_spark_df.select(column_name).schema[0].dataType ctx.write_metadata(transformed_column["id"], {"type": inferred_spark_type}) @@ -641,7 +617,7 @@ def validate_transformer(column_name, test_df, ctx, spark): raise if hasattr(trans_impl, "transform_spark") and hasattr(trans_impl, "transform_python"): - if transformer["output_type"] == "unknown" and inferred_spark_type != inferred_python_type: + if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED and inferred_spark_type != inferred_python_type: raise UserRuntimeException( "transformed column " + column_name, "type inference failed, transform_spark and transform_python had differing types.", diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index f4434cc32e..a72f0b53b2 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -45,12 +45,6 @@ def test_read_csv_valid(spark, write_csv_file, ctx_obj, get_context): assert spark_util.read_csv(get_context(ctx_obj), spark).count() == 3 - ctx_obj["environment"] = { - "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_float", "c_long"]} - } - - assert spark_util.read_csv(get_context(ctx_obj), spark).count() == 3 - def test_read_csv_invalid_type(spark, write_csv_file, ctx_obj, get_context): csv_str = "\n".join(["a,0.1,", "b,1,1", "c,1.1,4"]) diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index aead7b828d..11390410b0 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -95,8 +95,8 @@ def create_prediction_request(transformed_sample): prediction_request.model_spec.signature_name = signature_key for column_name, value in transformed_sample.items(): - column_Type = ctx.get_inferred_column_type(column_name) - data_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[column_Type] + column_type = ctx.get_inferred_column_type(column_name) + data_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[column_type] shape = [1] if util.is_list(value): shape = [len(value)] @@ -213,7 +213,7 @@ def is_valid_sample(sample): is_valid = util.CORTEX_TYPE_TO_UPCAST_VALIDATOR[column_type](sample_val) if not is_valid: - return (False, "{} should be a {}".format(column["name"], column["type"])) + return (False, "{} should be a {}".format(column["name"], column_type)) return True, None @@ -266,7 +266,8 @@ def predict(app_name, api_name): return prediction_failed(sample, reason) for column in local_cache["required_inputs"]: - sample[column["name"]] = util.upcast(sample[column["name"]], column["type"]) + column_type = local_cache["ctx"].get_inferred_column_type(column["name"]) + sample[column["name"]] = util.upcast(sample[column["name"]], column_type) try: result = run_predict(sample) From 40116becf27c258b971ef4d8b356ddb973f85aa8 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 29 May 2019 17:20:12 -0400 Subject: [PATCH 09/27] infer -> value --- pkg/operator/api/context/raw_columns.go | 10 ++++----- pkg/operator/api/context/serialize.go | 24 +++++++++++----------- pkg/operator/api/userconfig/column_type.go | 6 +++--- pkg/operator/api/userconfig/config.go | 12 +++++------ pkg/operator/api/userconfig/raw_columns.go | 10 ++++----- pkg/operator/api/userconfig/validators.go | 2 +- pkg/operator/context/raw_columns.go | 6 +++--- pkg/operator/context/transformers.go | 2 +- pkg/workloads/consts.py | 4 ++-- pkg/workloads/lib/context.py | 2 +- pkg/workloads/spark_job/spark_util.py | 8 ++++---- 11 files changed, 43 insertions(+), 43 deletions(-) diff --git a/pkg/operator/api/context/raw_columns.go b/pkg/operator/api/context/raw_columns.go index 6a6830a19a..796105f188 100644 --- a/pkg/operator/api/context/raw_columns.go +++ b/pkg/operator/api/context/raw_columns.go @@ -44,8 +44,8 @@ type RawStringColumn struct { *ComputedResourceFields } -type RawInferredColumn struct { - *userconfig.RawInferredColumn +type RawValueColumn struct { + *userconfig.RawValueColumn *ComputedResourceFields } @@ -91,8 +91,8 @@ func GetRawColumnUserConfig(rawColumn RawColumn) userconfig.Resource { return rawColumn.(*RawFloatColumn).RawFloatColumn case userconfig.StringColumnType: return rawColumn.(*RawStringColumn).RawStringColumn - case userconfig.InferredColumnType: - return rawColumn.(*RawInferredColumn).RawInferredColumn + case userconfig.ValueColumnType: + return rawColumn.(*RawValueColumn).RawValueColumn } return nil @@ -118,6 +118,6 @@ func (rawColumn *RawStringColumn) GetInputRawColumnNames() []string { return []string{rawColumn.GetName()} } -func (rawColumn *RawInferredColumn) GetInputRawColumnNames() []string { +func (rawColumn *RawValueColumn) GetInputRawColumnNames() []string { return []string{rawColumn.GetName()} } diff --git a/pkg/operator/api/context/serialize.go b/pkg/operator/api/context/serialize.go index dee8e00c7d..2bd9b0b2a6 100644 --- a/pkg/operator/api/context/serialize.go +++ b/pkg/operator/api/context/serialize.go @@ -25,10 +25,10 @@ import ( ) type RawColumnsTypeSplit struct { - RawIntColumns map[string]*RawIntColumn `json:"raw_int_columns"` - RawStringColumns map[string]*RawStringColumn `json:"raw_string_columns"` - RawFloatColumns map[string]*RawFloatColumn `json:"raw_float_columns"` - RawInferredColumns map[string]*RawInferredColumn `json:"raw_inferred_columns"` + RawIntColumns map[string]*RawIntColumn `json:"raw_int_columns"` + RawStringColumns map[string]*RawStringColumn `json:"raw_string_columns"` + RawFloatColumns map[string]*RawFloatColumn `json:"raw_float_columns"` + RawValueColumns map[string]*RawValueColumn `json:"raw_value_columns"` } type DataSplit struct { @@ -46,7 +46,7 @@ func (ctx Context) splitRawColumns() *RawColumnsTypeSplit { var rawIntColumns = make(map[string]*RawIntColumn) var rawFloatColumns = make(map[string]*RawFloatColumn) var rawStringColumns = make(map[string]*RawStringColumn) - var rawInferredColumns = make(map[string]*RawInferredColumn) + var rawValueColumns = make(map[string]*RawValueColumn) for name, rawColumn := range ctx.RawColumns { switch typedRawColumn := rawColumn.(type) { case *RawIntColumn: @@ -55,16 +55,16 @@ func (ctx Context) splitRawColumns() *RawColumnsTypeSplit { rawFloatColumns[name] = typedRawColumn case *RawStringColumn: rawStringColumns[name] = typedRawColumn - case *RawInferredColumn: - rawInferredColumns[name] = typedRawColumn + case *RawValueColumn: + rawValueColumns[name] = typedRawColumn } } return &RawColumnsTypeSplit{ - RawIntColumns: rawIntColumns, - RawFloatColumns: rawFloatColumns, - RawStringColumns: rawStringColumns, - RawInferredColumns: rawInferredColumns, + RawIntColumns: rawIntColumns, + RawFloatColumns: rawFloatColumns, + RawStringColumns: rawStringColumns, + RawValueColumns: rawValueColumns, } } @@ -80,7 +80,7 @@ func (serial Serial) collectRawColumns() RawColumns { for name, rawColumn := range serial.RawColumnSplit.RawStringColumns { rawColumns[name] = rawColumn } - for name, rawColumn := range serial.RawColumnSplit.RawInferredColumns { + for name, rawColumn := range serial.RawColumnSplit.RawValueColumns { rawColumns[name] = rawColumn } diff --git a/pkg/operator/api/userconfig/column_type.go b/pkg/operator/api/userconfig/column_type.go index b6f6990fe7..c682dd09a5 100644 --- a/pkg/operator/api/userconfig/column_type.go +++ b/pkg/operator/api/userconfig/column_type.go @@ -31,7 +31,7 @@ const ( IntegerListColumnType FloatListColumnType StringListColumnType - InferredColumnType + ValueColumnType ) var columnTypes = []string{ @@ -42,7 +42,7 @@ var columnTypes = []string{ "INT_LIST_COLUMN", "FLOAT_LIST_COLUMN", "STRING_LIST_COLUMN", - "INFERRED_COLUMN", + "VALUE_COLUMN", } var columnJSONPlaceholders = []string{ @@ -53,7 +53,7 @@ var columnJSONPlaceholders = []string{ "[INT]", "[FLOAT]", "[\"STRING\"]", - "INFER", + "VALUE", } func ColumnTypeFromString(s string) ColumnType { diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index c69acb40e8..5d5e36eb2f 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -444,16 +444,16 @@ func New(configs map[string][]byte, envName string) (*Config, error) { for _, env := range config.Environments { ingestedColumnNames := env.Data.GetIngestedColumns() missingColumnNames := slices.SubtractStrSlice(ingestedColumnNames, rawColumnNames) - for _, inferredColumnName := range missingColumnNames { - inferredRawColumn := &RawInferredColumn{ + for _, valueColumnName := range missingColumnNames { + valueRawColumn := &RawValueColumn{ ResourceFields: ResourceFields{ - Name: inferredColumnName, + Name: valueColumnName, }, - Type: InferredColumnType, + Type: ValueColumnType, Compute: &SparkCompute{}, } - cr.Struct(inferredRawColumn.Compute, make(map[string]interface{}), sparkComputeStructValidation) - config.RawColumns = append(config.RawColumns, inferredRawColumn) + cr.Struct(valueRawColumn.Compute, make(map[string]interface{}), sparkComputeStructValidation) + config.RawColumns = append(config.RawColumns, valueRawColumn) } } diff --git a/pkg/operator/api/userconfig/raw_columns.go b/pkg/operator/api/userconfig/raw_columns.go index 8bda695782..5de8b03d6a 100644 --- a/pkg/operator/api/userconfig/raw_columns.go +++ b/pkg/operator/api/userconfig/raw_columns.go @@ -180,7 +180,7 @@ var rawStringColumnFieldValidations = []*cr.StructFieldValidation{ typeFieldValidation, } -type RawInferredColumn struct { +type RawValueColumn struct { ResourceFields Type ColumnType `json:"type" yaml:"type"` Compute *SparkCompute `json:"compute" yaml:"compute"` @@ -229,7 +229,7 @@ func (column *RawStringColumn) GetType() ColumnType { return column.Type } -func (column *RawInferredColumn) GetType() ColumnType { +func (column *RawValueColumn) GetType() ColumnType { return column.Type } @@ -245,7 +245,7 @@ func (column *RawStringColumn) GetCompute() *SparkCompute { return column.Compute } -func (column *RawInferredColumn) GetCompute() *SparkCompute { +func (column *RawValueColumn) GetCompute() *SparkCompute { return column.Compute } @@ -261,7 +261,7 @@ func (column *RawStringColumn) GetResourceType() resource.Type { return resource.RawColumnType } -func (column *RawInferredColumn) GetResourceType() resource.Type { +func (column *RawValueColumn) GetResourceType() resource.Type { return resource.RawColumnType } @@ -277,6 +277,6 @@ func (column *RawStringColumn) IsRaw() bool { return true } -func (column *RawInferredColumn) IsRaw() bool { +func (column *RawValueColumn) IsRaw() bool { return true } diff --git a/pkg/operator/api/userconfig/validators.go b/pkg/operator/api/userconfig/validators.go index 3a107a6bc7..e6f275b78b 100644 --- a/pkg/operator/api/userconfig/validators.go +++ b/pkg/operator/api/userconfig/validators.go @@ -140,7 +140,7 @@ func CheckColumnRuntimeTypesMatch(columnRuntimeTypes map[string]interface{}, col return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeTypeInter, validTypes), columnInputName) } - if columnRuntimeType == InferredColumnType { + if columnRuntimeType == ValueColumnType { continue } diff --git a/pkg/operator/context/raw_columns.go b/pkg/operator/context/raw_columns.go index 9e236ba913..579ab9960b 100644 --- a/pkg/operator/context/raw_columns.go +++ b/pkg/operator/context/raw_columns.go @@ -93,17 +93,17 @@ func getRawColumns( }, RawStringColumn: typedColumnConfig, } - case *userconfig.RawInferredColumn: + case *userconfig.RawValueColumn: buf.WriteString(typedColumnConfig.Name) id := hash.Bytes(buf.Bytes()) - rawColumn = &context.RawInferredColumn{ + rawColumn = &context.RawValueColumn{ ComputedResourceFields: &context.ComputedResourceFields{ ResourceFields: &context.ResourceFields{ ID: id, ResourceType: resource.RawColumnType, }, }, - RawInferredColumn: typedColumnConfig, + RawValueColumn: typedColumnConfig, } default: return nil, errors.Wrap(configreader.ErrorInvalidStr(typedColumnConfig.GetType().String(), userconfig.IntegerColumnType.String(), userconfig.FloatColumnType.String(), userconfig.StringColumnType.String()), userconfig.Identify(columnConfig)) // unexpected error diff --git a/pkg/operator/context/transformers.go b/pkg/operator/context/transformers.go index d19aac4fbd..16153cd4eb 100644 --- a/pkg/operator/context/transformers.go +++ b/pkg/operator/context/transformers.go @@ -67,7 +67,7 @@ func loadUserTransformers( ResourceFields: userconfig.ResourceFields{ Name: implHash, }, - OutputType: userconfig.InferredColumnType, + OutputType: userconfig.ValueColumnType, Path: *transColConfig.TransformerPath, } transformer, err := newTransformer(*anonTransformerConfig, impl, nil, pythonPackages) diff --git a/pkg/workloads/consts.py b/pkg/workloads/consts.py index ed330d68be..e0d8875ff6 100644 --- a/pkg/workloads/consts.py +++ b/pkg/workloads/consts.py @@ -20,7 +20,7 @@ COLUMN_TYPE_INT_LIST = "INT_LIST_COLUMN" COLUMN_TYPE_FLOAT_LIST = "FLOAT_LIST_COLUMN" COLUMN_TYPE_STRING_LIST = "STRING_LIST_COLUMN" -COLUMN_TYPE_INFERRED = "INFERRED_COLUMN" +COLUMN_TYPE_VALUE = "VALUE_COLUMN" COLUMN_LIST_TYPES = [COLUMN_TYPE_INT_LIST, COLUMN_TYPE_FLOAT_LIST, COLUMN_TYPE_STRING_LIST] @@ -31,7 +31,7 @@ COLUMN_TYPE_INT_LIST, COLUMN_TYPE_FLOAT_LIST, COLUMN_TYPE_STRING_LIST, - COLUMN_TYPE_INFERRED, + COLUMN_TYPE_VALUE, ] VALUE_TYPE_INT = "INT" diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 4ff01c2058..a7ab5cfc60 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -489,7 +489,7 @@ def get_metadata(self, resource_id, use_cache=True): def get_inferred_column_type(self, column_name): column = self.columns[column_name] column_type = self.columns[column_name].get("type", "unknown") - if column_type == consts.COLUMN_TYPE_INFERRED: + if column_type == consts.COLUMN_TYPE_VALUE: column_type = self.get_metadata(column["id"])["type"] self.columns[column_name]["type"] = column_type diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index b24072d6bb..9b0231cd92 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -229,7 +229,7 @@ def ingest(ctx, spark): actual_type = StringType() column_type = raw_column["type"] - if column_type == consts.COLUMN_TYPE_INFERRED: + if column_type == consts.COLUMN_TYPE_VALUE: sample_df = df.select(raw_column_name).limit(1).collect() sample = sample_df[0][raw_column_name] inferred_type = infer_type(sample) @@ -517,7 +517,7 @@ def validate_transformer(column_name, test_df, ctx, spark): if hasattr(trans_impl, "transform_python"): try: - if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED: + if transformer["output_type"] == consts.COLUMN_TYPE_VALUE: sample_df = test_df.collect() sample = sample_df[0] inputs = ctx.create_column_inputs_map(sample, column_name) @@ -585,7 +585,7 @@ def validate_transformer(column_name, test_df, ctx, spark): ) ) - if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED: + if transformer["output_type"] == consts.COLUMN_TYPE_VALUE: inferred_spark_type = transform_spark_df.select(column_name).schema[0].dataType ctx.write_metadata(transformed_column["id"], {"type": inferred_spark_type}) @@ -617,7 +617,7 @@ def validate_transformer(column_name, test_df, ctx, spark): raise if hasattr(trans_impl, "transform_spark") and hasattr(trans_impl, "transform_python"): - if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED and inferred_spark_type != inferred_python_type: + if transformer["output_type"] == consts.COLUMN_TYPE_VALUE and inferred_spark_type != inferred_python_type: raise UserRuntimeException( "transformed column " + column_name, "type inference failed, transform_spark and transform_python had differing types.", From 25a48cd911d91bbb4182035ea7a25506b4c47521 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 29 May 2019 18:27:37 -0400 Subject: [PATCH 10/27] add type casting --- pkg/workloads/spark_job/spark_util.py | 32 ++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 9b0231cd92..b10a8eac5a 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -279,6 +279,19 @@ def read_csv(ctx, spark): + str(set(df.columns)) ) + # for columns with types, cast it + casted_cols = [] + for column_name in ctx.raw_columns: + column_type = ctx.raw_columns[column_name]["type"] + if column_type == consts.COLUMN_TYPE_VALUE: + casted_cols.append(F.col(column_name).alias(column_name)) + else: + casted_cols.append( + F.col(column_name).cast(CORTEX_TYPE_TO_SPARK_TYPE[column_type]).alias(column_name) + ) + + df = df.select(*casted_cols) + return df.select(*ctx.raw_columns.keys()) @@ -524,6 +537,20 @@ def validate_transformer(column_name, test_df, ctx, spark): _, impl_args = extract_inputs(column_name, ctx) initial_transformed_sample = trans_impl.transform_python(inputs, impl_args) inferred_python_type = infer_type(initial_transformed_sample) + + for row in sample_df: + inputs = ctx.create_column_inputs_map(row, column_name) + transformed_sample = trans_impl.transform_python(inputs, impl_args) + if inferred_python_type != infer_type(transformed_sample): + raise UserRuntimeException( + "transformed column " + column_name, + "type inference failed, mixed data types in dataframe.", + 'expected type of "' + + transformed_sample + + '" to be ' + + inferred_python_type, + ) + ctx.write_metadata(transformed_column["id"], {"type": inferred_python_type}) transform_python_collect = execute_transform_python( @@ -617,7 +644,10 @@ def validate_transformer(column_name, test_df, ctx, spark): raise if hasattr(trans_impl, "transform_spark") and hasattr(trans_impl, "transform_python"): - if transformer["output_type"] == consts.COLUMN_TYPE_VALUE and inferred_spark_type != inferred_python_type: + if ( + transformer["output_type"] == consts.COLUMN_TYPE_VALUE + and inferred_spark_type != inferred_python_type + ): raise UserRuntimeException( "transformed column " + column_name, "type inference failed, transform_spark and transform_python had differing types.", From d1a32b7b416fdcaa569383c5ea6da9479bec7610 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 29 May 2019 18:28:25 -0400 Subject: [PATCH 11/27] value -> infer --- pkg/operator/api/context/raw_columns.go | 10 ++++----- pkg/operator/api/context/serialize.go | 24 +++++++++++----------- pkg/operator/api/userconfig/column_type.go | 6 +++--- pkg/operator/api/userconfig/config.go | 12 +++++------ pkg/operator/api/userconfig/raw_columns.go | 10 ++++----- pkg/operator/api/userconfig/validators.go | 2 +- pkg/operator/context/raw_columns.go | 6 +++--- pkg/operator/context/transformers.go | 2 +- pkg/workloads/consts.py | 4 ++-- pkg/workloads/lib/context.py | 2 +- pkg/workloads/spark_job/spark_util.py | 8 ++++---- 11 files changed, 43 insertions(+), 43 deletions(-) diff --git a/pkg/operator/api/context/raw_columns.go b/pkg/operator/api/context/raw_columns.go index 796105f188..6a6830a19a 100644 --- a/pkg/operator/api/context/raw_columns.go +++ b/pkg/operator/api/context/raw_columns.go @@ -44,8 +44,8 @@ type RawStringColumn struct { *ComputedResourceFields } -type RawValueColumn struct { - *userconfig.RawValueColumn +type RawInferredColumn struct { + *userconfig.RawInferredColumn *ComputedResourceFields } @@ -91,8 +91,8 @@ func GetRawColumnUserConfig(rawColumn RawColumn) userconfig.Resource { return rawColumn.(*RawFloatColumn).RawFloatColumn case userconfig.StringColumnType: return rawColumn.(*RawStringColumn).RawStringColumn - case userconfig.ValueColumnType: - return rawColumn.(*RawValueColumn).RawValueColumn + case userconfig.InferredColumnType: + return rawColumn.(*RawInferredColumn).RawInferredColumn } return nil @@ -118,6 +118,6 @@ func (rawColumn *RawStringColumn) GetInputRawColumnNames() []string { return []string{rawColumn.GetName()} } -func (rawColumn *RawValueColumn) GetInputRawColumnNames() []string { +func (rawColumn *RawInferredColumn) GetInputRawColumnNames() []string { return []string{rawColumn.GetName()} } diff --git a/pkg/operator/api/context/serialize.go b/pkg/operator/api/context/serialize.go index 2bd9b0b2a6..dee8e00c7d 100644 --- a/pkg/operator/api/context/serialize.go +++ b/pkg/operator/api/context/serialize.go @@ -25,10 +25,10 @@ import ( ) type RawColumnsTypeSplit struct { - RawIntColumns map[string]*RawIntColumn `json:"raw_int_columns"` - RawStringColumns map[string]*RawStringColumn `json:"raw_string_columns"` - RawFloatColumns map[string]*RawFloatColumn `json:"raw_float_columns"` - RawValueColumns map[string]*RawValueColumn `json:"raw_value_columns"` + RawIntColumns map[string]*RawIntColumn `json:"raw_int_columns"` + RawStringColumns map[string]*RawStringColumn `json:"raw_string_columns"` + RawFloatColumns map[string]*RawFloatColumn `json:"raw_float_columns"` + RawInferredColumns map[string]*RawInferredColumn `json:"raw_inferred_columns"` } type DataSplit struct { @@ -46,7 +46,7 @@ func (ctx Context) splitRawColumns() *RawColumnsTypeSplit { var rawIntColumns = make(map[string]*RawIntColumn) var rawFloatColumns = make(map[string]*RawFloatColumn) var rawStringColumns = make(map[string]*RawStringColumn) - var rawValueColumns = make(map[string]*RawValueColumn) + var rawInferredColumns = make(map[string]*RawInferredColumn) for name, rawColumn := range ctx.RawColumns { switch typedRawColumn := rawColumn.(type) { case *RawIntColumn: @@ -55,16 +55,16 @@ func (ctx Context) splitRawColumns() *RawColumnsTypeSplit { rawFloatColumns[name] = typedRawColumn case *RawStringColumn: rawStringColumns[name] = typedRawColumn - case *RawValueColumn: - rawValueColumns[name] = typedRawColumn + case *RawInferredColumn: + rawInferredColumns[name] = typedRawColumn } } return &RawColumnsTypeSplit{ - RawIntColumns: rawIntColumns, - RawFloatColumns: rawFloatColumns, - RawStringColumns: rawStringColumns, - RawValueColumns: rawValueColumns, + RawIntColumns: rawIntColumns, + RawFloatColumns: rawFloatColumns, + RawStringColumns: rawStringColumns, + RawInferredColumns: rawInferredColumns, } } @@ -80,7 +80,7 @@ func (serial Serial) collectRawColumns() RawColumns { for name, rawColumn := range serial.RawColumnSplit.RawStringColumns { rawColumns[name] = rawColumn } - for name, rawColumn := range serial.RawColumnSplit.RawValueColumns { + for name, rawColumn := range serial.RawColumnSplit.RawInferredColumns { rawColumns[name] = rawColumn } diff --git a/pkg/operator/api/userconfig/column_type.go b/pkg/operator/api/userconfig/column_type.go index c682dd09a5..b6f6990fe7 100644 --- a/pkg/operator/api/userconfig/column_type.go +++ b/pkg/operator/api/userconfig/column_type.go @@ -31,7 +31,7 @@ const ( IntegerListColumnType FloatListColumnType StringListColumnType - ValueColumnType + InferredColumnType ) var columnTypes = []string{ @@ -42,7 +42,7 @@ var columnTypes = []string{ "INT_LIST_COLUMN", "FLOAT_LIST_COLUMN", "STRING_LIST_COLUMN", - "VALUE_COLUMN", + "INFERRED_COLUMN", } var columnJSONPlaceholders = []string{ @@ -53,7 +53,7 @@ var columnJSONPlaceholders = []string{ "[INT]", "[FLOAT]", "[\"STRING\"]", - "VALUE", + "INFER", } func ColumnTypeFromString(s string) ColumnType { diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index 5d5e36eb2f..c69acb40e8 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -444,16 +444,16 @@ func New(configs map[string][]byte, envName string) (*Config, error) { for _, env := range config.Environments { ingestedColumnNames := env.Data.GetIngestedColumns() missingColumnNames := slices.SubtractStrSlice(ingestedColumnNames, rawColumnNames) - for _, valueColumnName := range missingColumnNames { - valueRawColumn := &RawValueColumn{ + for _, inferredColumnName := range missingColumnNames { + inferredRawColumn := &RawInferredColumn{ ResourceFields: ResourceFields{ - Name: valueColumnName, + Name: inferredColumnName, }, - Type: ValueColumnType, + Type: InferredColumnType, Compute: &SparkCompute{}, } - cr.Struct(valueRawColumn.Compute, make(map[string]interface{}), sparkComputeStructValidation) - config.RawColumns = append(config.RawColumns, valueRawColumn) + cr.Struct(inferredRawColumn.Compute, make(map[string]interface{}), sparkComputeStructValidation) + config.RawColumns = append(config.RawColumns, inferredRawColumn) } } diff --git a/pkg/operator/api/userconfig/raw_columns.go b/pkg/operator/api/userconfig/raw_columns.go index 5de8b03d6a..8bda695782 100644 --- a/pkg/operator/api/userconfig/raw_columns.go +++ b/pkg/operator/api/userconfig/raw_columns.go @@ -180,7 +180,7 @@ var rawStringColumnFieldValidations = []*cr.StructFieldValidation{ typeFieldValidation, } -type RawValueColumn struct { +type RawInferredColumn struct { ResourceFields Type ColumnType `json:"type" yaml:"type"` Compute *SparkCompute `json:"compute" yaml:"compute"` @@ -229,7 +229,7 @@ func (column *RawStringColumn) GetType() ColumnType { return column.Type } -func (column *RawValueColumn) GetType() ColumnType { +func (column *RawInferredColumn) GetType() ColumnType { return column.Type } @@ -245,7 +245,7 @@ func (column *RawStringColumn) GetCompute() *SparkCompute { return column.Compute } -func (column *RawValueColumn) GetCompute() *SparkCompute { +func (column *RawInferredColumn) GetCompute() *SparkCompute { return column.Compute } @@ -261,7 +261,7 @@ func (column *RawStringColumn) GetResourceType() resource.Type { return resource.RawColumnType } -func (column *RawValueColumn) GetResourceType() resource.Type { +func (column *RawInferredColumn) GetResourceType() resource.Type { return resource.RawColumnType } @@ -277,6 +277,6 @@ func (column *RawStringColumn) IsRaw() bool { return true } -func (column *RawValueColumn) IsRaw() bool { +func (column *RawInferredColumn) IsRaw() bool { return true } diff --git a/pkg/operator/api/userconfig/validators.go b/pkg/operator/api/userconfig/validators.go index e6f275b78b..3a107a6bc7 100644 --- a/pkg/operator/api/userconfig/validators.go +++ b/pkg/operator/api/userconfig/validators.go @@ -140,7 +140,7 @@ func CheckColumnRuntimeTypesMatch(columnRuntimeTypes map[string]interface{}, col return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeTypeInter, validTypes), columnInputName) } - if columnRuntimeType == ValueColumnType { + if columnRuntimeType == InferredColumnType { continue } diff --git a/pkg/operator/context/raw_columns.go b/pkg/operator/context/raw_columns.go index 579ab9960b..9e236ba913 100644 --- a/pkg/operator/context/raw_columns.go +++ b/pkg/operator/context/raw_columns.go @@ -93,17 +93,17 @@ func getRawColumns( }, RawStringColumn: typedColumnConfig, } - case *userconfig.RawValueColumn: + case *userconfig.RawInferredColumn: buf.WriteString(typedColumnConfig.Name) id := hash.Bytes(buf.Bytes()) - rawColumn = &context.RawValueColumn{ + rawColumn = &context.RawInferredColumn{ ComputedResourceFields: &context.ComputedResourceFields{ ResourceFields: &context.ResourceFields{ ID: id, ResourceType: resource.RawColumnType, }, }, - RawValueColumn: typedColumnConfig, + RawInferredColumn: typedColumnConfig, } default: return nil, errors.Wrap(configreader.ErrorInvalidStr(typedColumnConfig.GetType().String(), userconfig.IntegerColumnType.String(), userconfig.FloatColumnType.String(), userconfig.StringColumnType.String()), userconfig.Identify(columnConfig)) // unexpected error diff --git a/pkg/operator/context/transformers.go b/pkg/operator/context/transformers.go index 16153cd4eb..d19aac4fbd 100644 --- a/pkg/operator/context/transformers.go +++ b/pkg/operator/context/transformers.go @@ -67,7 +67,7 @@ func loadUserTransformers( ResourceFields: userconfig.ResourceFields{ Name: implHash, }, - OutputType: userconfig.ValueColumnType, + OutputType: userconfig.InferredColumnType, Path: *transColConfig.TransformerPath, } transformer, err := newTransformer(*anonTransformerConfig, impl, nil, pythonPackages) diff --git a/pkg/workloads/consts.py b/pkg/workloads/consts.py index e0d8875ff6..ed330d68be 100644 --- a/pkg/workloads/consts.py +++ b/pkg/workloads/consts.py @@ -20,7 +20,7 @@ COLUMN_TYPE_INT_LIST = "INT_LIST_COLUMN" COLUMN_TYPE_FLOAT_LIST = "FLOAT_LIST_COLUMN" COLUMN_TYPE_STRING_LIST = "STRING_LIST_COLUMN" -COLUMN_TYPE_VALUE = "VALUE_COLUMN" +COLUMN_TYPE_INFERRED = "INFERRED_COLUMN" COLUMN_LIST_TYPES = [COLUMN_TYPE_INT_LIST, COLUMN_TYPE_FLOAT_LIST, COLUMN_TYPE_STRING_LIST] @@ -31,7 +31,7 @@ COLUMN_TYPE_INT_LIST, COLUMN_TYPE_FLOAT_LIST, COLUMN_TYPE_STRING_LIST, - COLUMN_TYPE_VALUE, + COLUMN_TYPE_INFERRED, ] VALUE_TYPE_INT = "INT" diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index a7ab5cfc60..4ff01c2058 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -489,7 +489,7 @@ def get_metadata(self, resource_id, use_cache=True): def get_inferred_column_type(self, column_name): column = self.columns[column_name] column_type = self.columns[column_name].get("type", "unknown") - if column_type == consts.COLUMN_TYPE_VALUE: + if column_type == consts.COLUMN_TYPE_INFERRED: column_type = self.get_metadata(column["id"])["type"] self.columns[column_name]["type"] = column_type diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index b10a8eac5a..94a319ac7b 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -229,7 +229,7 @@ def ingest(ctx, spark): actual_type = StringType() column_type = raw_column["type"] - if column_type == consts.COLUMN_TYPE_VALUE: + if column_type == consts.COLUMN_TYPE_INFERRED: sample_df = df.select(raw_column_name).limit(1).collect() sample = sample_df[0][raw_column_name] inferred_type = infer_type(sample) @@ -530,7 +530,7 @@ def validate_transformer(column_name, test_df, ctx, spark): if hasattr(trans_impl, "transform_python"): try: - if transformer["output_type"] == consts.COLUMN_TYPE_VALUE: + if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED: sample_df = test_df.collect() sample = sample_df[0] inputs = ctx.create_column_inputs_map(sample, column_name) @@ -612,7 +612,7 @@ def validate_transformer(column_name, test_df, ctx, spark): ) ) - if transformer["output_type"] == consts.COLUMN_TYPE_VALUE: + if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED: inferred_spark_type = transform_spark_df.select(column_name).schema[0].dataType ctx.write_metadata(transformed_column["id"], {"type": inferred_spark_type}) @@ -645,7 +645,7 @@ def validate_transformer(column_name, test_df, ctx, spark): if hasattr(trans_impl, "transform_spark") and hasattr(trans_impl, "transform_python"): if ( - transformer["output_type"] == consts.COLUMN_TYPE_VALUE + transformer["output_type"] == consts.COLUMN_TYPE_INFERRED and inferred_spark_type != inferred_python_type ): raise UserRuntimeException( From 61a4c33a186557950db267679a1dc34c887d28d2 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 29 May 2019 18:35:37 -0400 Subject: [PATCH 12/27] value -> infer --- pkg/workloads/spark_job/spark_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 94a319ac7b..6a5b5006e6 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -283,7 +283,7 @@ def read_csv(ctx, spark): casted_cols = [] for column_name in ctx.raw_columns: column_type = ctx.raw_columns[column_name]["type"] - if column_type == consts.COLUMN_TYPE_VALUE: + if column_type == consts.COLUMN_TYPE_INFERRED: casted_cols.append(F.col(column_name).alias(column_name)) else: casted_cols.append( From 1c9d02c92601cd8cce80fe1d9c9ca9972139afc3 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 29 May 2019 19:23:16 -0400 Subject: [PATCH 13/27] address spark comments --- pkg/workloads/spark_job/spark_util.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 6a5b5006e6..14554d0d4f 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -220,6 +220,7 @@ def ingest(ctx, spark): df = read_parquet(ctx, spark) input_type_map = {f.name: f.dataType for f in df.schema} + inferred_col_type_map = {c.name: c.dataType for c in df.schema} for raw_column_name in ctx.raw_columns.keys(): raw_column = ctx.raw_columns[raw_column_name] @@ -230,11 +231,8 @@ def ingest(ctx, spark): column_type = raw_column["type"] if column_type == consts.COLUMN_TYPE_INFERRED: - sample_df = df.select(raw_column_name).limit(1).collect() - sample = sample_df[0][raw_column_name] - inferred_type = infer_type(sample) - ctx.write_metadata(raw_column["id"], {"type": inferred_type}) - column_type = inferred_type + column_type = SPARK_TYPE_TO_CORTEX_TYPE[inferred_col_type_map[raw_column_name]] + ctx.write_metadata(raw_column["id"], {"type": column_type}) expected_types = CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[column_type] if actual_type not in expected_types: @@ -284,15 +282,14 @@ def read_csv(ctx, spark): for column_name in ctx.raw_columns: column_type = ctx.raw_columns[column_name]["type"] if column_type == consts.COLUMN_TYPE_INFERRED: - casted_cols.append(F.col(column_name).alias(column_name)) + casted_cols.append(F.col(column_name)) else: casted_cols.append( F.col(column_name).cast(CORTEX_TYPE_TO_SPARK_TYPE[column_type]).alias(column_name) ) - df = df.select(*casted_cols) + return df.select(*casted_cols) - return df.select(*ctx.raw_columns.keys()) def read_parquet(ctx, spark): From e000c6039030c3c3bf00ec77c477d149bd6a34a1 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 29 May 2019 19:24:29 -0400 Subject: [PATCH 14/27] format --- pkg/workloads/spark_job/spark_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 14554d0d4f..512d778442 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -291,7 +291,6 @@ def read_csv(ctx, spark): return df.select(*casted_cols) - def read_parquet(ctx, spark): parquet_config = ctx.environment["data"] df = spark.read.parquet(parquet_config["path"]) From 1aa7dc60a2e806f23eb52b32d10c23ba44c6821f Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 29 May 2019 19:53:38 -0400 Subject: [PATCH 15/27] use VALUE for jsonp --- pkg/operator/api/userconfig/column_type.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/operator/api/userconfig/column_type.go b/pkg/operator/api/userconfig/column_type.go index b6f6990fe7..9195df82be 100644 --- a/pkg/operator/api/userconfig/column_type.go +++ b/pkg/operator/api/userconfig/column_type.go @@ -53,7 +53,7 @@ var columnJSONPlaceholders = []string{ "[INT]", "[FLOAT]", "[\"STRING\"]", - "INFER", + "VALUE", } func ColumnTypeFromString(s string) ColumnType { From fdfc5cb813717c7b1bea1ac44681f28235c92565 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Thu, 30 May 2019 16:57:35 -0400 Subject: [PATCH 16/27] cast types and validate --- pkg/workloads/spark_job/spark_util.py | 70 ++++++++---------- .../spark_job/test/unit/spark_util_test.py | 73 ++++++++++++++++++- 2 files changed, 104 insertions(+), 39 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 512d778442..c4642824cd 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -48,6 +48,15 @@ consts.COLUMN_TYPE_STRING_LIST: [ArrayType(StringType(), True)], } +CORTEX_TYPE_TO_CASTABLE_SPARK_TYPES = { + consts.COLUMN_TYPE_INT: [IntegerType(), LongType()], + consts.COLUMN_TYPE_INT_LIST: [ArrayType(IntegerType(), True), ArrayType(LongType(), True)], + consts.COLUMN_TYPE_FLOAT: [FloatType(), DoubleType()], + consts.COLUMN_TYPE_FLOAT_LIST: [ArrayType(FloatType(), True), ArrayType(DoubleType(), True)], + consts.COLUMN_TYPE_STRING: [StringType(), IntegerType(), LongType(), FloatType(), DoubleType()], + consts.COLUMN_TYPE_STRING_LIST: [ArrayType(StringType(), True)], +} + PYTHON_TYPE_TO_CORTEX_TYPE = { int: consts.COLUMN_TYPE_INT, float: consts.COLUMN_TYPE_FLOAT, @@ -221,35 +230,33 @@ def ingest(ctx, spark): input_type_map = {f.name: f.dataType for f in df.schema} inferred_col_type_map = {c.name: c.dataType for c in df.schema} - - for raw_column_name in ctx.raw_columns.keys(): + for raw_column_name in ctx.raw_columns: raw_column = ctx.raw_columns[raw_column_name] - actual_type = input_type_map[raw_column_name] - - if actual_type not in SPARK_TYPE_TO_CORTEX_TYPE: - actual_type = StringType() - - column_type = raw_column["type"] - if column_type == consts.COLUMN_TYPE_INFERRED: + expected_cortex_type = raw_column["type"] + if expected_cortex_type == consts.COLUMN_TYPE_INFERRED: column_type = SPARK_TYPE_TO_CORTEX_TYPE[inferred_col_type_map[raw_column_name]] ctx.write_metadata(raw_column["id"], {"type": column_type}) + else: + actual_spark_type = input_type_map[raw_column_name] + if actual_spark_type not in SPARK_TYPE_TO_CORTEX_TYPE: + actual_spark_type = StringType() - expected_types = CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[column_type] - if actual_type not in expected_types: - logger.error("found schema:") - log_df_schema(df, logger.error) + expected_types = CORTEX_TYPE_TO_CASTABLE_SPARK_TYPES[expected_cortex_type] + if actual_spark_type not in expected_types: + logger.error("found schema:") + log_df_schema(df, logger.error) - raise UserException( - "raw column " + raw_column_name, - "type mismatch", - "expected {} but found {}".format( - " or ".join(str(x) for x in expected_types), actual_type - ), - ) - target_type = CORTEX_TYPE_TO_SPARK_TYPE[column_type] + raise UserException( + "raw column " + raw_column_name, + "type mismatch", + "expected {} but found {}".format( + " or ".join(str(x) for x in expected_types), actual_spark_type + ), + ) - if target_type != actual_type: - df = df.withColumn(raw_column_name, F.col(raw_column_name).cast(target_type)) + if actual_spark_type != expected_cortex_type: + target_spark_type = CORTEX_TYPE_TO_SPARK_TYPE[expected_cortex_type] + df = df.withColumn(raw_column_name, F.col(raw_column_name).cast(target_spark_type)) return df.select(*sorted(df.columns)) @@ -263,7 +270,7 @@ def read_csv(ctx, spark): if val is not None } - df = spark.read.csv(data_config["path"], inferSchema=True, **csv_config) + df = spark.read.csv(data_config["path"], inferSchema=True, mode="FAILFAST", **csv_config) renamed_cols = [F.col(c).alias(data_config["schema"][idx]) for idx, c in enumerate(df.columns)] df = df.select(*renamed_cols) @@ -276,20 +283,7 @@ def read_csv(ctx, spark): + " but got " + str(set(df.columns)) ) - - # for columns with types, cast it - casted_cols = [] - for column_name in ctx.raw_columns: - column_type = ctx.raw_columns[column_name]["type"] - if column_type == consts.COLUMN_TYPE_INFERRED: - casted_cols.append(F.col(column_name)) - else: - casted_cols.append( - F.col(column_name).cast(CORTEX_TYPE_TO_SPARK_TYPE[column_type]).alias(column_name) - ) - - return df.select(*casted_cols) - + return df def read_parquet(ctx, spark): parquet_config = ctx.environment["data"] diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index a72f0b53b2..e57df0319a 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -47,7 +47,7 @@ def test_read_csv_valid(spark, write_csv_file, ctx_obj, get_context): def test_read_csv_invalid_type(spark, write_csv_file, ctx_obj, get_context): - csv_str = "\n".join(["a,0.1,", "b,1,1", "c,1.1,4"]) + csv_str = "\n".join(["a,0.1,", "b,b,1", "c,1.1,4"]) path_to_file = write_csv_file(csv_str) @@ -64,6 +64,77 @@ def test_read_csv_invalid_type(spark, write_csv_file, ctx_obj, get_context): with pytest.raises(UserException): spark_util.ingest(get_context(ctx_obj), spark).collect() +def test_read_csv_infer_type(spark, write_csv_file, ctx_obj, get_context): + csv_str = "\n".join(["a,0.1,", "b,0.1,1", "c,1.1,4"]) + + path_to_file = write_csv_file(csv_str) + + ctx_obj["environment"] = { + "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_float", "c_long"]} + } + + df = spark_util.ingest(get_context(ctx_obj), spark) + assert df.count() == 3 + inferred_col_type_map = {c.name: c.dataType for c in df.schema} + assert inferred_col_type_map["a_str"] == StringType() + assert inferred_col_type_map["b_float"] == DoubleType() + assert inferred_col_type_map["c_long"] == IntegerType() + + +def test_read_csv_infer_and_cast_type(spark, write_csv_file, ctx_obj, get_context): + csv_str = "\n".join(["1,4,4.5", "1,3,1.2", "1,5,4.7"]) + + path_to_file = write_csv_file(csv_str) + + ctx_obj["environment"] = { + "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_int", "c_float"]} + } + + ctx_obj["raw_columns"] = { + "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "-"}, + "c_float": {"name": "c_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"}, + } + + df = spark_util.ingest(get_context(ctx_obj), spark) + assert df.count() == 3 + inferred_col_type_map = {c.name: c.dataType for c in df.schema} + assert inferred_col_type_map["a_str"] == StringType() + assert inferred_col_type_map["b_int"] == IntegerType() + assert inferred_col_type_map["c_float"] == FloatType() + + +def test_read_csv_infer_invalid(spark, write_csv_file, ctx_obj, get_context): + csv_str = "\n".join(["a,0.1,", "a,0.1,1", "a,1.1,4"]) + + path_to_file = write_csv_file(csv_str) + + ctx_obj["environment"] = { + "data": {"type": "csv", "path": path_to_file, "schema": ["a_int", "b_float", "c_long"]} + } + + ctx_obj["raw_columns"] = { + "a_int": {"name": "a_int", "type": "INT_COLUMN", "required": True, "id": "-"}, + } + + with pytest.raises(UserException): + spark_util.ingest(get_context(ctx_obj), spark).collect() + + csv_str = "\n".join(["a,1,", "a,1,1", "a,1,4"]) + path_to_file = write_csv_file(csv_str) + + ctx_obj["environment"] = { + "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_float", "c_long"]} + } + + ctx_obj["raw_columns"] = { + "b_float": {"name": "a_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"}, + } + + with pytest.raises(UserException): + spark_util.ingest(get_context(ctx_obj), spark).collect() + + + def test_read_csv_missing_column(spark, write_csv_file, ctx_obj, get_context): csv_str = "\n".join(["a,1,", "b,1,"]) From 659c89772f64bb3cb10b96b17628c796a9abab06 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Thu, 30 May 2019 16:59:38 -0400 Subject: [PATCH 17/27] lint --- pkg/workloads/spark_job/spark_util.py | 1 + pkg/workloads/spark_job/test/unit/spark_util_test.py | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index c4642824cd..d0241aeb19 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -285,6 +285,7 @@ def read_csv(ctx, spark): ) return df + def read_parquet(ctx, spark): parquet_config = ctx.environment["data"] df = spark.read.parquet(parquet_config["path"]) diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index e57df0319a..02aec117c8 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -64,6 +64,7 @@ def test_read_csv_invalid_type(spark, write_csv_file, ctx_obj, get_context): with pytest.raises(UserException): spark_util.ingest(get_context(ctx_obj), spark).collect() + def test_read_csv_infer_type(spark, write_csv_file, ctx_obj, get_context): csv_str = "\n".join(["a,0.1,", "b,0.1,1", "c,1.1,4"]) @@ -113,7 +114,7 @@ def test_read_csv_infer_invalid(spark, write_csv_file, ctx_obj, get_context): } ctx_obj["raw_columns"] = { - "a_int": {"name": "a_int", "type": "INT_COLUMN", "required": True, "id": "-"}, + "a_int": {"name": "a_int", "type": "INT_COLUMN", "required": True, "id": "-"} } with pytest.raises(UserException): @@ -127,15 +128,13 @@ def test_read_csv_infer_invalid(spark, write_csv_file, ctx_obj, get_context): } ctx_obj["raw_columns"] = { - "b_float": {"name": "a_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"}, + "b_float": {"name": "a_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"} } with pytest.raises(UserException): spark_util.ingest(get_context(ctx_obj), spark).collect() - - def test_read_csv_missing_column(spark, write_csv_file, ctx_obj, get_context): csv_str = "\n".join(["a,1,", "b,1,"]) From 90571775b04359a81647aae85ed5655785b7a11c Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Thu, 30 May 2019 17:37:42 -0400 Subject: [PATCH 18/27] add more casting test and table it --- .../spark_job/test/unit/spark_util_test.py | 125 ++++++++++-------- 1 file changed, 69 insertions(+), 56 deletions(-) diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index 02aec117c8..df07850867 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -66,73 +66,86 @@ def test_read_csv_invalid_type(spark, write_csv_file, ctx_obj, get_context): def test_read_csv_infer_type(spark, write_csv_file, ctx_obj, get_context): - csv_str = "\n".join(["a,0.1,", "b,0.1,1", "c,1.1,4"]) - - path_to_file = write_csv_file(csv_str) - - ctx_obj["environment"] = { - "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_float", "c_long"]} - } - - df = spark_util.ingest(get_context(ctx_obj), spark) - assert df.count() == 3 - inferred_col_type_map = {c.name: c.dataType for c in df.schema} - assert inferred_col_type_map["a_str"] == StringType() - assert inferred_col_type_map["b_float"] == DoubleType() - assert inferred_col_type_map["c_long"] == IntegerType() - - -def test_read_csv_infer_and_cast_type(spark, write_csv_file, ctx_obj, get_context): - csv_str = "\n".join(["1,4,4.5", "1,3,1.2", "1,5,4.7"]) + test_cases = [ + { + "csv": ["a,0.1,", "b,0.1,1", "c,1.1,4"], + "schema": ["a_str", "b_float", "c_long"], + "raw_columns": {}, + "expected_types": { + "a_str": StringType(), + "b_float": DoubleType(), + "c_long": IntegerType(), + }, + }, + { + "csv": ["1,4,4.5", "1,3,1.2", "1,5,4.7"], + "schema": ["a_str", "b_int", "c_float"], + "raw_columns": { + "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "-"}, + "c_float": {"name": "c_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"}, + }, + "expected_types": { + "a_str": StringType(), + "b_int": IntegerType(), + "c_float": FloatType(), + }, + }, + ] - path_to_file = write_csv_file(csv_str) + for test in test_cases: + csv_str = "\n".join(test["csv"]) + path_to_file = write_csv_file(csv_str) - ctx_obj["environment"] = { - "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_int", "c_float"]} - } + ctx_obj["environment"] = { + "data": {"type": "csv", "path": path_to_file, "schema": test["schema"]} + } - ctx_obj["raw_columns"] = { - "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "-"}, - "c_float": {"name": "c_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"}, - } + ctx_obj["raw_columns"] = test["raw_columns"] - df = spark_util.ingest(get_context(ctx_obj), spark) - assert df.count() == 3 - inferred_col_type_map = {c.name: c.dataType for c in df.schema} - assert inferred_col_type_map["a_str"] == StringType() - assert inferred_col_type_map["b_int"] == IntegerType() - assert inferred_col_type_map["c_float"] == FloatType() + df = spark_util.ingest(get_context(ctx_obj), spark) + assert df.count() == len(test["expected_types"]) + inferred_col_type_map = {c.name: c.dataType for c in df.schema} + for column_name in test["expected_types"]: + assert inferred_col_type_map[column_name] == test["expected_types"][column_name] def test_read_csv_infer_invalid(spark, write_csv_file, ctx_obj, get_context): - csv_str = "\n".join(["a,0.1,", "a,0.1,1", "a,1.1,4"]) - - path_to_file = write_csv_file(csv_str) - - ctx_obj["environment"] = { - "data": {"type": "csv", "path": path_to_file, "schema": ["a_int", "b_float", "c_long"]} - } - - ctx_obj["raw_columns"] = { - "a_int": {"name": "a_int", "type": "INT_COLUMN", "required": True, "id": "-"} - } - - with pytest.raises(UserException): - spark_util.ingest(get_context(ctx_obj), spark).collect() + test_cases = [ + { + "csv": ["a,0.1,", "a,0.1,1", "a,1.1,4"], + "schema": ["a_int", "b_float", "c_long"], + "raw_columns": { + "a_int": {"name": "a_int", "type": "INT_COLUMN", "required": True, "id": "-"} + }, + }, + { + "csv": ["a,1,", "a,1,1", "a,1,4"], + "schema": ["a_int", "b_float", "c_long"], + "raw_columns": { + "b_float": {"name": "b_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"} + }, + }, + { + "csv": ["a,1.1,", "a,1.1,1", "a,1.1,4"], + "schema": ["a_str", "b_int", "c_long"], + "raw_columns": { + "b_int": {"name": "b_int", "type": "INT_COLUMN", "required": True, "id": "-"} + }, + }, + ] - csv_str = "\n".join(["a,1,", "a,1,1", "a,1,4"]) - path_to_file = write_csv_file(csv_str) + for test in test_cases: + csv_str = "\n".join(test["csv"]) + path_to_file = write_csv_file(csv_str) - ctx_obj["environment"] = { - "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_float", "c_long"]} - } + ctx_obj["environment"] = { + "data": {"type": "csv", "path": path_to_file, "schema": test["schema"]} + } - ctx_obj["raw_columns"] = { - "b_float": {"name": "a_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"} - } + ctx_obj["raw_columns"] = test["raw_columns"] - with pytest.raises(UserException): - spark_util.ingest(get_context(ctx_obj), spark).collect() + with pytest.raises(UserException): + spark_util.ingest(get_context(ctx_obj), spark).collect() def test_read_csv_missing_column(spark, write_csv_file, ctx_obj, get_context): From 72b520c62d74b76d0f410c6cf52ea2143bb1795a Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Fri, 31 May 2019 14:51:31 -0400 Subject: [PATCH 19/27] add more tests and refactor --- pkg/operator/api/userconfig/config.go | 1 + pkg/workloads/spark_job/spark_job.py | 7 + pkg/workloads/spark_job/spark_util.py | 25 +-- .../spark_job/test/unit/spark_util_test.py | 192 ++++++++++++++++-- 4 files changed, 197 insertions(+), 28 deletions(-) diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index c69acb40e8..74f8fbee6c 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -455,6 +455,7 @@ func New(configs map[string][]byte, envName string) (*Config, error) { cr.Struct(inferredRawColumn.Compute, make(map[string]interface{}), sparkComputeStructValidation) config.RawColumns = append(config.RawColumns, inferredRawColumn) } + rawColumnNames = config.RawColumns.Names() } if err := config.Validate(envName); err != nil { diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 253bb2b093..8b60d8dca8 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -25,6 +25,7 @@ from lib.exceptions import UserException, CortexException, UserRuntimeException import spark_util import pyspark.sql.functions as F +import consts logger = get_logger() @@ -150,6 +151,12 @@ def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest): logger.info("Ingesting {} data from {}".format(ctx.app["name"], data_config["path"])) ingest_df = spark_util.ingest(ctx, spark) + input_type_map = {f.name: f.dataType for f in ingest_df.schema} + for raw_column_name in ctx.raw_columns: + if ctx.raw_columns[raw_column_name]["type"] == consts.COLUMN_TYPE_INFERRED: + column_type = spark_util.SPARK_TYPE_TO_CORTEX_TYPE[input_type_map[raw_column_name]] + ctx.write_metadata(ctx.raw_columns[raw_column_name]["id"], {"type": column_type}) + full_dataset_size = ingest_df.count() if data_config.get("drop_null"): diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index d0241aeb19..fcd564342e 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -51,7 +51,7 @@ CORTEX_TYPE_TO_CASTABLE_SPARK_TYPES = { consts.COLUMN_TYPE_INT: [IntegerType(), LongType()], consts.COLUMN_TYPE_INT_LIST: [ArrayType(IntegerType(), True), ArrayType(LongType(), True)], - consts.COLUMN_TYPE_FLOAT: [FloatType(), DoubleType()], + consts.COLUMN_TYPE_FLOAT: [FloatType(), DoubleType(), IntegerType(), LongType()], consts.COLUMN_TYPE_FLOAT_LIST: [ArrayType(FloatType(), True), ArrayType(DoubleType(), True)], consts.COLUMN_TYPE_STRING: [StringType(), IntegerType(), LongType(), FloatType(), DoubleType()], consts.COLUMN_TYPE_STRING_LIST: [ArrayType(StringType(), True)], @@ -229,17 +229,15 @@ def ingest(ctx, spark): df = read_parquet(ctx, spark) input_type_map = {f.name: f.dataType for f in df.schema} - inferred_col_type_map = {c.name: c.dataType for c in df.schema} + logger.info(ctx.raw_columns) for raw_column_name in ctx.raw_columns: raw_column = ctx.raw_columns[raw_column_name] expected_cortex_type = raw_column["type"] - if expected_cortex_type == consts.COLUMN_TYPE_INFERRED: - column_type = SPARK_TYPE_TO_CORTEX_TYPE[inferred_col_type_map[raw_column_name]] - ctx.write_metadata(raw_column["id"], {"type": column_type}) - else: - actual_spark_type = input_type_map[raw_column_name] - if actual_spark_type not in SPARK_TYPE_TO_CORTEX_TYPE: - actual_spark_type = StringType() + actual_spark_type = input_type_map[raw_column_name] + + if actual_spark_type in SPARK_TYPE_TO_CORTEX_TYPE: + if expected_cortex_type == consts.COLUMN_TYPE_INFERRED: + expected_cortex_type = SPARK_TYPE_TO_CORTEX_TYPE[input_type_map[raw_column_name]] expected_types = CORTEX_TYPE_TO_CASTABLE_SPARK_TYPES[expected_cortex_type] if actual_spark_type not in expected_types: @@ -253,10 +251,13 @@ def ingest(ctx, spark): " or ".join(str(x) for x in expected_types), actual_spark_type ), ) - - if actual_spark_type != expected_cortex_type: - target_spark_type = CORTEX_TYPE_TO_SPARK_TYPE[expected_cortex_type] + target_spark_type = CORTEX_TYPE_TO_SPARK_TYPE[expected_cortex_type] + if actual_spark_type != target_spark_type: + target_spark_type = target_spark_type df = df.withColumn(raw_column_name, F.col(raw_column_name).cast(target_spark_type)) + else: # for unexpected types, just cast to string + df = df.withColumn(raw_column_name, F.col(raw_column_name).cast(StringType())) + return df.select(*sorted(df.columns)) diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index df07850867..eac5455737 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -23,7 +23,7 @@ import pyspark.sql.functions as F from mock import MagicMock, call from py4j.protocol import Py4JJavaError - +from datetime import datetime pytestmark = pytest.mark.usefixtures("spark") @@ -70,11 +70,15 @@ def test_read_csv_infer_type(spark, write_csv_file, ctx_obj, get_context): { "csv": ["a,0.1,", "b,0.1,1", "c,1.1,4"], "schema": ["a_str", "b_float", "c_long"], - "raw_columns": {}, + "raw_columns": { + "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, + "b_float": {"name": "b_float", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, + "c_long": {"name": "c_float", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, + }, "expected_types": { "a_str": StringType(), - "b_float": DoubleType(), - "c_long": IntegerType(), + "b_float": FloatType(), + "c_long": LongType(), }, }, { @@ -82,14 +86,43 @@ def test_read_csv_infer_type(spark, write_csv_file, ctx_obj, get_context): "schema": ["a_str", "b_int", "c_float"], "raw_columns": { "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "-"}, - "c_float": {"name": "c_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"}, + "b_int": {"name": "b_int", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, + "c_float": {"name": "c_float", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, }, "expected_types": { "a_str": StringType(), - "b_int": IntegerType(), + "b_int": LongType(), "c_float": FloatType(), }, }, + { + "csv": ["1,4,2017-09-16", "1,3,2017-09-16", "1,5,2017-09-16"], + "schema": ["a_str", "b_int", "c_str"], + "raw_columns": { + "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "-"}, + "b_int": {"name": "b_int", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, + "c_str": {"name": "c_str", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, + }, + "expected_types": { + "a_str": StringType(), + "b_int": LongType(), + "c_str": StringType(), + }, + }, + { + "csv": ["1,4,2017-09-16", "1,3,2017-09-16", "1,5,2017-09-16"], + "schema": ["a_float", "b_int", "c_str"], + "raw_columns": { + "a_float": {"name": "a_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"}, + "b_int": {"name": "b_int", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, + "c_str": {"name": "c_str", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, + }, + "expected_types": { + "a_float": FloatType(), + "b_int": LongType(), + "c_str": StringType(), + }, + }, ] for test in test_cases: @@ -115,21 +148,18 @@ def test_read_csv_infer_invalid(spark, write_csv_file, ctx_obj, get_context): "csv": ["a,0.1,", "a,0.1,1", "a,1.1,4"], "schema": ["a_int", "b_float", "c_long"], "raw_columns": { - "a_int": {"name": "a_int", "type": "INT_COLUMN", "required": True, "id": "-"} - }, - }, - { - "csv": ["a,1,", "a,1,1", "a,1,4"], - "schema": ["a_int", "b_float", "c_long"], - "raw_columns": { - "b_float": {"name": "b_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"} + "a_int": {"name": "a_int", "type": "INT_COLUMN", "required": True, "id": "-"}, + "b_float": {"name": "b_float", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, + "c_long": {"name": "c_long", "type": "INFERRED_COLUMN", "required": True, "id": "-"} }, }, { "csv": ["a,1.1,", "a,1.1,1", "a,1.1,4"], - "schema": ["a_str", "b_int", "c_long"], + "schema": ["a_str", "b_int", "c_int"], "raw_columns": { - "b_int": {"name": "b_int", "type": "INT_COLUMN", "required": True, "id": "-"} + "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, + "b_int": {"name": "b_int", "type": "INT_COLUMN", "required": True, "id": "-"}, + "c_int": {"name": "c_int", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, }, }, ] @@ -422,6 +452,136 @@ def test_ingest_parquet_valid(spark, write_parquet_file, ctx_obj, get_context): ] + +def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_context): + tests = [ + { + "data": [("a", 0.1, None), ("b", 1.0, None), ("c", 1.1, 4)], + "schema": StructType( + [ + StructField("a_str", StringType()), + StructField("b_float", DoubleType()), + StructField("c_long", IntegerType()), + ] + ), + "env": [ + {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, + {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, + {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + ], + "raw_columns": { + "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, + "b_float": {"name": "b_float", "type": "INFERRED_COLUMN", "required": True, "id": "2"}, + "c_long": {"name": "c_long", "type": "INFERRED_COLUMN", "required": False, "id": "3"}, + }, + "expected_types": [ + ("a_str", StringType()), + ("b_float", FloatType()), + ("c_long", LongType()), + ] + }, + { + "data": [("1", 0.1, None), ("1", 1.0, None), ("1", 1.1, 4)], + "schema": StructType( + [ + StructField("a_str", StringType()), + StructField("b_float", DoubleType()), + StructField("c_long", IntegerType()), + ] + ), + "env": [ + {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, + {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, + {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + ], + "raw_columns": { + "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, + "b_float": {"name": "b_float", "type": "INFERRED_COLUMN", "required": True, "id": "2"}, + "c_long": {"name": "c_long", "type": "INFERRED_COLUMN", "required": False, "id": "3"}, + }, + "expected_types": [ + ("a_str", StringType()), + ("b_float", FloatType()), + ("c_long", LongType()), + ] + }, + { + "data": [("1", 0.1, datetime.now()), ("1", 1.0, datetime.now()), ("1", 1.1, datetime.now())], + "schema": StructType( + [ + StructField("a_str", StringType()), + StructField("b_float", DoubleType()), + StructField("c_str", TimestampType()), + ] + ), + "env": [ + {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, + {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, + {"parquet_column_name": "c_str", "raw_column_name": "c_str"}, + ], + "raw_columns": { + "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, + "b_float": {"name": "b_float", "type": "INFERRED_COLUMN", "required": True, "id": "2"}, + "c_str": {"name": "c_str", "type": "INFERRED_COLUMN", "required": False, "id": "3"}, + }, + "expected_types": [ + ("a_str", StringType()), + ("b_float", FloatType()), + ("c_str", StringType()), + ] + }, + { + "data": [(1, 0.1, datetime.now()), (1, 1.0, datetime.now()), (1, 1.1, datetime.now())], + "schema": StructType( + [ + StructField("a_long", IntegerType()), + StructField("b_float", DoubleType()), + StructField("c_str", TimestampType()), + ] + ), + "env": [ + {"parquet_column_name": "a_long", "raw_column_name": "a_long"}, + {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, + {"parquet_column_name": "c_str", "raw_column_name": "c_str"}, + ], + "raw_columns": { + "a_long": {"name": "a_long", "type": "FLOAT_COLUMN", "required": True, "id": "1"}, + "b_float": {"name": "b_float", "type": "INFERRED_COLUMN", "required": True, "id": "2"}, + "c_str": {"name": "c_str", "type": "INFERRED_COLUMN", "required": False, "id": "3"}, + }, + "expected_types": [ + ("a_long", FloatType()), + ("b_float", FloatType()), + ("c_str", StringType()), + ] + } + ] + + for test in tests: + data = test["data"] + + schema = test["schema"] + + path_to_file = write_parquet_file(spark, data, schema) + + ctx_obj["environment"] = { + "data": { + "type": "parquet", + "path": path_to_file, + "schema": test["env"], + } + } + + ctx_obj["raw_columns"] = test["raw_columns"] + + + df = spark_util.ingest(get_context(ctx_obj), spark) + + assert df.count() == 3 + assert sorted([(s.name, s.dataType) for s in df.schema], key=lambda x: x[0]) == test["expected_types"] + + + def test_ingest_parquet_extra_cols(spark, write_parquet_file, ctx_obj, get_context): data = [("a", 0.1, None), ("b", 1.0, None), ("c", 1.1, 4)] From d39d31e1c94fe14b3ed3402380816d8a01b74cf8 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Fri, 31 May 2019 16:26:31 -0400 Subject: [PATCH 20/27] lint and format --- pkg/workloads/spark_job/spark_job.py | 8 +- pkg/workloads/spark_job/spark_util.py | 3 +- .../spark_job/test/unit/spark_util_test.py | 141 +++++++++++------- 3 files changed, 97 insertions(+), 55 deletions(-) diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 8b60d8dca8..5383f17db1 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -154,8 +154,12 @@ def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest): input_type_map = {f.name: f.dataType for f in ingest_df.schema} for raw_column_name in ctx.raw_columns: if ctx.raw_columns[raw_column_name]["type"] == consts.COLUMN_TYPE_INFERRED: - column_type = spark_util.SPARK_TYPE_TO_CORTEX_TYPE[input_type_map[raw_column_name]] - ctx.write_metadata(ctx.raw_columns[raw_column_name]["id"], {"type": column_type}) + column_type = spark_util.SPARK_TYPE_TO_CORTEX_TYPE[ + input_type_map[raw_column_name] + ] + ctx.write_metadata( + ctx.raw_columns[raw_column_name]["id"], {"type": column_type} + ) full_dataset_size = ingest_df.count() diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index fcd564342e..5e571e4d16 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -255,10 +255,9 @@ def ingest(ctx, spark): if actual_spark_type != target_spark_type: target_spark_type = target_spark_type df = df.withColumn(raw_column_name, F.col(raw_column_name).cast(target_spark_type)) - else: # for unexpected types, just cast to string + else: # for unexpected types, just cast to string df = df.withColumn(raw_column_name, F.col(raw_column_name).cast(StringType())) - return df.select(*sorted(df.columns)) diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index eac5455737..177df969be 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -72,14 +72,20 @@ def test_read_csv_infer_type(spark, write_csv_file, ctx_obj, get_context): "schema": ["a_str", "b_float", "c_long"], "raw_columns": { "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, - "b_float": {"name": "b_float", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, - "c_long": {"name": "c_float", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, - }, - "expected_types": { - "a_str": StringType(), - "b_float": FloatType(), - "c_long": LongType(), + "b_float": { + "name": "b_float", + "type": "INFERRED_COLUMN", + "required": True, + "id": "-", + }, + "c_long": { + "name": "c_float", + "type": "INFERRED_COLUMN", + "required": True, + "id": "-", + }, }, + "expected_types": {"a_str": StringType(), "b_float": FloatType(), "c_long": LongType()}, }, { "csv": ["1,4,4.5", "1,3,1.2", "1,5,4.7"], @@ -87,13 +93,14 @@ def test_read_csv_infer_type(spark, write_csv_file, ctx_obj, get_context): "raw_columns": { "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "-"}, "b_int": {"name": "b_int", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, - "c_float": {"name": "c_float", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, - }, - "expected_types": { - "a_str": StringType(), - "b_int": LongType(), - "c_float": FloatType(), + "c_float": { + "name": "c_float", + "type": "INFERRED_COLUMN", + "required": True, + "id": "-", + }, }, + "expected_types": {"a_str": StringType(), "b_int": LongType(), "c_float": FloatType()}, }, { "csv": ["1,4,2017-09-16", "1,3,2017-09-16", "1,5,2017-09-16"], @@ -103,13 +110,9 @@ def test_read_csv_infer_type(spark, write_csv_file, ctx_obj, get_context): "b_int": {"name": "b_int", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, "c_str": {"name": "c_str", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, }, - "expected_types": { - "a_str": StringType(), - "b_int": LongType(), - "c_str": StringType(), - }, + "expected_types": {"a_str": StringType(), "b_int": LongType(), "c_str": StringType()}, }, - { + { "csv": ["1,4,2017-09-16", "1,3,2017-09-16", "1,5,2017-09-16"], "schema": ["a_float", "b_int", "c_str"], "raw_columns": { @@ -117,11 +120,7 @@ def test_read_csv_infer_type(spark, write_csv_file, ctx_obj, get_context): "b_int": {"name": "b_int", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, "c_str": {"name": "c_str", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, }, - "expected_types": { - "a_float": FloatType(), - "b_int": LongType(), - "c_str": StringType(), - }, + "expected_types": {"a_float": FloatType(), "b_int": LongType(), "c_str": StringType()}, }, ] @@ -149,8 +148,18 @@ def test_read_csv_infer_invalid(spark, write_csv_file, ctx_obj, get_context): "schema": ["a_int", "b_float", "c_long"], "raw_columns": { "a_int": {"name": "a_int", "type": "INT_COLUMN", "required": True, "id": "-"}, - "b_float": {"name": "b_float", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, - "c_long": {"name": "c_long", "type": "INFERRED_COLUMN", "required": True, "id": "-"} + "b_float": { + "name": "b_float", + "type": "INFERRED_COLUMN", + "required": True, + "id": "-", + }, + "c_long": { + "name": "c_long", + "type": "INFERRED_COLUMN", + "required": True, + "id": "-", + }, }, }, { @@ -452,12 +461,11 @@ def test_ingest_parquet_valid(spark, write_parquet_file, ctx_obj, get_context): ] - def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_context): tests = [ { "data": [("a", 0.1, None), ("b", 1.0, None), ("c", 1.1, 4)], - "schema": StructType( + "schema": StructType( [ StructField("a_str", StringType()), StructField("b_float", DoubleType()), @@ -471,18 +479,28 @@ def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_cont ], "raw_columns": { "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, - "b_float": {"name": "b_float", "type": "INFERRED_COLUMN", "required": True, "id": "2"}, - "c_long": {"name": "c_long", "type": "INFERRED_COLUMN", "required": False, "id": "3"}, + "b_float": { + "name": "b_float", + "type": "INFERRED_COLUMN", + "required": True, + "id": "2", + }, + "c_long": { + "name": "c_long", + "type": "INFERRED_COLUMN", + "required": False, + "id": "3", + }, }, "expected_types": [ ("a_str", StringType()), ("b_float", FloatType()), ("c_long", LongType()), - ] + ], }, { "data": [("1", 0.1, None), ("1", 1.0, None), ("1", 1.1, 4)], - "schema": StructType( + "schema": StructType( [ StructField("a_str", StringType()), StructField("b_float", DoubleType()), @@ -496,18 +514,32 @@ def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_cont ], "raw_columns": { "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, - "b_float": {"name": "b_float", "type": "INFERRED_COLUMN", "required": True, "id": "2"}, - "c_long": {"name": "c_long", "type": "INFERRED_COLUMN", "required": False, "id": "3"}, + "b_float": { + "name": "b_float", + "type": "INFERRED_COLUMN", + "required": True, + "id": "2", + }, + "c_long": { + "name": "c_long", + "type": "INFERRED_COLUMN", + "required": False, + "id": "3", + }, }, "expected_types": [ ("a_str", StringType()), ("b_float", FloatType()), ("c_long", LongType()), - ] + ], }, { - "data": [("1", 0.1, datetime.now()), ("1", 1.0, datetime.now()), ("1", 1.1, datetime.now())], - "schema": StructType( + "data": [ + ("1", 0.1, datetime.now()), + ("1", 1.0, datetime.now()), + ("1", 1.1, datetime.now()), + ], + "schema": StructType( [ StructField("a_str", StringType()), StructField("b_float", DoubleType()), @@ -521,18 +553,23 @@ def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_cont ], "raw_columns": { "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, - "b_float": {"name": "b_float", "type": "INFERRED_COLUMN", "required": True, "id": "2"}, + "b_float": { + "name": "b_float", + "type": "INFERRED_COLUMN", + "required": True, + "id": "2", + }, "c_str": {"name": "c_str", "type": "INFERRED_COLUMN", "required": False, "id": "3"}, }, "expected_types": [ ("a_str", StringType()), ("b_float", FloatType()), ("c_str", StringType()), - ] + ], }, { "data": [(1, 0.1, datetime.now()), (1, 1.0, datetime.now()), (1, 1.1, datetime.now())], - "schema": StructType( + "schema": StructType( [ StructField("a_long", IntegerType()), StructField("b_float", DoubleType()), @@ -546,15 +583,20 @@ def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_cont ], "raw_columns": { "a_long": {"name": "a_long", "type": "FLOAT_COLUMN", "required": True, "id": "1"}, - "b_float": {"name": "b_float", "type": "INFERRED_COLUMN", "required": True, "id": "2"}, + "b_float": { + "name": "b_float", + "type": "INFERRED_COLUMN", + "required": True, + "id": "2", + }, "c_str": {"name": "c_str", "type": "INFERRED_COLUMN", "required": False, "id": "3"}, }, "expected_types": [ ("a_long", FloatType()), ("b_float", FloatType()), ("c_str", StringType()), - ] - } + ], + }, ] for test in tests: @@ -565,21 +607,18 @@ def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_cont path_to_file = write_parquet_file(spark, data, schema) ctx_obj["environment"] = { - "data": { - "type": "parquet", - "path": path_to_file, - "schema": test["env"], - } + "data": {"type": "parquet", "path": path_to_file, "schema": test["env"]} } ctx_obj["raw_columns"] = test["raw_columns"] - df = spark_util.ingest(get_context(ctx_obj), spark) assert df.count() == 3 - assert sorted([(s.name, s.dataType) for s in df.schema], key=lambda x: x[0]) == test["expected_types"] - + assert ( + sorted([(s.name, s.dataType) for s in df.schema], key=lambda x: x[0]) + == test["expected_types"] + ) def test_ingest_parquet_extra_cols(spark, write_parquet_file, ctx_obj, get_context): From edd9e882cf253d5d5f3ba76dddcb3b10cf9793c8 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 3 Jun 2019 12:47:58 -0400 Subject: [PATCH 21/27] address comments --- pkg/operator/api/userconfig/config.go | 4 +- pkg/operator/api/userconfig/environments.go | 13 ++ pkg/operator/api/userconfig/errors.go | 20 ++- pkg/workloads/spark_job/spark_util.py | 120 ++++++++++++------ .../spark_job/test/unit/spark_util_test.py | 77 ++++++++--- 5 files changed, 170 insertions(+), 64 deletions(-) diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index 74f8fbee6c..e731697ac6 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -440,10 +440,9 @@ func New(configs map[string][]byte, envName string) (*Config, error) { } } - rawColumnNames := config.RawColumns.Names() for _, env := range config.Environments { ingestedColumnNames := env.Data.GetIngestedColumns() - missingColumnNames := slices.SubtractStrSlice(ingestedColumnNames, rawColumnNames) + missingColumnNames := slices.SubtractStrSlice(ingestedColumnNames, config.RawColumns.Names()) for _, inferredColumnName := range missingColumnNames { inferredRawColumn := &RawInferredColumn{ ResourceFields: ResourceFields{ @@ -455,7 +454,6 @@ func New(configs map[string][]byte, envName string) (*Config, error) { cr.Struct(inferredRawColumn.Compute, make(map[string]interface{}), sparkComputeStructValidation) config.RawColumns = append(config.RawColumns, inferredRawColumn) } - rawColumnNames = config.RawColumns.Names() } if err := config.Validate(envName); err != nil { diff --git a/pkg/operator/api/userconfig/environments.go b/pkg/operator/api/userconfig/environments.go index 940765f080..9ae93b2cfa 100644 --- a/pkg/operator/api/userconfig/environments.go +++ b/pkg/operator/api/userconfig/environments.go @@ -21,6 +21,7 @@ import ( cr "github.com/cortexlabs/cortex/pkg/lib/configreader" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/pointer" + "github.com/cortexlabs/cortex/pkg/lib/sets/strset" "github.com/cortexlabs/cortex/pkg/lib/slices" "github.com/cortexlabs/cortex/pkg/operator/api/resource" ) @@ -337,6 +338,18 @@ func (environments Environments) Validate() error { return ErrorDuplicateResourceName(dups...) } + ingestedColumns := environments[0].Data.GetIngestedColumns() + for _, env := range environments[1:] { + difference := strset.Difference( + strset.New(ingestedColumns...), + strset.New(env.Data.GetIngestedColumns()...), + ) + + if len(difference) > 0 { + return ErrorEnvSchemaMismatch(environments[0], env) + } + } + return nil } diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 142f311733..32609239b4 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -59,6 +59,7 @@ const ( ErrRegressionTargetType ErrClassificationTargetType ErrSpecifyOnlyOneMissing + ErrEnvSchemaMismatch ) var errorKinds = []string{ @@ -92,9 +93,10 @@ var errorKinds = []string{ "err_regression_target_type", "err_classification_target_type", "err_specify_only_one_missing", + "err_env_schema_mismatch", } -var _ = [1]int{}[int(ErrSpecifyOnlyOneMissing)-(len(errorKinds)-1)] // Ensure list length matches +var _ = [1]int{}[int(ErrEnvSchemaMismatch)-(len(errorKinds)-1)] // Ensure list length matches func (t ErrorKind) String() string { return errorKinds[t] @@ -397,3 +399,19 @@ func ErrorSpecifyOnlyOneMissing(vals ...string) error { message: message, } } + +func ErrorEnvSchemaMismatch(env1, env2 *Environment) error { + difference := strset.Difference( + strset.New(env1.Data.GetIngestedColumns()...), + strset.New(env2.Data.GetIngestedColumns()...), + ) + + return Error{ + Kind: ErrEnvSchemaMismatch, + message: fmt.Sprintf("schemas diverged between environments, %s lists $s columns but %s does not", + env1.Name, + s.StrsAnd(difference.Slice()), + env2.Name, + ), + } +} diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 5e571e4d16..bc96659926 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -49,12 +49,34 @@ } CORTEX_TYPE_TO_CASTABLE_SPARK_TYPES = { - consts.COLUMN_TYPE_INT: [IntegerType(), LongType()], - consts.COLUMN_TYPE_INT_LIST: [ArrayType(IntegerType(), True), ArrayType(LongType(), True)], - consts.COLUMN_TYPE_FLOAT: [FloatType(), DoubleType(), IntegerType(), LongType()], - consts.COLUMN_TYPE_FLOAT_LIST: [ArrayType(FloatType(), True), ArrayType(DoubleType(), True)], - consts.COLUMN_TYPE_STRING: [StringType(), IntegerType(), LongType(), FloatType(), DoubleType()], - consts.COLUMN_TYPE_STRING_LIST: [ArrayType(StringType(), True)], + "csv": { + consts.COLUMN_TYPE_INT: [IntegerType(), LongType()], + consts.COLUMN_TYPE_INT_LIST: [ArrayType(IntegerType(), True), ArrayType(LongType(), True)], + consts.COLUMN_TYPE_FLOAT: [FloatType(), DoubleType(), IntegerType(), LongType()], + consts.COLUMN_TYPE_FLOAT_LIST: [ + ArrayType(FloatType(), True), + ArrayType(DoubleType(), True), + ], + consts.COLUMN_TYPE_STRING: [ + StringType(), + IntegerType(), + LongType(), + FloatType(), + DoubleType(), + ], + consts.COLUMN_TYPE_STRING_LIST: [ArrayType(StringType(), True)], + }, + "parquet": { + consts.COLUMN_TYPE_INT: [IntegerType(), LongType()], + consts.COLUMN_TYPE_INT_LIST: [ArrayType(IntegerType(), True), ArrayType(LongType(), True)], + consts.COLUMN_TYPE_FLOAT: [FloatType(), DoubleType()], + consts.COLUMN_TYPE_FLOAT_LIST: [ + ArrayType(FloatType(), True), + ArrayType(DoubleType(), True), + ], + consts.COLUMN_TYPE_STRING: [StringType()], + consts.COLUMN_TYPE_STRING_LIST: [ArrayType(StringType(), True)], + }, } PYTHON_TYPE_TO_CORTEX_TYPE = { @@ -223,40 +245,60 @@ def value_check_data(ctx, df, raw_columns=None): def ingest(ctx, spark): - if ctx.environment["data"]["type"] == "csv": + fileType = ctx.environment["data"]["type"] + if fileType == "csv": df = read_csv(ctx, spark) - elif ctx.environment["data"]["type"] == "parquet": + elif fileType == "parquet": df = read_parquet(ctx, spark) input_type_map = {f.name: f.dataType for f in df.schema} - logger.info(ctx.raw_columns) for raw_column_name in ctx.raw_columns: raw_column = ctx.raw_columns[raw_column_name] expected_cortex_type = raw_column["type"] actual_spark_type = input_type_map[raw_column_name] - if actual_spark_type in SPARK_TYPE_TO_CORTEX_TYPE: - if expected_cortex_type == consts.COLUMN_TYPE_INFERRED: - expected_cortex_type = SPARK_TYPE_TO_CORTEX_TYPE[input_type_map[raw_column_name]] - - expected_types = CORTEX_TYPE_TO_CASTABLE_SPARK_TYPES[expected_cortex_type] - if actual_spark_type not in expected_types: - logger.error("found schema:") - log_df_schema(df, logger.error) - - raise UserException( - "raw column " + raw_column_name, - "type mismatch", - "expected {} but found {}".format( - " or ".join(str(x) for x in expected_types), actual_spark_type - ), - ) - target_spark_type = CORTEX_TYPE_TO_SPARK_TYPE[expected_cortex_type] - if actual_spark_type != target_spark_type: - target_spark_type = target_spark_type - df = df.withColumn(raw_column_name, F.col(raw_column_name).cast(target_spark_type)) - else: # for unexpected types, just cast to string - df = df.withColumn(raw_column_name, F.col(raw_column_name).cast(StringType())) + if expected_cortex_type == consts.COLUMN_TYPE_INFERRED: + if actual_spark_type not in SPARK_TYPE_TO_CORTEX_TYPE: + df = df.withColumn(raw_column_name, F.col(raw_column_name).cast(StringType())) + else: + actual_cortex_type = SPARK_TYPE_TO_CORTEX_TYPE[actual_spark_type] + expected_spark_type = CORTEX_TYPE_TO_SPARK_TYPE[actual_cortex_type] + if actual_spark_type != expected_spark_type: + df = df.withColumn( + raw_column_name, F.col(raw_column_name).cast(expected_spark_type) + ) + continue + else: + expected_spark_type = CORTEX_TYPE_TO_SPARK_TYPE[expected_cortex_type] + if actual_spark_type in SPARK_TYPE_TO_CORTEX_TYPE: + expected_types = CORTEX_TYPE_TO_CASTABLE_SPARK_TYPES[fileType][expected_cortex_type] + if actual_spark_type not in expected_types: + logger.error("found schema:") + log_df_schema(df, logger.error) + + raise UserException( + "raw column " + raw_column_name, + "type mismatch", + "expected {} but found {}".format( + " or ".join(str(x) for x in expected_types), actual_spark_type + ), + ) + if actual_spark_type != expected_spark_type: + df = df.withColumn( + raw_column_name, F.col(raw_column_name).cast(expected_spark_type) + ) + else: + try: + df = df.withColumn( + raw_column_name, F.col(raw_column_name).cast(expected_spark_type) + ) + except Exception as e: + raise UserException( + "tried casting " + raw_column_name, + "to from ingested type " + actual_spark_type, + "to expected type " + expected_spark_type, + "but got exception: " + e, + ) return df.select(*sorted(df.columns)) @@ -271,19 +313,13 @@ def read_csv(ctx, spark): } df = spark.read.csv(data_config["path"], inferSchema=True, mode="FAILFAST", **csv_config) - renamed_cols = [F.col(c).alias(data_config["schema"][idx]) for idx, c in enumerate(df.columns)] - df = df.select(*renamed_cols) - - if set(data_config["schema"]) != set(df.columns): - logger.error("found schema:") - log_df_schema(df, logger.error) + if len(data_config["schema"]) != len(df.columns): raise UserException( - "expected column(s) " - + str(set(data_config["schema"])) - + " but got " - + str(set(df.columns)) + "expected " + len(data_config["schema"]) + " column(s) but got " + len(df.columns) ) - return df + + renamed_cols = [F.col(c).alias(data_config["schema"][idx]) for idx, c in enumerate(df.columns)] + return df.select(*renamed_cols) def read_parquet(ctx, spark): diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index 177df969be..1e6d9ed526 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -567,35 +567,81 @@ def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_cont ("c_str", StringType()), ], }, + ] + + for test in tests: + data = test["data"] + + schema = test["schema"] + + path_to_file = write_parquet_file(spark, data, schema) + + ctx_obj["environment"] = { + "data": {"type": "parquet", "path": path_to_file, "schema": test["env"]} + } + + ctx_obj["raw_columns"] = test["raw_columns"] + + df = spark_util.ingest(get_context(ctx_obj), spark) + + assert df.count() == 3 + assert ( + sorted([(s.name, s.dataType) for s in df.schema], key=lambda x: x[0]) + == test["expected_types"] + ) + + +def test_read_parquet_infer_invalid(spark, write_parquet_file, ctx_obj, get_context): + tests = [ { - "data": [(1, 0.1, datetime.now()), (1, 1.0, datetime.now()), (1, 1.1, datetime.now())], + "data": [("a", 0.1, None), ("b", 1.0, None), ("c", 1.1, 4)], "schema": StructType( [ - StructField("a_long", IntegerType()), + StructField("a_str", StringType()), StructField("b_float", DoubleType()), - StructField("c_str", TimestampType()), + StructField("c_long", IntegerType()), + ] + ), + "env": [ + {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, + {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, + {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + ], + "raw_columns": { + "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, + "b_float": {"name": "b_float", "type": "INT_COLUMN", "required": True, "id": "2"}, + "c_long": { + "name": "c_long", + "type": "INFERRED_COLUMN", + "required": False, + "id": "3", + }, + }, + }, + { + "data": [("1", 0.1, "yolo"), ("1", 1.0, "yolo"), ("1", 1.1, "yolo")], + "schema": StructType( + [ + StructField("a_str", StringType()), + StructField("b_float", DoubleType()), + StructField("c_str", StringType()), ] ), "env": [ - {"parquet_column_name": "a_long", "raw_column_name": "a_long"}, + {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, {"parquet_column_name": "c_str", "raw_column_name": "c_str"}, ], "raw_columns": { - "a_long": {"name": "a_long", "type": "FLOAT_COLUMN", "required": True, "id": "1"}, + "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, "b_float": { "name": "b_float", "type": "INFERRED_COLUMN", "required": True, "id": "2", }, - "c_str": {"name": "c_str", "type": "INFERRED_COLUMN", "required": False, "id": "3"}, + "c_str": {"name": "c_str", "type": "INT_COLUMN", "required": False, "id": "3"}, }, - "expected_types": [ - ("a_long", FloatType()), - ("b_float", FloatType()), - ("c_str", StringType()), - ], }, ] @@ -612,13 +658,8 @@ def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_cont ctx_obj["raw_columns"] = test["raw_columns"] - df = spark_util.ingest(get_context(ctx_obj), spark) - - assert df.count() == 3 - assert ( - sorted([(s.name, s.dataType) for s in df.schema], key=lambda x: x[0]) - == test["expected_types"] - ) + with pytest.raises(UserException) as exec_info: + spark_util.ingest(get_context(ctx_obj), spark).collect() def test_ingest_parquet_extra_cols(spark, write_parquet_file, ctx_obj, get_context): From e52753e7c4c753790687b2c38b2ba76b44586ac5 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 3 Jun 2019 13:37:19 -0400 Subject: [PATCH 22/27] format --- pkg/workloads/spark_job/spark_util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index bc96659926..e865d8880e 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -264,9 +264,9 @@ def ingest(ctx, spark): actual_cortex_type = SPARK_TYPE_TO_CORTEX_TYPE[actual_spark_type] expected_spark_type = CORTEX_TYPE_TO_SPARK_TYPE[actual_cortex_type] if actual_spark_type != expected_spark_type: - df = df.withColumn( - raw_column_name, F.col(raw_column_name).cast(expected_spark_type) - ) + df = df.withColumn( + raw_column_name, F.col(raw_column_name).cast(expected_spark_type) + ) continue else: expected_spark_type = CORTEX_TYPE_TO_SPARK_TYPE[expected_cortex_type] From aaeff4f2c6343bc51051b42740796fe8d87a6396 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 3 Jun 2019 14:01:00 -0400 Subject: [PATCH 23/27] add more tests --- .../spark_job/test/unit/spark_util_test.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index 1e6d9ed526..a7ec5f3916 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -643,6 +643,56 @@ def test_read_parquet_infer_invalid(spark, write_parquet_file, ctx_obj, get_cont "c_str": {"name": "c_str", "type": "INT_COLUMN", "required": False, "id": "3"}, }, }, + { + "data": [("a", 1, None), ("b", 1, None), ("c", 1, 4)], + "schema": StructType( + [ + StructField("a_str", StringType()), + StructField("b_float", IntegerType()), + StructField("c_long", IntegerType()), + ] + ), + "env": [ + {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, + {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, + {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + ], + "raw_columns": { + "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, + "b_float": {"name": "b_float", "type": "FLOAT_COLUMN", "required": True, "id": "2"}, + "c_long": { + "name": "c_long", + "type": "INFERRED_COLUMN", + "required": False, + "id": "3", + }, + }, + }, + { + "data": [("a", 1, None), ("b", 1, None), ("c", 1, 4)], + "schema": StructType( + [ + StructField("a_str", StringType()), + StructField("b_float", IntegerType()), + StructField("c_long", IntegerType()), + ] + ), + "env": [ + {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, + {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, + {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + ], + "raw_columns": { + "a_str": {"name": "a_str", "type": "INT_COLUMN", "required": True, "id": "1"}, + "b_float": {"name": "b_float", "type": "STRING_COLUMN", "required": True, "id": "2"}, + "c_long": { + "name": "c_long", + "type": "INFERRED_COLUMN", + "required": False, + "id": "3", + }, + }, + }, ] for test in tests: From e99bcd87f35a0fe3c3fbead86850fd615641382d Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 3 Jun 2019 14:17:40 -0400 Subject: [PATCH 24/27] add list test --- .../spark_job/test/unit/spark_util_test.py | 61 ++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index a7ec5f3916..7ff5c60f2b 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -567,6 +567,40 @@ def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_cont ("c_str", StringType()), ], }, + { + "data": [ + ("1", [0.1, 12.0], datetime.now()), + ("1", [1.23, 1.0], datetime.now()), + ("1", [12.3, 1.1], datetime.now()), + ], + "schema": StructType( + [ + StructField("a_str", StringType()), + StructField("b_float", ArrayType(DoubleType()), True), + StructField("c_str", TimestampType()), + ] + ), + "env": [ + {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, + {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, + {"parquet_column_name": "c_str", "raw_column_name": "c_str"}, + ], + "raw_columns": { + "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, + "b_float": { + "name": "b_float", + "type": "FLOAT_LIST_COLUMN", + "required": True, + "id": "2", + }, + "c_str": {"name": "c_str", "type": "INFERRED_COLUMN", "required": False, "id": "3"}, + }, + "expected_types": [ + ("a_str", StringType()), + ("b_float", ArrayType(FloatType(), True)), + ("c_str", StringType()), + ], + }, ] for test in tests: @@ -668,7 +702,7 @@ def test_read_parquet_infer_invalid(spark, write_parquet_file, ctx_obj, get_cont }, }, }, - { + { "data": [("a", 1, None), ("b", 1, None), ("c", 1, 4)], "schema": StructType( [ @@ -693,6 +727,31 @@ def test_read_parquet_infer_invalid(spark, write_parquet_file, ctx_obj, get_cont }, }, }, + { + "data": [("a", [1], None), ("b", [1], None), ("c", [1], 4)], + "schema": StructType( + [ + StructField("a_str", StringType()), + StructField("b_float", ArrayType(IntegerType()), True), + StructField("c_long", IntegerType()), + ] + ), + "env": [ + {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, + {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, + {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + ], + "raw_columns": { + "a_str": {"name": "a_str", "type": "INT_COLUMN", "required": True, "id": "1"}, + "b_float": {"name": "b_float", "type": "FLOAT_LIST_COLUMN", "required": True, "id": "2"}, + "c_long": { + "name": "c_long", + "type": "INFERRED_COLUMN", + "required": False, + "id": "3", + }, + }, + }, ] for test in tests: From aaa5f9df7d4b7672dd69830d3d2b942dd25265d5 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 3 Jun 2019 14:21:39 -0400 Subject: [PATCH 25/27] format and lint --- pkg/operator/api/userconfig/errors.go | 2 +- .../spark_job/test/unit/spark_util_test.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 32609239b4..46d5abb2e8 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -408,7 +408,7 @@ func ErrorEnvSchemaMismatch(env1, env2 *Environment) error { return Error{ Kind: ErrEnvSchemaMismatch, - message: fmt.Sprintf("schemas diverged between environments, %s lists $s columns but %s does not", + message: fmt.Sprintf("schemas diverged between environments, %s lists %s columns but %s does not", env1.Name, s.StrsAnd(difference.Slice()), env2.Name, diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index 7ff5c60f2b..a95b95085a 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -718,7 +718,12 @@ def test_read_parquet_infer_invalid(spark, write_parquet_file, ctx_obj, get_cont ], "raw_columns": { "a_str": {"name": "a_str", "type": "INT_COLUMN", "required": True, "id": "1"}, - "b_float": {"name": "b_float", "type": "STRING_COLUMN", "required": True, "id": "2"}, + "b_float": { + "name": "b_float", + "type": "STRING_COLUMN", + "required": True, + "id": "2", + }, "c_long": { "name": "c_long", "type": "INFERRED_COLUMN", @@ -743,7 +748,12 @@ def test_read_parquet_infer_invalid(spark, write_parquet_file, ctx_obj, get_cont ], "raw_columns": { "a_str": {"name": "a_str", "type": "INT_COLUMN", "required": True, "id": "1"}, - "b_float": {"name": "b_float", "type": "FLOAT_LIST_COLUMN", "required": True, "id": "2"}, + "b_float": { + "name": "b_float", + "type": "FLOAT_LIST_COLUMN", + "required": True, + "id": "2", + }, "c_long": { "name": "c_long", "type": "INFERRED_COLUMN", From 97e5207332a32f29233e886ba344137f9586887f Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 3 Jun 2019 18:23:37 -0400 Subject: [PATCH 26/27] address comments --- pkg/operator/api/userconfig/environments.go | 7 +------ pkg/operator/api/userconfig/errors.go | 10 +++------- pkg/workloads/spark_job/spark_util.py | 16 +++++++++++++--- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/pkg/operator/api/userconfig/environments.go b/pkg/operator/api/userconfig/environments.go index 9ae93b2cfa..58956033e9 100644 --- a/pkg/operator/api/userconfig/environments.go +++ b/pkg/operator/api/userconfig/environments.go @@ -340,12 +340,7 @@ func (environments Environments) Validate() error { ingestedColumns := environments[0].Data.GetIngestedColumns() for _, env := range environments[1:] { - difference := strset.Difference( - strset.New(ingestedColumns...), - strset.New(env.Data.GetIngestedColumns()...), - ) - - if len(difference) > 0 { + if !strset.New(ingestedColumns...).IsEqual(strset.New(env.Data.GetIngestedColumns()...)) { return ErrorEnvSchemaMismatch(environments[0], env) } } diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 46d5abb2e8..76617d9933 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -401,17 +401,13 @@ func ErrorSpecifyOnlyOneMissing(vals ...string) error { } func ErrorEnvSchemaMismatch(env1, env2 *Environment) error { - difference := strset.Difference( - strset.New(env1.Data.GetIngestedColumns()...), - strset.New(env2.Data.GetIngestedColumns()...), - ) - return Error{ Kind: ErrEnvSchemaMismatch, - message: fmt.Sprintf("schemas diverged between environments, %s lists %s columns but %s does not", + message: fmt.Sprintf("schemas diverged between environments, %s lists %s columns and %s lists %s", env1.Name, - s.StrsAnd(difference.Slice()), + s.StrsAnd(env1.Data.GetIngestedColumns()), env2.Name, + s.StrsAnd(env2.Data.GetIngestedColumns()), ), } } diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index e865d8880e..eb5f3b174e 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -63,6 +63,11 @@ LongType(), FloatType(), DoubleType(), + ArrayType(FloatType(), True), + ArrayType(DoubleType(), True), + ArrayType(StringType(), True), + ArrayType(IntegerType(), True), + ArrayType(LongType(), True), ], consts.COLUMN_TYPE_STRING_LIST: [ArrayType(StringType(), True)], }, @@ -75,7 +80,13 @@ ArrayType(DoubleType(), True), ], consts.COLUMN_TYPE_STRING: [StringType()], - consts.COLUMN_TYPE_STRING_LIST: [ArrayType(StringType(), True)], + consts.COLUMN_TYPE_STRING_LIST: [ + ArrayType(StringType(), True), + ArrayType(FloatType(), True), + ArrayType(DoubleType(), True), + ArrayType(IntegerType(), True), + ArrayType(LongType(), True), + ], }, } @@ -267,7 +278,6 @@ def ingest(ctx, spark): df = df.withColumn( raw_column_name, F.col(raw_column_name).cast(expected_spark_type) ) - continue else: expected_spark_type = CORTEX_TYPE_TO_SPARK_TYPE[expected_cortex_type] if actual_spark_type in SPARK_TYPE_TO_CORTEX_TYPE: @@ -295,7 +305,7 @@ def ingest(ctx, spark): except Exception as e: raise UserException( "tried casting " + raw_column_name, - "to from ingested type " + actual_spark_type, + "from ingested type " + actual_spark_type, "to expected type " + expected_spark_type, "but got exception: " + e, ) From 54d5db6f9ece877922aa68ef8e59d266aaaaa24a Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Mon, 3 Jun 2019 15:41:57 -0700 Subject: [PATCH 27/27] Update errors.go --- pkg/operator/api/userconfig/errors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 76617d9933..0f1cbc1d98 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -403,7 +403,7 @@ func ErrorSpecifyOnlyOneMissing(vals ...string) error { func ErrorEnvSchemaMismatch(env1, env2 *Environment) error { return Error{ Kind: ErrEnvSchemaMismatch, - message: fmt.Sprintf("schemas diverged between environments, %s lists %s columns and %s lists %s", + message: fmt.Sprintf("schemas diverge between environments (%s lists %s, and %s lists %s)", env1.Name, s.StrsAnd(env1.Data.GetIngestedColumns()), env2.Name,