diff --git a/test/test_compile.py b/test/test_compile.py index 9b88811a..d39b7400 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import copy import random import unittest @@ -11,6 +12,10 @@ import torch import torch.nn as nn from float8_experimental.float8_linear_utils import get_float8_linear, LinearType +from float8_experimental.float8_tensor import Float8Tensor + +from torch._dynamo.test_case import TestCase as DynamoTestCase +from torch._dynamo.testing import CompileCounterWithBackend # Setting to unblock for calling contiguous in backwards is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) @@ -76,5 +81,87 @@ def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtyp _test_compile_base("inductor", fullgraph, emulate, linear_type, dtype) +class TestGraphBreaks(DynamoTestCase): + class MockLinear(torch.nn.Module): + def __init__(self, graph_break: bool): + super().__init__() + self.register_buffer("fp8_amax_x", torch.tensor(1.0)) + self.register_buffer("fp8_scale_x", torch.tensor(1.0)) + self.graph_break = graph_break + + def forward(self, x): + x_fp8 = Float8Tensor.to_float8( + x, + self.fp8_scale_x, + torch.float8_e4m3fn, + self.fp8_amax_x, + emulate=True, # TODO: I set this to True so that people on A100 can test, but once fix is in, set to False + ) + if self.graph_break: + torch._dynamo.graph_break() + x_hp = x_fp8.to_original_precision() + return x_hp + return x_fp8 + + @pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear") + def test_float8_with_graph_break_in_the_middle(self): + """Test that having Float8Tensor object at the boundary of a subgraph""" + cnts = CompileCounterWithBackend("inductor") + mod = self.MockLinear(graph_break=True).cuda() + compiled_mod = copy.deepcopy(mod) + compiled_mod = torch.compile(compiled_mod, backend=cnts) + x = torch.randn(16, 16, device="cuda") + y_eager = mod(x) + y_compiled = compiled_mod(x) + self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!") + torch.testing.assert_close(y_eager, y_compiled) + + def test_float8_graph_input(self): + """Test that having Float8Tensor object as a graph input""" + + def to_float(x): + return x.to_original_precision() + + cnts = CompileCounterWithBackend("inductor") + mod = self.MockLinear(graph_break=False).cuda() + x = torch.randn(2, 2, device="cuda") + compiled_to_float = torch.compile(to_float, backend=cnts) + y = mod(x) + y2_eager = to_float(y) + y2_compiled = compiled_to_float(y) + self.assertEqual( + cnts.frame_count, + 1, + "to_float was not compiled into 1 frame and likely encountered a skip!", + ) + torch.testing.assert_close(y2_eager, y2_compiled) + + @pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear") + def test_float8_graph_output(self): + """Test that having Float8Tensor object as a graph output works""" + cnts = CompileCounterWithBackend("inductor") + mod = self.MockLinear(graph_break=False).cuda() + compiled_mod = torch.compile(mod, backend=cnts) + x = torch.randn(16, 16, device="cuda") + y_compiled = compiled_mod(x) + + self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") + tensors, ctx = y_compiled.__tensor_flatten__() + for tensor in tensors: + assert not isinstance( + getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor + ), "Float8Tensor should not contain any FakeTensors!" + assert isinstance( + y_compiled._orig_dtype, torch.dtype + ), "Float8Tensor._orig_dtype should be a dtype but got {}".format( + type(y_compiled._orig_dtype) + ) + assert isinstance( + y_compiled._emulate, bool + ), "Float8Tensor._emulate should be a bool but got {}".format( + type(y_compiled._emulate) + ) + + if __name__ == "__main__": pytest.main([__file__])