-
Notifications
You must be signed in to change notification settings - Fork 19
Add tests for Float8Tensor at graph boundaries #196
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import copy | ||
import random | ||
import unittest | ||
|
||
|
@@ -11,6 +12,7 @@ | |
import torch | ||
import torch.nn as nn | ||
from float8_experimental.float8_linear_utils import get_float8_linear, LinearType | ||
from float8_experimental.float8_tensor import Float8Tensor | ||
|
||
# Setting to unblock for calling contiguous in backwards | ||
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) | ||
|
@@ -76,5 +78,79 @@ def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtyp | |
_test_compile_base("inductor", fullgraph, emulate, linear_type, dtype) | ||
|
||
|
||
class TestGraphBreaks: | ||
class MockLinear(torch.nn.Module): | ||
def __init__(self, graph_break: bool): | ||
super().__init__() | ||
self.register_buffer("fp8_amax_x", torch.tensor(1.0)) | ||
self.register_buffer("fp8_scale_x", torch.tensor(1.0)) | ||
self.graph_break = graph_break | ||
|
||
def forward(self, x): | ||
x_fp8 = Float8Tensor.to_float8( | ||
x, | ||
self.fp8_scale_x, | ||
torch.float8_e4m3fn, | ||
self.fp8_amax_x, | ||
emulate=True, # TODO: I set this to True so that people on A100 can test, but once fix is in, set to False | ||
drisspg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
if self.graph_break: | ||
torch._dynamo.graph_break() | ||
x_hp = x_fp8.to_original_precision() | ||
return x_hp | ||
return x_fp8 | ||
|
||
@pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear") | ||
def test_float8_with_graph_break_in_the_middle(self): | ||
"""Test that having Float8Tensor object at the boundary of a subgraph""" | ||
mod = self.MockLinear(graph_break=True).cuda() | ||
compiled_mod = copy.deepcopy(mod) | ||
compiled_mod = torch.compile(compiled_mod) | ||
x = torch.randn(16, 16, device="cuda") | ||
y_eager = mod(x) | ||
y_compiled = compiled_mod(x) | ||
torch.testing.assert_close(y_eager, y_compiled) | ||
|
||
def test_float8_graph_input(self): | ||
"""Test that having Float8Tensor object as a graph input""" | ||
|
||
def to_float(x): | ||
return x.to_original_precision() | ||
|
||
to_float = torch.compile(to_float) | ||
|
||
mod = self.MockLinear(graph_break=False).cuda() | ||
x = torch.randn(2, 2, device="cuda") | ||
compiled_to_float = torch.compile(to_float) | ||
|
||
y = mod(x) | ||
y2_eager = to_float(y) | ||
y2_compiled = compiled_to_float(y) | ||
torch.testing.assert_close(y2_eager, y2_compiled) | ||
|
||
@pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear") | ||
def test_float8_graph_output(self): | ||
"""Test that having Float8Tensor object as a graph output works""" | ||
mod = self.MockLinear(graph_break=False).cuda() | ||
compiled_mod = torch.compile(mod) | ||
x = torch.randn(16, 16, device="cuda") | ||
y_compiled = compiled_mod(x) | ||
|
||
tensors, ctx = y_compiled.__tensor_flatten__() | ||
for tensor in tensors: | ||
assert not isinstance( | ||
getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor | ||
), "Float8Tensor should not contain any FakeTensors!" | ||
assert isinstance( | ||
y_compiled._orig_dtype, torch.dtype | ||
), "Float8Tensor._orig_dtype should be a dtype but got {}".format( | ||
type(y_compiled._orig_dtype) | ||
) | ||
assert isinstance( | ||
y_compiled._emulate, bool | ||
), "Float8Tensor._emulate should be a bool but got {}".format( | ||
type(y_compiled._emulate) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(1) test_compile is a weird file when we have test/dynamo/*
(2) Can we use regular dynamo testing infra? compile counter, etc?
(3) 1 test class per test file please
(4) This class name is wrong? It's testing fp8 + dynamo, not testing graph breaks (I would imagine TestGraphBreaks would test things like deep dynamo internal working)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1.) We don't have test/dynamo/*
2.) Sure is there a pointer for this?
3.) I hadn't heard of this rule before is this in some style guide convention somewhere?
4.) I thought thats implied by being in Float8Experimetnal, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(1) I know, but you can put a test there that depends on this, and have CI install it, right? Alternatively, if you depend on dynamo, just pull the deps out
(2) Yes! Check any test under test/dynamo/*, you dont have to inherit from the base class, but it does cover things like .reset() for you nicely (which I think we missed here!) - grep for CompileCounter :)
(3) No, but I think it helps with organization. Its just a tiny nipick
(4) True, that is my mistake.