Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy
import random
import unittest

Expand All @@ -11,6 +12,7 @@
import torch
import torch.nn as nn
from float8_experimental.float8_linear_utils import get_float8_linear, LinearType
from float8_experimental.float8_tensor import Float8Tensor

# Setting to unblock for calling contiguous in backwards
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
Expand Down Expand Up @@ -76,5 +78,79 @@ def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtyp
_test_compile_base("inductor", fullgraph, emulate, linear_type, dtype)


class TestGraphBreaks:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) test_compile is a weird file when we have test/dynamo/*
(2) Can we use regular dynamo testing infra? compile counter, etc?
(3) 1 test class per test file please
(4) This class name is wrong? It's testing fp8 + dynamo, not testing graph breaks (I would imagine TestGraphBreaks would test things like deep dynamo internal working)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.) We don't have test/dynamo/*
2.) Sure is there a pointer for this?
3.) I hadn't heard of this rule before is this in some style guide convention somewhere?
4.) I thought thats implied by being in Float8Experimetnal, no?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) I know, but you can put a test there that depends on this, and have CI install it, right? Alternatively, if you depend on dynamo, just pull the deps out

(2) Yes! Check any test under test/dynamo/*, you dont have to inherit from the base class, but it does cover things like .reset() for you nicely (which I think we missed here!) - grep for CompileCounter :)

(3) No, but I think it helps with organization. Its just a tiny nipick

(4) True, that is my mistake.

class MockLinear(torch.nn.Module):
def __init__(self, graph_break: bool):
super().__init__()
self.register_buffer("fp8_amax_x", torch.tensor(1.0))
self.register_buffer("fp8_scale_x", torch.tensor(1.0))
self.graph_break = graph_break

def forward(self, x):
x_fp8 = Float8Tensor.to_float8(
x,
self.fp8_scale_x,
torch.float8_e4m3fn,
self.fp8_amax_x,
emulate=True, # TODO: I set this to True so that people on A100 can test, but once fix is in, set to False
)
if self.graph_break:
torch._dynamo.graph_break()
x_hp = x_fp8.to_original_precision()
return x_hp
return x_fp8

@pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear")
def test_float8_with_graph_break_in_the_middle(self):
"""Test that having Float8Tensor object at the boundary of a subgraph"""
mod = self.MockLinear(graph_break=True).cuda()
compiled_mod = copy.deepcopy(mod)
compiled_mod = torch.compile(compiled_mod)
x = torch.randn(16, 16, device="cuda")
y_eager = mod(x)
y_compiled = compiled_mod(x)
torch.testing.assert_close(y_eager, y_compiled)

def test_float8_graph_input(self):
"""Test that having Float8Tensor object as a graph input"""

def to_float(x):
return x.to_original_precision()

to_float = torch.compile(to_float)

mod = self.MockLinear(graph_break=False).cuda()
x = torch.randn(2, 2, device="cuda")
compiled_to_float = torch.compile(to_float)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to_float could hit a skip, and look exactly identical, and still return as compiled_to_float, and this test would pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right thats a bug 120 should be removed

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, thats not what I mean! I mean that if you dont use our test infra that counts frames, you could hit a skip, get a frame_count of 0, and this test would pass!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline - this is the only grounds for my reject.

y = mod(x)
y2_eager = to_float(y)
y2_compiled = compiled_to_float(y)
torch.testing.assert_close(y2_eager, y2_compiled)

@pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear")
def test_float8_graph_output(self):
"""Test that having Float8Tensor object as a graph output works"""
mod = self.MockLinear(graph_break=False).cuda()
compiled_mod = torch.compile(mod)
x = torch.randn(16, 16, device="cuda")
y_compiled = compiled_mod(x)

tensors, ctx = y_compiled.__tensor_flatten__()
for tensor in tensors:
assert not isinstance(
getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor
), "Float8Tensor should not contain any FakeTensors!"
assert isinstance(
y_compiled._orig_dtype, torch.dtype
), "Float8Tensor._orig_dtype should be a dtype but got {}".format(
type(y_compiled._orig_dtype)
)
assert isinstance(
y_compiled._emulate, bool
), "Float8Tensor._emulate should be a bool but got {}".format(
type(y_compiled._emulate)
)


if __name__ == "__main__":
pytest.main([__file__])