Skip to content

Commit 1d370be

Browse files
Hector YuenPenghuiCheng
Hector Yuen
authored andcommitted
add a Float16UniformFill (pytorch#11123)
Summary: Pull Request resolved: pytorch#11123 this adds an operator that fills a tensor with a uniform(min, max) the implementation is to use the fp32 generator and convert to fp16 if performance becomes an issue we could resort to intrinsics Reviewed By: jspark1105, chocjy Differential Revision: D9598142 fbshipit-source-id: 5aeab99acf7c3596fa6c33611d9d2c484f7c1145
1 parent f778921 commit 1d370be

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

caffe2/operators/half_float_ops.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,37 @@ class Float16ConstantFillOp : public Operator<CPUContext> {
3939
vector<TIndex> shape_;
4040
};
4141

42+
class Float16UniformFillOp : public Operator<CPUContext> {
43+
public:
44+
Float16UniformFillOp(const OperatorDef& operator_def, Workspace* ws)
45+
: Operator<CPUContext>(operator_def, ws),
46+
shape_(this->template GetRepeatedArgument<int64_t>("shape")),
47+
min_(this->template GetSingleArgument<float>("min", 0)),
48+
max_(this->template GetSingleArgument<float>("max", 1)) {
49+
if (InputSize() == 3) {
50+
CAFFE_ENFORCE(
51+
!this->template HasSingleArgumentOfType<float>("min"),
52+
"Cannot set both min arg and min input blob");
53+
CAFFE_ENFORCE(
54+
!this->template HasSingleArgumentOfType<float>("max"),
55+
"Cannot set both max arg and max input blob");
56+
} else {
57+
CAFFE_ENFORCE_LT(
58+
min_, max_, "Max value should be bigger than min value.");
59+
}
60+
}
61+
62+
USE_OPERATOR_FUNCTIONS(CPUContext);
63+
virtual ~Float16UniformFillOp() {}
64+
65+
bool RunOnDevice() override;
66+
67+
private:
68+
vector<TIndex> shape_;
69+
float min_;
70+
float max_;
71+
};
72+
4273
inline std::vector<TensorShape> Float16FillerTensorInference(
4374
const OperatorDef& def,
4475
const vector<TensorShape>& in) {

0 commit comments

Comments
 (0)