diff --git a/torch_glow/src/GlowIValue.cpp b/torch_glow/src/GlowIValue.cpp index e49d4da35c..c8b8d44609 100644 --- a/torch_glow/src/GlowIValue.cpp +++ b/torch_glow/src/GlowIValue.cpp @@ -452,4 +452,72 @@ size_t GlowIValueMapHash::operator()(const GlowIValue &ival) const { return 0; } +/// Unwrap a Expected \p expectedIVal and call toDouble, +/// propogate any Errors. +Expected iValToDouble(Expected expectedIVal) { + if (expectedIVal) { + return (*expectedIVal)->toDouble(); + } else { + return expectedIVal.takeError(); + } +} + +/// Unwrap a Expected \p expectedIVal and call toInt, +/// propogate any Errors. +Expected iValToInt(Expected expectedIVal) { + if (expectedIVal) { + return (*expectedIVal)->toInt(); + } else { + return expectedIVal.takeError(); + } +} + +/// Unwrap a Expected \p expectedIVal and call toBool, +/// propogate any Errors. +Expected iValToBool(Expected expectedIVal) { + if (expectedIVal) { + return (*expectedIVal)->toBool(); + } else { + return expectedIVal.takeError(); + } +} + +/// Unwrap a Expected \p expectedIVal and call toIntList, +/// propogate any Errors. +Expected *> +iValToIntList(Expected expectedIVal) { + if (expectedIVal) { + return (*expectedIVal)->toIntList(); + } else { + return expectedIVal.takeError(); + } +} + +/// Unwrap a Expected \p expectedIVal and call toPTTensor, +/// propogate any Errors. +Expected iValToPTTensor(Expected expectedIVal) { + if (expectedIVal) { + return (*expectedIVal)->toPTTensor(); + } else { + return expectedIVal.takeError(); + } +} + +Expected +iValToGenericMap(Expected expectedIVal) { + if (expectedIVal) { + return (*expectedIVal)->toGenericMap(); + } else { + return expectedIVal.takeError(); + } +} + +Expected iValToString(Expected expectedIVal) { + if (expectedIVal) { + return (*expectedIVal)->toString(); + } else { + return expectedIVal.takeError(); + } +} + } // namespace glow diff --git a/torch_glow/src/GlowIValue.h b/torch_glow/src/GlowIValue.h index b6333f1775..6cdb4a93ae 100644 --- a/torch_glow/src/GlowIValue.h +++ b/torch_glow/src/GlowIValue.h @@ -236,6 +236,35 @@ class GlowIValue { Error fromIValue(const at::IValue &ival); }; +/// Unwrap a Expected \p expectedIVal and call toDouble, +/// propogate any Errors. +Expected iValToDouble(Expected expectedIVal); + +/// Unwrap a Expected \p expectedIVal and call toInt, +/// propogate any Errors. +Expected iValToInt(Expected expectedIVal); + +/// Unwrap a Expected \p expectedIVal and call toBool, +/// propogate any Errors. +Expected iValToBool(Expected expectedIVal); + +/// Unwrap a Expected \p expectedIVal and call toIntList, +/// propogate any Errors. +Expected *> +iValToIntList(Expected expectedIVal); + +/// Unwrap a Expected \p expectedIVal and call toPTTensor, +/// propogate any Errors. +Expected iValToPTTensor(Expected expectedIVal); + +/// Unwrap a Expected \p expectedIVal and call toGenericMap, +/// propogate any Errors. +Expected iValToGenericMap(Expected expectedIVal); + +/// Unwrap a Expected \p expectedIVal and call toString, +/// propogate any Errors. +Expected iValToString(Expected expectedIVal); + } // namespace glow #endif // GLOW_TORCH_GLOW_SRC_GLOWIVALUE_H diff --git a/torch_glow/src/PyTorchModelLoader.cpp b/torch_glow/src/PyTorchModelLoader.cpp index 1b0af027dd..856cf78af6 100644 --- a/torch_glow/src/PyTorchModelLoader.cpp +++ b/torch_glow/src/PyTorchModelLoader.cpp @@ -176,57 +176,6 @@ contractIntIValIfNeeded(Expected expectedGlowIVal) { } } -/// Unwrap a Expected \p expectedIVal and call toDouble, -/// propogate any Errors. -Expected iValToDouble(Expected expectedIVal) { - if (expectedIVal) { - return (*expectedIVal)->toDouble(); - } else { - return expectedIVal.takeError(); - } -} - -/// Unwrap a Expected \p expectedIVal and call toInt, -/// propogate any Errors. -Expected iValToInt(Expected expectedIVal) { - if (expectedIVal) { - return (*expectedIVal)->toInt(); - } else { - return expectedIVal.takeError(); - } -} - -/// Unwrap a Expected \p expectedIVal and call toBool, -/// propogate any Errors. -Expected iValToBool(Expected expectedIVal) { - if (expectedIVal) { - return (*expectedIVal)->toBool(); - } else { - return expectedIVal.takeError(); - } -} - -/// Unwrap a Expected \p expectedIVal and call toIntList, -/// propogate any Errors. -Expected *> -iValToIntList(Expected expectedIVal) { - if (expectedIVal) { - return (*expectedIVal)->toIntList(); - } else { - return expectedIVal.takeError(); - } -} - -/// Unwrap a Expected \p expectedIVal and call toPTTensor, -/// propogate any Errors. -Expected iValToPTTensor(Expected expectedIVal) { - if (expectedIVal) { - return (*expectedIVal)->toPTTensor(); - } else { - return expectedIVal.takeError(); - } -} - /// Given a vector \p original containing elements of some type, \returns a /// vector of each element cast to another type T. template @@ -968,6 +917,7 @@ bool PyTorchModelLoader::isNodeSupported(const torch::jit::Node *ptNode) { } const auto &mapping = getSymbolsMapping(); + return mapping.count(kind) != 0; }