Skip to content

Commit e619e54

Browse files
authored
Merge pull request #59 from csarofeen/misc
Rename Fusion::uses() to unordered_uses()
2 parents bb9ca9d + 607862d commit e619e54

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

torch/csrc/jit/codegen/cuda/dispatch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ void Expr::constDispatch(T handler, const Expr* expr) {
193193
ptr(handler)->handle(static_cast<const ReductionOp*>(expr));
194194
return;
195195
case ExprType::BroadcastOp:
196-
ptr(handler)->handle(static_cast<const BroadcastOp* const>(expr));
196+
ptr(handler)->handle(static_cast<const BroadcastOp*>(expr));
197197
return;
198198
case ExprType::ForLoop:
199199
ptr(handler)->handle(static_cast<const ForLoop*>(expr));

torch/csrc/jit/codegen/cuda/fusion.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ void Fusion::removeVal(Val* val) {
112112
if (orig != nullptr)
113113
removeExpr(origin(val));
114114

115-
for (Expr* use : uses(val))
115+
for (Expr* use : unordered_uses(val))
116116
removeExpr(use);
117117

118118
val_set_.erase(val);
@@ -130,7 +130,7 @@ void Fusion::addInput(Val* const input) {
130130
assertInFusion(input, "Cannot register input ");
131131

132132
if (input->getValType().value() == ValType::TensorView) {
133-
TensorView* tv = static_cast<TensorView* const>(input);
133+
auto tv = input->as<TensorView>();
134134
if (tv->hasReduction())
135135
TORCH_WARN_ONCE(
136136
"Registered input ",
@@ -144,7 +144,7 @@ void Fusion::addInput(Val* const input) {
144144
void Fusion::addOutput(Val* const output) {
145145
assertInFusion(output, "Cannot register output ");
146146
if (output->getValType().value() == ValType::TensorView) {
147-
TensorView* tv = static_cast<TensorView* const>(output);
147+
auto tv = output->as<TensorView>();
148148
if (TensorDomain::hasBroadcast(tv->getRootDomain()))
149149
// Go to the root as we can merge bcast and
150150
// non-bcast dims, making a non-bcast dim.
@@ -311,7 +311,7 @@ const std::unordered_set<Expr*>& Fusion::unordered_exprs() const noexcept {
311311
return expr_set_;
312312
}
313313

314-
std::unordered_set<Expr*> Fusion::uses(Val* val) const {
314+
std::unordered_set<Expr*> Fusion::unordered_uses(Val* val) const {
315315
assertInFusion(val, "Cannot detect where val was used, ");
316316
if (uses_.find(val) != uses_.end()) {
317317
auto ret = uses_.find(val)->second;

torch/csrc/jit/codegen/cuda/fusion.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput {
192192
const std::unordered_set<Expr*>& unordered_exprs() const noexcept;
193193

194194
// Return all Exprs that use val
195-
std::unordered_set<Expr*> uses(Val* val) const;
195+
std::unordered_set<Expr*> unordered_uses(Val* val) const;
196196

197197
// Return the Expr that produces val
198198
Expr* origin(Val* val) const;

torch/csrc/jit/codegen/cuda/iter_visitor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ std::vector<Statement*> BackwardVisitor::next(Val* val) {
250250
// Going to sort based on relative topological position
251251
std::map<size_t, Statement*> exprs;
252252

253-
for (auto expr : FusionGuard::getCurFusion()->uses(val))
253+
for (auto expr : FusionGuard::getCurFusion()->unordered_uses(val))
254254
// Make sure it's an expr we can traverse
255255
if (traversal_exprs_.find(expr) != traversal_exprs_.end())
256256
exprs[traversal_exprs_[expr]] = expr;

0 commit comments

Comments
 (0)