-
Notifications
You must be signed in to change notification settings - Fork 7
Vectorized welford #2204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vectorized welford #2204
Conversation
Lift the predicated count division outside of the innermost loop if that loop is exactly mapped with vectorized IDs and not a reduction domain. Targeted to address outer-reduction grid welford tuning
Represents the 32-bit floating-point scalar value. Not supported in PyTorch, so can't be used as inputs to fusions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor questions, but overall looks fine, marking as approved.
case DataType::Float: | ||
return IrBuilder::create<Double>(DataType::Float); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought you added float scalar support, why is there not supposed to be a float entry here? Is it because this can be user facing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did, but then I felt it's sometimes inconvenient to have both Double
and Float
, so I dropped Float
. The dtype
field of Double
is used to specify its actual type (#2203). It's also the case with Int
, which can represent int64_t
by Int(DataType::Int)
and int
by Int(DataType::Int)
.
// welfordVectorized(..., new_count, reciprocal, p); | ||
// } | ||
void vectorizeWithInlinePredicate(WelfordOp* wop) { | ||
kir::IfThenElse* wop_ite = scope_exprs_.back()->as<kir::IfThenElse>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should you also double check the ite
only has one expression in it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the check in isVectorizableWelford
.
if (!(wop_ite->thenBody().size() == 1 && wop_ite->thenBody().at(0) == wop &&
wop_ite->elseBody().empty())) {
return false;
}
// Predicated case when the innermost loop has no externaly | ||
// visible effect except for the welford outputs | ||
// | ||
// Before: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the before supposed to be the same in this case as the inline predicate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's checked in isVectorizableWelford
:
// The predicate should be either Manual or Inline
if (wop_ite->predicate()->predicate_type() != PredicateType::Manual &&
wop_ite->predicate()->predicate_type() != PredicateType::Inline) {
return false;
}
registerInsertBeforeInnerMostLoop(reciprocal_expr); | ||
} else { | ||
// Initialize reciprocal as 0; | ||
registerInsertBeforeInnerMostLoop(IrBuilder::create<UnaryOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this just be where(pred, 1/new_count_float, 0)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it doesn't seem to work. where
only accepts Val*
, so we would need to have 1/new_count_float
as a separate expression and use its output as an input to where
, which means 1/new_count_float
would be executed unconditionally. Since new_count_float
can be zero when the predicate is false, this would result in unsafe code. I'm going to leave it as is.
// it is invariant within the loop. The loop index can be replaced | ||
// with what ever value within the loop range since it is | ||
// independent of the loop index. | ||
kir::TensorIndex* hoistCount(kir::TensorIndex* out_N) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonder if it would be convenient to add a utility to detect which loops a Val
is dependent on so we could assert the invariance this assumes (I know it's really just detected earlier). @zasdfgbnm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that would be useful. We could probably just do a search of a specific Val
in an index math expression.
// Hoist the count TensorIndex out of the innermost loop, assuming | ||
// it is invariant within the loop. The loop index can be replaced | ||
// with what ever value within the loop range since it is | ||
// independent of the loop index. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the loop index have to be replaced? Is it not already just 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, at this point the loop index is used to index the count Val as indexing doesn't know the value of the loop index doesn't matter.
// | ||
// The pattern matching here only works with a simple sequence of | ||
// expressions but should be sufficient as it is likely the most | ||
// common situation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like this could be a nice generic pass like (a future) index hoisting. Seems to me it's a cleanup of unrolling loops, i.e. trying to simplify the else of the unrolled ite. Noting as future work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. This particular implementation is really simplistic as I don't expect nothing complex loop body should appear with a serial WelfordOp, so we would need to consider a wider variation of patterns.
// guaranteed that its predicate can be safely replaced with the | ||
// predicate of the WelfordOp. If not, that doesn't necessarily | ||
// mean the expr predicate is different, but likely not | ||
// worthwhile to consider. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If two expressions are loop mapped, it doesn't necessarily mean they have the same predicate, it just means their predicates are safe to join. This is because of broadcast resolution, the full predicate is a composite of all the root domains covered in the roots of the expressions. I don't see any specific errors in the code, I'm just stating this is how we do the unrolled predicate, accumulate all the predicates from the various expressions to make sure we covered all the root IterDomains.
@zasdfgbnm Merged the PR, but let me know if you have any comments. |
Apply loop-invariant code hoisting to serial WelfordOps. For example, when the innermost loop looks like:
The count input should be invariant when the loop is not a reduction loop, and if the predicate is also loop invariant, then this can be transformed as:
After:
This is meant to optimize outer welford reductions. With this,
FusionComputeWith6
looks like: