Skip to content

Commit 0546c76

Browse files
jackm321facebook-github-bot
authored andcommitted
Move Expected IValue helper methods to common (#4212)
Summary: Pull Request resolved: #4212 Move methods used to deal with Expected IVavlues in the PyTorchModelLoader to Common so that they can be used in other code as well Reviewed By: yinghai Differential Revision: D20068995 fbshipit-source-id: 2926911bb918eaa21eda3ca381f2bea239e89454
1 parent daa534a commit 0546c76

File tree

3 files changed

+98
-51
lines changed

3 files changed

+98
-51
lines changed

torch_glow/src/GlowIValue.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,4 +452,72 @@ size_t GlowIValueMapHash::operator()(const GlowIValue &ival) const {
452452
return 0;
453453
}
454454

455+
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toDouble,
456+
/// propogate any Errors.
457+
Expected<double> iValToDouble(Expected<GlowIValue *> expectedIVal) {
458+
if (expectedIVal) {
459+
return (*expectedIVal)->toDouble();
460+
} else {
461+
return expectedIVal.takeError();
462+
}
463+
}
464+
465+
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toInt,
466+
/// propogate any Errors.
467+
Expected<int64_t> iValToInt(Expected<GlowIValue *> expectedIVal) {
468+
if (expectedIVal) {
469+
return (*expectedIVal)->toInt();
470+
} else {
471+
return expectedIVal.takeError();
472+
}
473+
}
474+
475+
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toBool,
476+
/// propogate any Errors.
477+
Expected<bool> iValToBool(Expected<GlowIValue *> expectedIVal) {
478+
if (expectedIVal) {
479+
return (*expectedIVal)->toBool();
480+
} else {
481+
return expectedIVal.takeError();
482+
}
483+
}
484+
485+
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toIntList,
486+
/// propogate any Errors.
487+
Expected<std::vector<int64_t> *>
488+
iValToIntList(Expected<GlowIValue *> expectedIVal) {
489+
if (expectedIVal) {
490+
return (*expectedIVal)->toIntList();
491+
} else {
492+
return expectedIVal.takeError();
493+
}
494+
}
495+
496+
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toPTTensor,
497+
/// propogate any Errors.
498+
Expected<at::Tensor *> iValToPTTensor(Expected<GlowIValue *> expectedIVal) {
499+
if (expectedIVal) {
500+
return (*expectedIVal)->toPTTensor();
501+
} else {
502+
return expectedIVal.takeError();
503+
}
504+
}
505+
506+
Expected<GlowIValueMap *>
507+
iValToGenericMap(Expected<GlowIValue *> expectedIVal) {
508+
if (expectedIVal) {
509+
return (*expectedIVal)->toGenericMap();
510+
} else {
511+
return expectedIVal.takeError();
512+
}
513+
}
514+
515+
Expected<std::string *> iValToString(Expected<GlowIValue *> expectedIVal) {
516+
if (expectedIVal) {
517+
return (*expectedIVal)->toString();
518+
} else {
519+
return expectedIVal.takeError();
520+
}
521+
}
522+
455523
} // namespace glow

torch_glow/src/GlowIValue.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,35 @@ class GlowIValue {
236236
Error fromIValue(const at::IValue &ival);
237237
};
238238

239+
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toDouble,
240+
/// propogate any Errors.
241+
Expected<double> iValToDouble(Expected<GlowIValue *> expectedIVal);
242+
243+
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toInt,
244+
/// propogate any Errors.
245+
Expected<int64_t> iValToInt(Expected<GlowIValue *> expectedIVal);
246+
247+
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toBool,
248+
/// propogate any Errors.
249+
Expected<bool> iValToBool(Expected<GlowIValue *> expectedIVal);
250+
251+
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toIntList,
252+
/// propogate any Errors.
253+
Expected<std::vector<int64_t> *>
254+
iValToIntList(Expected<GlowIValue *> expectedIVal);
255+
256+
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toPTTensor,
257+
/// propogate any Errors.
258+
Expected<at::Tensor *> iValToPTTensor(Expected<GlowIValue *> expectedIVal);
259+
260+
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toGenericMap,
261+
/// propogate any Errors.
262+
Expected<GlowIValueMap *> iValToGenericMap(Expected<GlowIValue *> expectedIVal);
263+
264+
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toString,
265+
/// propogate any Errors.
266+
Expected<std::string *> iValToString(Expected<GlowIValue *> expectedIVal);
267+
239268
} // namespace glow
240269

241270
#endif // GLOW_TORCH_GLOW_SRC_GLOWIVALUE_H

torch_glow/src/PyTorchModelLoader.cpp

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -176,57 +176,6 @@ contractIntIValIfNeeded(Expected<GlowIValue *> expectedGlowIVal) {
176176
}
177177
}
178178

179-
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toDouble,
180-
/// propogate any Errors.
181-
Expected<double> iValToDouble(Expected<GlowIValue *> expectedIVal) {
182-
if (expectedIVal) {
183-
return (*expectedIVal)->toDouble();
184-
} else {
185-
return expectedIVal.takeError();
186-
}
187-
}
188-
189-
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toInt,
190-
/// propogate any Errors.
191-
Expected<int64_t> iValToInt(Expected<GlowIValue *> expectedIVal) {
192-
if (expectedIVal) {
193-
return (*expectedIVal)->toInt();
194-
} else {
195-
return expectedIVal.takeError();
196-
}
197-
}
198-
199-
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toBool,
200-
/// propogate any Errors.
201-
Expected<bool> iValToBool(Expected<GlowIValue *> expectedIVal) {
202-
if (expectedIVal) {
203-
return (*expectedIVal)->toBool();
204-
} else {
205-
return expectedIVal.takeError();
206-
}
207-
}
208-
209-
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toIntList,
210-
/// propogate any Errors.
211-
Expected<std::vector<int64_t> *>
212-
iValToIntList(Expected<GlowIValue *> expectedIVal) {
213-
if (expectedIVal) {
214-
return (*expectedIVal)->toIntList();
215-
} else {
216-
return expectedIVal.takeError();
217-
}
218-
}
219-
220-
/// Unwrap a Expected<GlowIValue *> \p expectedIVal and call toPTTensor,
221-
/// propogate any Errors.
222-
Expected<at::Tensor *> iValToPTTensor(Expected<GlowIValue *> expectedIVal) {
223-
if (expectedIVal) {
224-
return (*expectedIVal)->toPTTensor();
225-
} else {
226-
return expectedIVal.takeError();
227-
}
228-
}
229-
230179
/// Given a vector \p original containing elements of some type, \returns a
231180
/// vector of each element cast to another type T.
232181
template <typename T, typename OriginalT>
@@ -968,6 +917,7 @@ bool PyTorchModelLoader::isNodeSupported(const torch::jit::Node *ptNode) {
968917
}
969918

970919
const auto &mapping = getSymbolsMapping();
920+
971921
return mapping.count(kind) != 0;
972922
}
973923

0 commit comments

Comments
 (0)