Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions include/tvm/node/structural_equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/data_type.h>

#include <cmath>
#include <string>

namespace tvm {
Expand All @@ -38,11 +39,21 @@ namespace tvm {
class BaseValueEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
// fuzzy float pt comparison
constexpr double atol = 1e-9;
if (lhs == rhs) return true;
double diff = lhs - rhs;
return diff > -atol && diff < atol;
if (std::isnan(lhs) && std::isnan(rhs)) {
// IEEE floats do not compare as equivalent to each other.
// However, for the purpose of comparing IR representation, two
// NaN values are equivalent.
return true;
} else if (std::isnan(lhs) || std::isnan(rhs)) {
return false;
} else if (lhs == rhs) {
return true;
} else {
// fuzzy float pt comparison
constexpr double atol = 1e-9;
double diff = lhs - rhs;
return diff > -atol && diff < atol;
}
}

bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; }
Expand Down
13 changes: 12 additions & 1 deletion include/tvm/node/structural_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/ndarray.h>

#include <cmath>
#include <functional>
#include <limits>
#include <string>

namespace tvm {
Expand All @@ -52,7 +54,16 @@ class BaseValueHash {

public:
uint64_t operator()(const float& key) const { return Reinterpret<float, uint32_t>(key); }
uint64_t operator()(const double& key) const { return Reinterpret<double, uint64_t>(key); }
uint64_t operator()(const double& key) const {
if (std::isnan(key)) {
// The IEEE format defines more than one bit-pattern that
// represents NaN. For the purpose of comparing IR
// representations, all NaN values are considered equivalent.
return Reinterpret<double, uint64_t>(std::numeric_limits<double>::quiet_NaN());
} else {
return Reinterpret<double, uint64_t>(key);
}
}
uint64_t operator()(const int64_t& key) const { return Reinterpret<int64_t, uint64_t>(key); }
uint64_t operator()(const uint64_t& key) const { return key; }
uint64_t operator()(const int& key) const { return Reinterpret<int, uint32_t>(key); }
Expand Down
43 changes: 43 additions & 0 deletions tests/python/tir-base/test_tir_structural_equal_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,5 +419,48 @@ def func(A: T.Buffer(1, "int32")):
assert '<root>.functions[I.GlobalVar("func")].body.extent.value' in err.value.args[0]


def test_nan_values_are_equivalent():
"""Structural equality treats two NaN values as equivalent.

By IEEE, a check of `NaN == NaN` returns false, as does
`abs(NaN - NaN) < tolerance`. However, for the purpose of
comparing IR representations, both NaN values are equivalent.

"""

@T.prim_func(private=True)
def func_1():
return T.float32("nan")

@T.prim_func(private=True)
def func_2():
return T.float32("nan")

tvm.ir.assert_structural_equal(func_1, func_2)
assert tvm.ir.structural_hash(func_1) == tvm.ir.structural_hash(func_2)


def test_all_nan_values_are_equivalent():
"""Structural equality treats two NaN values as equivalent.

IEEE defines NaN as any value that has all exponent bits set,
and has a non-zero mantissa. For the purposes of comparing IR
representations, all NaN values are considered equivalent.

"""

# A NaN with the first payload bit set.
nan_all_zeros = np.int32(0x7FC00000).view("float32")

# A NaN with the last payload bit set.
nan_with_payload = np.int32(0x7F800001).view("float32")

float_1 = T.float32(nan_all_zeros)
float_2 = T.float32(nan_with_payload)

tvm.ir.assert_structural_equal(float_1, float_2)
assert tvm.ir.structural_hash(float_1) == tvm.ir.structural_hash(float_2)


if __name__ == "__main__":
tvm.testing.main()