-
Notifications
You must be signed in to change notification settings - Fork 7
Reduce the work to add a new expr: remove ExprType #2186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -91,132 +91,167 @@ void Val::dispatch(T handler, Val* val) { | |||
|
|||
template <typename T> | |||
void Expr::dispatch(T handler, Expr* expr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dispatch logic also needs refactor, but I will not do it in this PR
|
Failing tests are due to my stupid mistake, fixed now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have any particular concern about removing ExprType
. Just two comments:
getExprType().value() == ExprType::ABC
isn't equivalent toisA<ABC>()
as the latter allows any subclasses ofABC
.isStrictlyA<ABC>()
should be used most of the cases to keep the semantics of ExprType semantics. I don't see any problematic conversion in this PR, but just want to make sure you are also aware of the difference.- I don't know what the original motivation of having the
ExprType
enum was. Is there anything we are losing? Pinging @csarofeen. ForVal
, theValType
enum and the class type doesn't always match 1-to-1, e.g.,Bool
,Double
,Int
all haveValType::Scalar
.Val::isScalar()
would require a little longer code, but doesn't seem like a problem.
} | ||
|
||
private: | ||
template <int> // unused template argument |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh! Is this because functions don't allow template partial specialization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Is there a name for this technique?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know. I just invented it ad hoc.
Yes, I am aware of that. Initially, I thought there was no 2-level inheritance in nvfuser, but after some looking, seems that the reductions are the only ops that use 2-level inheritance. I just revisited all the use of if (expr->isStrictlyA<kir::GridReduction>()) {
fused_reduction_alloc_reduction =
IrBuilder::create<kir::AllocateFusedReduction>(
expr->as<kir::GridReduction>());
} else if (expr->isStrictlyA<kir::GridWelford>()) {
fused_reduction_alloc_reduction =
IrBuilder::create<kir::AllocateFusedReduction>(
expr->as<kir::GridWelford>());
} else if (expr->isStrictlyA<kir::GroupedGridReduction>()) {
fused_reduction_alloc_reduction =
IrBuilder::create<kir::AllocateFusedReduction>(
expr->as<kir::GroupedGridReduction>());
} else if (expr->isStrictlyA<kir::GroupedGridWelford>()) {
fused_reduction_alloc_reduction =
IrBuilder::create<kir::AllocateFusedReduction>(
expr->as<kir::GroupedGridWelford>());
} else {
TORCH_INTERNAL_ASSERT(false, "Invalid expr: ", expr->toString());
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems fine to me, a little nervous about moving switch statements to if-elif-elif-... patterns but otherwise no objections.
expr->getExprType().value() == ExprType::GridBroadcast || | ||
expr->getExprType().value() == ExprType::GridWelford || | ||
expr->getExprType().value() == ExprType::GroupedGridWelford)) { | ||
(expr->isOneOf< |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function 😆
@@ -765,7 +765,7 @@ struct DummyExpr : public Expr { | |||
Val* _outrhs, | |||
Val* _lhs, | |||
Val* _rhs) | |||
: Expr(passkey, ExprType::UnaryOp) // Not terribly safe... | |||
: Expr(passkey) // terribly safe :-D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember why it wasn't terribly safe, but I'm glad it is terribly safe now 😆
In the past, I added many new expression types, and whenever I added one, I felt very painful because there were just so many places to modify. When working on the expression simplifier, I am trying to create a new temporary op
FlattenedAdd
which stores the terms ina + b + c + d + ...
in a single expression instead of making a tree of binary op add. This new op is mainly for the convenience of usage and will be removed after expression simplification, so I want to keep its definition in a single CPP file and not expose it outside. But our current design just doesn't allow me to do so.So I started a refactoring of expressions, trying to reduce the amount of mechanical work required for adding a new expression. Ideally, defining a new expression type for private usage should be doable in a single CPP file without modifying other files.
This PR is the first step of this refactor. It removes
ExprType
. C++'s RTTI already provides what we need, so I don't think we should keep this extra type.