-
Notifications
You must be signed in to change notification settings - Fork 7
Expr simplifier: implement prove::isPositive, prove::isNonNegative, prove::isNonZero #2273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||
bool NamedScalar::isTensorSize() const { | ||
static const std::regex r(R"(T\d+\.size\[\d+\])"); | ||
return std::regex_match(name(), r); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be fine, but why don't we add static methods to create tensor size and stride named scalars and explicitly tag them as such? That seems more robust than the string-based pattern match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I generally agree, but I prefer to keep NamedScalar
string only. We could have new subclasses of Val
for values that has special meaning, for example class TensorAttribute : public Val
. And I believe a good place to do this refactor is at #2282, where I need to access the .data
of a global tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that would be better, and it should have a dependency connection to the tensor itself. I just don't like the string matching-based approach as it seems backward.
} | ||
} else if (auto bop = dynamic_cast<BinaryOp*>(value->definition())) { | ||
auto op = bop->getBinaryOpType(); | ||
if (op == BinaryOpType::Mod || op == BinaryOpType::Div || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are only these ops considered here? Why not, e.g., BinaryOpType::Add
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because Add
and Mul
are already flattened as FlattenedAssocCommOp
, so we will not see BinaryOp
with Add
and Mul
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the fattening done? Does this function assume the value
parameter is flattened?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flatten is done before running any other passes:
auto simplified = assoc_comm::flatten(value); |
So all other passes assumes the input
value
already flattened.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, thanks. Could you mention that as a comment? Probably obvious to you, but it would have been definitely a helpful comment for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, thanks. Could you mention that as a comment? Probably obvious to you, but it would have been definitely a helpful comment for me.
Oh, sorry, missed this comment. Will add to #2275
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in dbe16cd
op == BinaryOpType::Eq && prove::isNonZero(bop->lhs(), var_info)) { | ||
return IrBuilder::newConstant(false, DataType::Bool); | ||
} | ||
} else if (bop->lhs()->isZeroInt()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Looks like the block below is mostly a duplicate of the above block. Would be great if refactored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can not think of a more readable way to refactor this. This is like implementing a table lookup using if
statements. Different branches will look similar, but each with some small modifications. To me, the most straightforward way to implement this is to just list them all.
…into compatible-sign-check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for answering the questions.
This PR implements
prove::isPositive
,prove::isNonNegative
and improveprove::isNonZero
. This PR also added a new passeliminateTrivialPredicate
to simplify compare with zero. I don't theeliminateTrivialPredicate
will be super important to simplify predicates that we generate, but it is super helpful for me to write unit tests to check if these proves works as expected.The proves
prove::isPositive
,prove::isNonNegative
, andprove::isNonZero
are very important for expression simplification because many rules require sign compatibility check: