From b61cdc3b5174bbf40885f4acd1bde14981d11544 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 4 Mar 2024 16:49:12 -0800 Subject: [PATCH] Fix static runtime sigrid_hash precomputed multiplier pass Reviewed By: pls331, houseroad Differential Revision: D54336561 --- torch_glow/src/ShapeInferenceEngine.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_glow/src/ShapeInferenceEngine.cpp b/torch_glow/src/ShapeInferenceEngine.cpp index 24dd73f0aa..81b13e78d5 100644 --- a/torch_glow/src/ShapeInferenceEngine.cpp +++ b/torch_glow/src/ShapeInferenceEngine.cpp @@ -3437,7 +3437,8 @@ ShapeInferenceEngine::argmin(const MetaStack &variableMetas) { * int salt, * int maxValue, * Tensor multiplier_shift, - * bool hashIntoInt32 + * bool hashIntoInt32, + * bool? noHashNegSalt * ) -> Tensor * * @@ -3445,8 +3446,8 @@ ShapeInferenceEngine::argmin(const MetaStack &variableMetas) { Expected ShapeInferenceEngine::sigridHashPrecompute(const MetaStack &variableMetas) { RETURN_ERR_IF_NOT( - variableMetas.size() == 5, - strFormat("Expected 5 inputs, got %zu", variableMetas.size())); + variableMetas.size() == 6, + strFormat("Expected 6 inputs, got %zu", variableMetas.size())); TensorShape shape = variableMetas[0].shape();