Skip to content

Commit 4c7fa06

Browse files
authored
Add some more validation checks for torch.linalg.eigh and torch.compile (#1580)
* Add some more validation checks for torch.linalg.eigh and torch.compile * Update test * Also update smoke_test.py * Fix lint
1 parent c6cbe77 commit 4c7fa06

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed

check_binary.sh

+6
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,12 @@ if [[ "$DESIRED_CUDA" != 'cpu' && "$DESIRED_CUDA" != 'cpu-cxx11-abi' && "$DESIRE
404404
echo "Test that linalg works"
405405
python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.svd(torch.mm(x.t(), x)))"
406406

407+
echo "Test that linalg.eigh works"
408+
python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.eigh(torch.mm(x.t(), x)))"
409+
410+
echo "Checking that basic torch.compile works"
411+
python ${TEST_CODE_DIR}/torch_compile_smoke.py
412+
407413
popd
408414
fi # if libtorch
409415
fi # if cuda

test/smoke_test/smoke_test.py

+3
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ def smoke_test_linalg() -> None:
193193
A = torch.randn(20, 16, 50, 100, device="cuda").type(dtype)
194194
torch.linalg.svd(A)
195195

196+
A = torch.rand(3, 3, device="cuda")
197+
L, Q = torch.linalg.eigh(torch.mm(A.t(), A))
198+
196199

197200
def smoke_test_compile() -> None:
198201
supported_dtypes = [torch.float16, torch.float32, torch.float64]
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
3+
4+
def foo(x: torch.Tensor) -> torch.Tensor:
5+
return torch.sin(x) + torch.cos(x)
6+
7+
8+
if __name__ == "__main__":
9+
x = torch.rand(3, 3, device="cuda")
10+
x_eager = foo(x)
11+
x_pt2 = torch.compile(foo)(x)
12+
print(torch.allclose(x_eager, x_pt2))

0 commit comments

Comments
 (0)