Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,43 +1749,65 @@ def test_constant_scale(
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("prob", (0.1, 0.5, 0.75))
@pytest.mark.parametrize("prob", (0.0625, 0.5, 0.75))
@pytest.mark.parametrize("is_training", (True, False))
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16)))
@pytest.mark.parametrize("quantization", (None, "fp8_current_scaling"))
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16), (128, 128)))
@pytest.mark.parametrize("dtype", _dtypes)
def test_dropout(
self,
*,
prob: float,
is_training: bool,
quantization: Optional[str],
shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
):

# Skip invalid configurations
quantized_input = quantization is not None
maybe_skip_quantization(quantization, dims=shape, device=device)

# Random data
x_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
x_test = x_ref.clone().requires_grad_()
dy_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
dy_test = dy_ref.clone()
# Note: Shift values to make sure inputs are non-zero
x_ref, x_test = make_reference_and_test_tensors(
shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=quantized_input,
)
with torch.no_grad():
x_test += 1
x_ref.copy_(x_test)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Apply dropout
op = te_ops.Dropout(prob)
if is_training:
op.train()
else:
op.eval()
y = op(x_test)
y.backward(dy_test)
y_test = op(x_test)
y_test.backward(dy_test)

# Check values
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
if is_training:
mask = ((y != 0) / (1 - prob)).to(dtype=dtype)
torch.testing.assert_close(y, x_ref * mask)
torch.testing.assert_close(x_test.grad, dy_ref * mask)
tols = dtype_tols(dtype)
mask = ((y_test != 0) / (1 - prob)).to(dtype=dtype)
torch.testing.assert_close(y_test, x_ref * mask, **tols)
torch.testing.assert_close(dx_test, dy_ref * mask, **tols)
else:
torch.testing.assert_close(y, x_ref, rtol=0, atol=0)
torch.testing.assert_close(x_test.grad, dy_ref, rtol=0, atol=0)
torch.testing.assert_close(y_test, x_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, dy_ref, rtol=0, atol=0)

# Hypothesis testing for number of zeros
# Note: A Bernoulli random variable with probability p has
Expand All @@ -1797,9 +1819,11 @@ def test_dropout(
# p-value is less than 1% and we assume that the dropout
# distribution is incorrect.
if is_training:
prob_observed = 1 - torch.count_nonzero(y).item() / y.numel()
z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y.numel())
assert abs(z_score) < 2.5758, "Number of zeros is outside 99% confidence interval"
prob_observed = 1 - torch.count_nonzero(y_test).item() / y_test.numel()
z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y_test.numel())
assert (
abs(z_score) < 2.5758
), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})"


class TestFusedOps:
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ list(APPEND transformer_engine_SOURCES
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
Expand Down
33 changes: 33 additions & 0 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,38 @@ def _load_nvrtc():
return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)


@functools.lru_cache(maxsize=None)
def _load_curand():
"""Load cuRAND shared library."""
# Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)

# Attempt to locate cuRAND in Python dist-packages
found, handle = _load_nvidia_cuda_library("curand")
if found:
return handle

# Attempt to locate cuRAND via ldconfig
libs = subprocess.check_output(
f"ldconfig -p | grep 'libcurand{_get_sys_extension()}'", shell=True
)
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libcurand" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)

# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)


@functools.lru_cache(maxsize=None)
def _load_core_library():
"""Load shared library with Transformer Engine C extensions"""
Expand All @@ -303,6 +335,7 @@ def _load_core_library():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_CURAND_LIB_CTYPES = _load_curand()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
_TE_LIB_CTYPES = _load_core_library()
Expand Down
Loading