Skip to content

Conversation

vthumbe1503
Copy link
Collaborator

@vthumbe1503 vthumbe1503 commented Sep 22, 2025

Description

Jax integration for clamped swiglu

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 30 commits September 19, 2025 06:10
Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

revert accidental change

Signed-off-by: Varun Thumbe <[email protected]>

Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

fix merge conflict

Signed-off-by: Varun Thumbe <[email protected]>

bug: missed a } in the code

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

Add cuBLASMp-backed GEMM-like API to TE common (NVIDIA#1824)

* Pick up cuBLASMp during build

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Change lib order to fix link error

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Context creation, incomplete...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Test fixure

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* A sanity AgGemm test, failing...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix axes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Take care of uneven distribution

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use MPI to get position of local matrices

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Refactor

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Refactor & fixes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Gemm-RS

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Gemm-AR, not working...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fixes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Setting all-reduce epilogue for gemm-ar

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use supported shapes for GEMM-AR

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak tolerance

Signed-off-by: Vladimir Cherepanov <[email protected]>

* First shot at fp8

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use TensorHolder in tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More test configs

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Support comm_sm_count

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Parametrize dtypes for A, B and D separately

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak scaling

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Amax ptr

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Flags parity with cublas_gemm, saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Cleanup

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Bias tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix bias test

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Aux, saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* aux_ld

Signed-off-by: Vladimir Cherepanov <[email protected]>

* A fix

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use test::Tensor

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Set scale inv

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove unsupported test configs

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Replace libcal with NCCL

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add NVTX markers to API functions

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak GemmAr tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More test config

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix merge fallout

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove MPI dependency, comment API, add algo parameter

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix nvshmem dependency

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix nvshmem build

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Excluse CommGemm tests from L0_cppunittest

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add cpp_distributed sh file for CI

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Adapt tp TensorAllocator

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Skip GemmAr test on unsupported HW

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Oversibscribe is needed on some clusters

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix incomplete libcal removal

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Move CI tests to L1

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Rename context to include NVTE prefix

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove leftover code

Signed-off-by: Vladimir Cherepanov <[email protected]>

* NVTE_WITH_CUBLASMP off by default

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More detailed NVTE_CHECK diag

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Comment API

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Include stdbool header for legacy C compilers

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove now unused argument

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Abstract away cuBLASMp algo behind our own enum

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More detailed shape diag messages

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update transformer_engine/common/include/transformer_engine/comm_gemm.h

Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add license

Signed-off-by: Vladimir Cherepanov <[email protected]>

---------

Signed-off-by: Vladimir Cherepanov <[email protected]>
Signed-off-by: Vladimir Cherepanov <[email protected]>
Co-authored-by: Vladimir Cherepanov <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (NVIDIA#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.

Signed-off-by: Ming Huang <[email protected]>

* Slightly refactor

Signed-off-by: Ming Huang <[email protected]>

* Adding documents of new args.

Signed-off-by: Ming Huang <[email protected]>

* Adding unit-tests.

Signed-off-by: Ming Huang <[email protected]>

* Adding license.

Signed-off-by: Ming Huang <[email protected]>

* Move unit-tests to L1.

Signed-off-by: Ming Huang <[email protected]>

* Move quantizaer store/reset into FP8 only.

Signed-off-by: Ming Huang <[email protected]>

* Adding all layout support for Blackwell+

Signed-off-by: Ming Huang <[email protected]>

* Adopt the feedback from code-review.

Signed-off-by: Ming Huang <[email protected]>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.

Signed-off-by: Ming Huang <[email protected]>

---------

Signed-off-by: Ming Huang <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[JAX] Delay MeshResource validation until first usage (NVIDIA#2124)

Delay MeshResource validation until first usage

Signed-off-by: Jeremy Berchtold <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[JAX] Decouple Recipe and ScalingMode (NVIDIA#1728)

* Decouple recipe and scaling mode

Signed-off-by: Jeremy Berchtold <[email protected]>

* Expose global QuantizeConfig instance as a getter

Signed-off-by: Jeremy Berchtold <[email protected]>

* Format and lint

Signed-off-by: Jeremy Berchtold <[email protected]>

* Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling

Signed-off-by: Jeremy Berchtold <[email protected]>

* Rename UsageType to TensorSource

Signed-off-by: Jeremy Berchtold <[email protected]>

* Update test_layer.py

Signed-off-by: Jeremy Berchtold <[email protected]>

---------

Signed-off-by: Jeremy Berchtold <[email protected]>
Signed-off-by: jberchtold-nvidia <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (NVIDIA#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (NVIDIA#2118)

* add amax input to DBiasQuantizePrimitive and FFI

Signed-off-by: Phuong Nguyen <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make sure amax is init with zero

Signed-off-by: Phuong Nguyen <[email protected]>

* fix sharding rule

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Varun Thumbe <[email protected]>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (NVIDIA#2121)

Signed-off-by: Kshitij Lakhani <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

Temporarily remove comm_gemm tests (NVIDIA#2133)

Signed-off-by: Vladimir Cherepanov <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[PyTorch] Disable determinism for sm100 (NVIDIA#2130)

* disable determinism for sm100+ and cudnn<9.14

Signed-off-by: Charlene Yang <[email protected]>

* fix remaining CI failures

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert some changes

Signed-off-by: Charlene Yang <[email protected]>

* revert more changes

Signed-off-by: Charlene Yang <[email protected]>

* remove sm100 from determinism table

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Varun Thumbe <[email protected]>

[PyTorch] ONNX export of FP8 Current Scaling (NVIDIA#2068)

* Compute amax in normalization forward in current scaling in untuned kernels

Signed-off-by: Jan Bielak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* code drop

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* apply tims suggestions

Signed-off-by: Pawel Gadzinski <[email protected]>

---------

Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Co-authored-by: Jan Bielak <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Varun Thumbe <[email protected]>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (NVIDIA#2134)

use torch empty for empty shape instead of from_blob

Signed-off-by: zhongboz <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

build: pull cached wheels (NVIDIA#2127)

* build: pull cached wheels

Signed-off-by: oliver könig <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update setup.py

Signed-off-by: oliver könig <[email protected]>

---------

Signed-off-by: oliver könig <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

feat: Add support for multiple quantization modes in the UB communicators (NVIDIA#2043)

Signed-off-by: Varun Thumbe <[email protected]>

[Common] Add checks to CUDA kernel launch and CUDA API calls (NVIDIA#2074)

* add checks to cuda kernel launch and cuda API calls

Signed-off-by: Xin Yao <[email protected]>

* Remove exceptions from destructors

Signed-off-by: Tim Moon <[email protected]>

* fix weired dispatch in ln/rmsnorm

Signed-off-by: Xin Yao <[email protected]>

---------

Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[PyTorch] Support bf16+fp8 cudagraph (NVIDIA#2098)

* support bf16+fp8 model

Signed-off-by: Robin Zhang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: Robin Zhang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: Robin Zhang <[email protected]>

---------

Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

Dropout with 8-bit RNG (NVIDIA#2014)

* Add dropout kernel with 8-bit RNG

Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix license

Signed-off-by: Tim Moon <[email protected]>

* Avoid ambiguous types

Signed-off-by: Tim Moon <[email protected]>

* Do not enforce dropout prob is representable in 8 bits

Signed-off-by: Tim Moon <[email protected]>

* Expand error message

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warning

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unnecessary helper function in PyTorch extensions

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Varun Thumbe <[email protected]>

Create GPU reload buffers on main stream (NVIDIA#2131)

* Create GPU relaod buffers on main stream

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed typo

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Fixed typo

Signed-off-by: Selvaraj Anandaraj <[email protected]>

---------

Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Paweł Gadziński <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

mxfp8 unfused quant support, refined unit test, remove unecessary quantization code

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

missed a quant code removal

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

minor bug fix

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

Add cuBLASMp-backed GEMM-like API to TE common (NVIDIA#1824)

* Pick up cuBLASMp during build

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Change lib order to fix link error

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Context creation, incomplete...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Test fixure

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* A sanity AgGemm test, failing...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix axes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Take care of uneven distribution

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use MPI to get position of local matrices

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Refactor

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Refactor & fixes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Gemm-RS

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Gemm-AR, not working...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fixes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Setting all-reduce epilogue for gemm-ar

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use supported shapes for GEMM-AR

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak tolerance

Signed-off-by: Vladimir Cherepanov <[email protected]>

* First shot at fp8

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use TensorHolder in tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More test configs

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Support comm_sm_count

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Parametrize dtypes for A, B and D separately

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak scaling

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Amax ptr

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Flags parity with cublas_gemm, saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Cleanup

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Bias tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix bias test

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Aux, saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* aux_ld

Signed-off-by: Vladimir Cherepanov <[email protected]>

* A fix

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use test::Tensor

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Set scale inv

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove unsupported test configs

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Replace libcal with NCCL

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add NVTX markers to API functions

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak GemmAr tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More test config

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix merge fallout

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove MPI dependency, comment API, add algo parameter

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix nvshmem dependency

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix nvshmem build

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Excluse CommGemm tests from L0_cppunittest

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add cpp_distributed sh file for CI

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Adapt tp TensorAllocator

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Skip GemmAr test on unsupported HW

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Oversibscribe is needed on some clusters

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix incomplete libcal removal

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Move CI tests to L1

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Rename context to include NVTE prefix

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove leftover code

Signed-off-by: Vladimir Cherepanov <[email protected]>

* NVTE_WITH_CUBLASMP off by default

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More detailed NVTE_CHECK diag

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Comment API

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Include stdbool header for legacy C compilers

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove now unused argument

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Abstract away cuBLASMp algo behind our own enum

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More detailed shape diag messages

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update transformer_engine/common/include/transformer_engine/comm_gemm.h

Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add license

Signed-off-by: Vladimir Cherepanov <[email protected]>

---------

Signed-off-by: Vladimir Cherepanov <[email protected]>
Signed-off-by: Vladimir Cherepanov <[email protected]>
Co-authored-by: Vladimir Cherepanov <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

Temporarily remove comm_gemm tests (NVIDIA#2133)

Signed-off-by: Vladimir Cherepanov <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

minor code cleanup

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

minor cosmetics

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

Address review comment

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

minor comment update

Signed-off-by: Varun Thumbe <[email protected]>

Fix CI failures for UB overlap changes (NVIDIA#2149)

Signed-off-by: djns99 <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

minor bug: quantizer should not be none for unfused quantization

Signed-off-by: Varun Thumbe <[email protected]>

[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (NVIDIA#2135)

* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell

Signed-off-by: Kshitij Lakhani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix the skip message

Signed-off-by: Kshitij Lakhani <[email protected]>

* Assert in fused attn bwd pass for sm100

Signed-off-by: Kshitij Lakhani <[email protected]>

Add check for sm100

Signed-off-by: Kshitij Lakhani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add support to get all devs in the process for jax

Signed-off-by: Kshitij Lakhani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Code clean up

Signed-off-by: Kshitij Lakhani <[email protected]>

* Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion

Signed-off-by: Kshitij Lakhani <[email protected]>

* Represent attn bias using enum instead of string

Signed-off-by: Kshitij Lakhani <[email protected]>

---------

Signed-off-by: Kshitij Lakhani <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Varun Thumbe <[email protected]>

fix linting error

Signed-off-by: Varun Thumbe <[email protected]>
…d kernels needs to be fixed

Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

Add cuBLASMp-backed GEMM-like API to TE common (NVIDIA#1824)

* Pick up cuBLASMp during build

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Change lib order to fix link error

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Context creation, incomplete...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Test fixure

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* A sanity AgGemm test, failing...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix axes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Take care of uneven distribution

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use MPI to get position of local matrices

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Refactor

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Refactor & fixes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Gemm-RS

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Gemm-AR, not working...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fixes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Setting all-reduce epilogue for gemm-ar

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use supported shapes for GEMM-AR

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak tolerance

Signed-off-by: Vladimir Cherepanov <[email protected]>

* First shot at fp8

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use TensorHolder in tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More test configs

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Support comm_sm_count

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Parametrize dtypes for A, B and D separately

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak scaling

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Amax ptr

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Flags parity with cublas_gemm, saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Cleanup

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Bias tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix bias test

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Aux, saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* aux_ld

Signed-off-by: Vladimir Cherepanov <[email protected]>

* A fix

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use test::Tensor

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Set scale inv

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove unsupported test configs

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Replace libcal with NCCL

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add NVTX markers to API functions

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak GemmAr tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More test config

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix merge fallout

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove MPI dependency, comment API, add algo parameter

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix nvshmem dependency

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix nvshmem build

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Excluse CommGemm tests from L0_cppunittest

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add cpp_distributed sh file for CI

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Adapt tp TensorAllocator

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Skip GemmAr test on unsupported HW

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Oversibscribe is needed on some clusters

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix incomplete libcal removal

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Move CI tests to L1

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Rename context to include NVTE prefix

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove leftover code

Signed-off-by: Vladimir Cherepanov <[email protected]>

* NVTE_WITH_CUBLASMP off by default

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More detailed NVTE_CHECK diag

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Comment API

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Include stdbool header for legacy C compilers

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove now unused argument

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Abstract away cuBLASMp algo behind our own enum

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More detailed shape diag messages

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update transformer_engine/common/include/transformer_engine/comm_gemm.h

Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add license

Signed-off-by: Vladimir Cherepanov <[email protected]>

---------

Signed-off-by: Vladimir Cherepanov <[email protected]>
Signed-off-by: Vladimir Cherepanov <[email protected]>
Co-authored-by: Vladimir Cherepanov <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (NVIDIA#2119)

* add noop to comp amax

Signed-off-by: zhongboz <[email protected]>

* fix for fp8 blockwise recipe

Signed-off-by: zhongboz <[email protected]>

* resolve comments

Signed-off-by: zhongboz <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: zhongboz <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[PyTorch] fix cross entropy vanishing gradients (NVIDIA#2139)

* fix cross entropy

Signed-off-by: Casper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Casper <[email protected]>

* fix comments

Signed-off-by: Casper <[email protected]>

* fix: few more style issues

Signed-off-by: Casper <[email protected]>

* fix: remove grad_output_stride (unnecessary)

Signed-off-by: Casper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: only backward was broken

Signed-off-by: Casper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Generalize cross entropy backward kernel to handle reduced and unreduced loss

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Casper <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

Fix bug when enabling --overlap-grad-reduce in mcore (NVIDIA#2142)

* fix bugs when enabling --overlap-grad-reduce in mcore

Signed-off-by: Hongbin Liu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix CI

Signed-off-by: Hongbin Liu <[email protected]>

* format

Signed-off-by: Hongbin Liu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Hongbin Liu <[email protected]>
Co-authored-by: Hongbin Liu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Varun Thumbe <[email protected]>

Fix CUDA version in setup.py (NVIDIA#2132)

* Fix CUDA version in setup.py

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Re-enable building comm-gemm tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* WAR for nvidia-nvshmem package

Signed-off-by: Vladimir Cherepanov <[email protected]>

---------

Signed-off-by: Vladimir Cherepanov <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[JAX] NoScaleTensor wrapper for non-quantized data (NVIDIA#2136)

* Custom call tests passing

Signed-off-by: Jeremy Berchtold <[email protected]>

* Fix test_layer.py

Signed-off-by: Jeremy Berchtold <[email protected]>

* Lint

Signed-off-by: Jeremy Berchtold <[email protected]>

* Fix comments

Signed-off-by: Jeremy Berchtold <[email protected]>

* Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling

Signed-off-by: Jeremy Berchtold <[email protected]>

* Fix shardy issue with amax being shape 1,1,1 instead of shape (1,)

Signed-off-by: Jeremy Berchtold <[email protected]>

* Add higher-precision VJP tests to test_distributed_layernorm_mlp

Signed-off-by: Jeremy Berchtold <[email protected]>

* Cast non-quantized kernels to input dtype in VJPs

Signed-off-by: Jeremy Berchtold <[email protected]>

* Rename HighPrecisionTensor to NoScaleTensor

Signed-off-by: Jeremy Berchtold <[email protected]>

* Use NoScaleTensor in pure JAX impls where it was missing

Signed-off-by: Jeremy Berchtold <[email protected]>

* Fix tests

Signed-off-by: Jeremy Berchtold <[email protected]>

---------

Signed-off-by: Jeremy Berchtold <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[JAX] Fix GroupedScaledTensor creation with keyword arg (NVIDIA#2154)

Fix GroupedScaledTensor creation

Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

Fixing few issues with multi-process launching. (NVIDIA#2155)

* Fixing few issues with multi-process launching.

Signed-off-by: Ming Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Ming Huang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Phuong Nguyen <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

Update list of authorized CI users (NVIDIA#2152)

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

a bit of cleanup

Signed-off-by: Varun Thumbe <[email protected]>
author Varun Thumbe <[email protected]> 1757373536 +0000
committer Varun Thumbe <[email protected]> 1758262513 +0000

parent de9ef2f
author Varun Thumbe <[email protected]> 1757373536 +0000
committer Varun Thumbe <[email protected]> 1758262476 +0000

parent de9ef2f
author Varun Thumbe <[email protected]> 1757373536 +0000
committer Varun Thumbe <[email protected]> 1758262304 +0000

merge conflict

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (NVIDIA#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.

Signed-off-by: Ming Huang <[email protected]>

* Slightly refactor

Signed-off-by: Ming Huang <[email protected]>

* Adding documents of new args.

Signed-off-by: Ming Huang <[email protected]>

* Adding unit-tests.

Signed-off-by: Ming Huang <[email protected]>

* Adding license.

Signed-off-by: Ming Huang <[email protected]>

* Move unit-tests to L1.

Signed-off-by: Ming Huang <[email protected]>

* Move quantizaer store/reset into FP8 only.

Signed-off-by: Ming Huang <[email protected]>

* Adding all layout support for Blackwell+

Signed-off-by: Ming Huang <[email protected]>

* Adopt the feedback from code-review.

Signed-off-by: Ming Huang <[email protected]>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.

Signed-off-by: Ming Huang <[email protected]>

---------

Signed-off-by: Ming Huang <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>

[JAX] Delay MeshResource validation until first usage (NVIDIA#2124)

Delay MeshResource validation until first usage

Signed-off-by: Jeremy Berchtold <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (NVIDIA#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (NVIDIA#2118)

* add amax input to DBiasQuantizePrimitive and FFI

Signed-off-by: Phuong Nguyen <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make sure amax is init with zero

Signed-off-by: Phuong Nguyen <[email protected]>

* fix sharding rule

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (NVIDIA#2121)

Signed-off-by: Kshitij Lakhani <[email protected]>

Temporarily remove comm_gemm tests (NVIDIA#2133)

Signed-off-by: Vladimir Cherepanov <[email protected]>

[PyTorch] Disable determinism for sm100 (NVIDIA#2130)

* disable determinism for sm100+ and cudnn<9.14

Signed-off-by: Charlene Yang <[email protected]>

* fix remaining CI failures

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert some changes

Signed-off-by: Charlene Yang <[email protected]>

* revert more changes

Signed-off-by: Charlene Yang <[email protected]>

* remove sm100 from determinism table

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch] ONNX export of FP8 Current Scaling (NVIDIA#2068)

* Compute amax in normalization forward in current scaling in untuned kernels

Signed-off-by: Jan Bielak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* code drop

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* apply tims suggestions

Signed-off-by: Pawel Gadzinski <[email protected]>

---------

Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Co-authored-by: Jan Bielak <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (NVIDIA#2134)

use torch empty for empty shape instead of from_blob

Signed-off-by: zhongboz <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>

build: pull cached wheels (NVIDIA#2127)

* build: pull cached wheels

Signed-off-by: oliver könig <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update setup.py

Signed-off-by: oliver könig <[email protected]>

---------

Signed-off-by: oliver könig <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>

[Common] Add checks to CUDA kernel launch and CUDA API calls (NVIDIA#2074)

* add checks to cuda kernel launch and cuda API calls

Signed-off-by: Xin Yao <[email protected]>

* Remove exceptions from destructors

Signed-off-by: Tim Moon <[email protected]>

* fix weired dispatch in ln/rmsnorm

Signed-off-by: Xin Yao <[email protected]>

---------

Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>

[PyTorch] Support bf16+fp8 cudagraph (NVIDIA#2098)

* support bf16+fp8 model

Signed-off-by: Robin Zhang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: Robin Zhang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: Robin Zhang <[email protected]>

---------

Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>

Dropout with 8-bit RNG (NVIDIA#2014)

* Add dropout kernel with 8-bit RNG

Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix license

Signed-off-by: Tim Moon <[email protected]>

* Avoid ambiguous types

Signed-off-by: Tim Moon <[email protected]>

* Do not enforce dropout prob is representable in 8 bits

Signed-off-by: Tim Moon <[email protected]>

* Expand error message

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warning

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unnecessary helper function in PyTorch extensions

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Create GPU reload buffers on main stream (NVIDIA#2131)

* Create GPU relaod buffers on main stream

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed typo

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Fixed typo

Signed-off-by: Selvaraj Anandaraj <[email protected]>

---------

Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Paweł Gadziński <[email protected]>

Fix CI failures for UB overlap changes (NVIDIA#2149)

Signed-off-by: djns99 <[email protected]>

[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (NVIDIA#2135)

* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell

Signed-off-by: Kshitij Lakhani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix the skip message

Signed-off-by: Kshitij Lakhani <[email protected]>

* Assert in fused attn bwd pass for sm100

Signed-off-by: Kshitij Lakhani <[email protected]>

Add check for sm100

Signed-off-by: Kshitij Lakhani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add support to get all devs in the process for jax

Signed-off-by: Kshitij Lakhani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Code clean up

Signed-off-by: Kshitij Lakhani <[email protected]>

* Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion

Signed-off-by: Kshitij Lakhani <[email protected]>

* Represent attn bias using enum instead of string

Signed-off-by: Kshitij Lakhani <[email protected]>

---------

Signed-off-by: Kshitij Lakhani <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (NVIDIA#2119)

* add noop to comp amax

Signed-off-by: zhongboz <[email protected]>

* fix for fp8 blockwise recipe

Signed-off-by: zhongboz <[email protected]>

* resolve comments

Signed-off-by: zhongboz <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: zhongboz <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>

[PyTorch] fix cross entropy vanishing gradients (NVIDIA#2139)

* fix cross entropy

Signed-off-by: Casper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Casper <[email protected]>

* fix comments

Signed-off-by: Casper <[email protected]>

* fix: few more style issues

Signed-off-by: Casper <[email protected]>

* fix: remove grad_output_stride (unnecessary)

Signed-off-by: Casper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: only backward was broken

Signed-off-by: Casper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Generalize cross entropy backward kernel to handle reduced and unreduced loss

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Casper <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>

Fix bug when enabling --overlap-grad-reduce in mcore (NVIDIA#2142)

* fix bugs when enabling --overlap-grad-reduce in mcore

Signed-off-by: Hongbin Liu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix CI

Signed-off-by: Hongbin Liu <[email protected]>

* format

Signed-off-by: Hongbin Liu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Hongbin Liu <[email protected]>
Co-authored-by: Hongbin Liu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Fix CUDA version in setup.py (NVIDIA#2132)

* Fix CUDA version in setup.py

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Re-enable building comm-gemm tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* WAR for nvidia-nvshmem package

Signed-off-by: Vladimir Cherepanov <[email protected]>

---------

Signed-off-by: Vladimir Cherepanov <[email protected]>
Co-authored-by: Tim Moon <[email protected]>

[JAX] NoScaleTensor wrapper for non-quantized data (NVIDIA#2136)

* Custom call tests passing

Signed-off-by: Jeremy Berchtold <[email protected]>

* Fix test_layer.py

Signed-off-by: Jeremy Berchtold <[email protected]>

* Lint

Signed-off-by: Jeremy Berchtold <[email protected]>

* Fix comments

Signed-off-by: Jeremy Berchtold <[email protected]>

* Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling

Signed-off-by: Jeremy Berchtold <[email protected]>

* Fix shardy issue with amax being shape 1,1,1 instead of shape (1,)

Signed-off-by: Jeremy Berchtold <[email protected]>

* Add higher-precision VJP tests to test_distributed_layernorm_mlp

Signed-off-by: Jeremy Berchtold <[email protected]>

* Cast non-quantized kernels to input dtype in VJPs

Signed-off-by: Jeremy Berchtold <[email protected]>

* Rename HighPrecisionTensor to NoScaleTensor

Signed-off-by: Jeremy Berchtold <[email protected]>

* Use NoScaleTensor in pure JAX impls where it was missing

Signed-off-by: Jeremy Berchtold <[email protected]>

* Fix tests

Signed-off-by: Jeremy Berchtold <[email protected]>

---------

Signed-off-by: Jeremy Berchtold <[email protected]>

[JAX] Fix GroupedScaledTensor creation with keyword arg (NVIDIA#2154)

Fix GroupedScaledTensor creation

Signed-off-by: Phuong Nguyen <[email protected]>

Fixing few issues with multi-process launching. (NVIDIA#2155)

* Fixing few issues with multi-process launching.

Signed-off-by: Ming Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Ming Huang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Phuong Nguyen <[email protected]>

Update list of authorized CI users (NVIDIA#2152)

Signed-off-by: Tim Moon <[email protected]>

Fused RoPE with combined QKV input. (NVIDIA#2122)

* Fused RoPE with combined QKV input.

Initial commit for Dropout with 8-bit RNG

Fix documentation

Initial commit for Fused QKV RoPE

WIP

Initial tests passing

Enable rotary percent and margin

Enable CP2, start_positions, interleaved

Cleanup test

Revert "Fix documentation"

This reverts commit 53df100.

Revert "Initial commit for Dropout with 8-bit RNG"

This reverts commit 301505e.

Cleanup.

Minor cleanup

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Optimize kernels

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Misc. Cleanup

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Optimize kernel performance

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Move fused_qkv_rope test to test_fused_rope.py

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* apply shared memory optimization to separate fused rope kernels

Signed-off-by: Xin Yao <[email protected]>

* fix lint

Signed-off-by: Xin Yao <[email protected]>

---------

Signed-off-by: Vasudevan Rengasamy <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xin Yao <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: vthumbe1503 <[email protected]>
Signed-off-by: vthumbe1503 <[email protected]>
Signed-off-by: vthumbe1503 <[email protected]>

Add bf16/fp32 token-per-expert to the MoE aux loss kernel (NVIDIA#2162)

* add bf16/fp32 token-per-expert on the moe-loss-computation on router fusion

Signed-off-by: tongliu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: tongliu <[email protected]>
Co-authored-by: tongliu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[JAX] Scale swizzling via JAX transpose op (NVIDIA#2163)

* add swizzle in jax

Signed-off-by: Phuong Nguyen <[email protected]>

* added outer_impl

Signed-off-by: Phuong Nguyen <[email protected]>

* clean up FFI

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>

Extract cpp distributed tests into a separate project (NVIDIA#2165)

* Extract cpp distributed tests into a separate project

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove obsolete exclusion

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Run L1_cpp_distributed tests if at least 4 GPUs

Signed-off-by: Vladimir Cherepanov <[email protected]>

---------

Signed-off-by: Vladimir Cherepanov <[email protected]>

Adds context parallelism utilities: moving cp shards to diff ranks and pad sequence to divisibility factory (NVIDIA#2129)

* test - adds unit test for cp utilities and the utilites

Signed-off-by: Jonathan Mitchell <[email protected]>

* assert line change

Signed-off-by: Jonathan Mitchell <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Jonathan Mitchell <[email protected]>
Co-authored-by: Jonathan Mitchell <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sudhakar Singh <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci

[PyTorch Debug] Fix issue with negative underflow% stat. (NVIDIA#2107)

* fix underflows log issue

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Pawel Gadzinski <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…d swiglu parameter correctly

Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

Lower precision gated-act to accelerate FP8 current-scaling. (#2153)

* Applying the original precision as Norm outputs' and activation compuations.

Signed-off-by: Ming Huang <[email protected]>

* Adding knob to control norm output precision.

Signed-off-by: Ming Huang <[email protected]>

* Removing the knob and applying lower-precision norm with current-scaling only.

Signed-off-by: Ming Huang <[email protected]>

* Fix the error when quantizer==None

Signed-off-by: Ming Huang <[email protected]>

---------

Signed-off-by: Ming Huang <[email protected]>

[PyTorch] Support activation CPU offloading in fusible ops (#2158)

* Add CPU offloading logic to ops. Fix test to compute dgrad.

Signed-off-by: Tim Moon <[email protected]>

* Make sure grads are contiguous in op backwards

Signed-off-by: Tim Moon <[email protected]>

* Add op-based MLP to CPU offloading tests

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Handle different weight cache behavior on Hopper/Blackwell

Add MXFP8 to CPU offload tests.

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove MXFP8 test

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Do not use normalization forward + amax fusion if cuDNN backend is requested (#2174)

* Do not use norm fwd + amax fusion if cudnn backend is requested

Signed-off-by: Jan Bielak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Read envirornment vairable directly to avoid include error

Signed-off-by: Jan Bielak <[email protected]>

---------

Signed-off-by: Jan Bielak <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Fix unjoined comm stream in UB communicator (#2160)

Signed-off-by: djns99 <[email protected]>

FP8 Output Quantization for GEMM (#2123)

* Test working as I think it should work

Signed-off-by: Varun Thumbe <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

* revert accidental change

Signed-off-by: Varun Thumbe <[email protected]>

Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

fix merge conflict

Signed-off-by: Varun Thumbe <[email protected]>

bug: missed a } in the code

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Change lib order to fix link error

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Context creation, incomplete...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Test fixure

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* A sanity AgGemm test, failing...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix axes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Take care of uneven distribution

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use MPI to get position of local matrices

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Refactor

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Refactor & fixes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Gemm-RS

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Gemm-AR, not working...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fixes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Setting all-reduce epilogue for gemm-ar

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use supported shapes for GEMM-AR

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak tolerance

Signed-off-by: Vladimir Cherepanov <[email protected]>

* First shot at fp8

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use TensorHolder in tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More test configs

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Support comm_sm_count

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Parametrize dtypes for A, B and D separately

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak scaling

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Amax ptr

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Flags parity with cublas_gemm, saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Cleanup

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Bias tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix bias test

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Aux, saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* aux_ld

Signed-off-by: Vladimir Cherepanov <[email protected]>

* A fix

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use test::Tensor

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Set scale inv

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove unsupported test configs

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Replace libcal with NCCL

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add NVTX markers to API functions

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak GemmAr tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More test config

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix merge fallout

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove MPI dependency, comment API, add algo parameter

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix nvshmem dependency

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix nvshmem build

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Excluse CommGemm tests from L0_cppunittest

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add cpp_distributed sh file for CI

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Adapt tp TensorAllocator

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Skip GemmAr test on unsupported HW

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Oversibscribe is needed on some clusters

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix incomplete libcal removal

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Move CI tests to L1

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Rename context to include NVTE prefix

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove leftover code

Signed-off-by: Vladimir Cherepanov <[email protected]>

* NVTE_WITH_CUBLASMP off by default

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More detailed NVTE_CHECK diag

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Comment API

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Include stdbool header for legacy C compilers

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove now unused argument

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Abstract away cuBLASMp algo behind our own enum

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More detailed shape diag messages

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update transformer_engine/common/include/transformer_engine/comm_gemm.h

Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add license

Signed-off-by: Vladimir Cherepanov <[email protected]>

---------

Signed-off-by: Vladimir Cherepanov <[email protected]>
Signed-off-by: Vladimir Cherepanov <[email protected]>
Co-authored-by: Vladimir Cherepanov <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.

Signed-off-by: Ming Huang <[email protected]>

* Slightly refactor

Signed-off-by: Ming Huang <[email protected]>

* Adding documents of new args.

Signed-off-by: Ming Huang <[email protected]>

* Adding unit-tests.

Signed-off-by: Ming Huang <[email protected]>

* Adding license.

Signed-off-by: Ming Huang <[email protected]>

* Move unit-tests to L1.

Signed-off-by: Ming Huang <[email protected]>

* Move quantizaer store/reset into FP8 only.

Signed-off-by: Ming Huang <[email protected]>

* Adding all layout support for Blackwell+

Signed-off-by: Ming Huang <[email protected]>

* Adopt the feedback from code-review.

Signed-off-by: Ming Huang <[email protected]>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.

Signed-off-by: Ming Huang <[email protected]>

---------

Signed-off-by: Ming Huang <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[JAX] Delay MeshResource validation until first usage (#2124)

Delay MeshResource validation until first usage

Signed-off-by: Jeremy Berchtold <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[JAX] Decouple Recipe and ScalingMode (#1728)

* Decouple recipe and scaling mode

Signed-off-by: Jeremy Berchtold <[email protected]>

* Expose global QuantizeConfig instance as a getter

Signed-off-by: Jeremy Berchtold <[email protected]>

* Format and lint

Signed-off-by: Jeremy Berchtold <[email protected]>

* Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling

Signed-off-by: Jeremy Berchtold <[email protected]>

* Rename UsageType to TensorSource

Signed-off-by: Jeremy Berchtold <[email protected]>

* Update test_layer.py

Signed-off-by: Jeremy Berchtold <[email protected]>

---------

Signed-off-by: Jeremy Berchtold <[email protected]>
Signed-off-by: jberchtold-nvidia <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)

* add amax input to DBiasQuantizePrimitive and FFI

Signed-off-by: Phuong Nguyen <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make sure amax is init with zero

Signed-off-by: Phuong Nguyen <[email protected]>

* fix sharding rule

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Varun Thumbe <[email protected]>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121)

Signed-off-by: Kshitij Lakhani <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

Temporarily remove comm_gemm tests (#2133)

Signed-off-by: Vladimir Cherepanov <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[PyTorch] Disable determinism for sm100 (#2130)

* disable determinism for sm100+ and cudnn<9.14

Signed-off-by: Charlene Yang <[email protected]>

* fix remaining CI failures

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert some changes

Signed-off-by: Charlene Yang <[email protected]>

* revert more changes

Signed-off-by: Charlene Yang <[email protected]>

* remove sm100 from determinism table

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Varun Thumbe <[email protected]>

[PyTorch] ONNX export of FP8 Current Scaling (#2068)

* Compute amax in normalization forward in current scaling in untuned kernels

Signed-off-by: Jan Bielak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* code drop

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* apply tims suggestions

Signed-off-by: Pawel Gadzinski <[email protected]>

---------

Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Co-authored-by: Jan Bielak <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Varun Thumbe <[email protected]>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134)

use torch empty for empty shape instead of from_blob

Signed-off-by: zhongboz <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

build: pull cached wheels (#2127)

* build: pull cached wheels

Signed-off-by: oliver könig <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update setup.py

Signed-off-by: oliver könig <[email protected]>

---------

Signed-off-by: oliver könig <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

feat: Add support for multiple quantization modes in the UB communicators (#2043)

Signed-off-by: Varun Thumbe <[email protected]>

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)

* add checks to cuda kernel launch and cuda API calls

Signed-off-by: Xin Yao <[email protected]>

* Remove exceptions from destructors

Signed-off-by: Tim Moon <[email protected]>

* fix weired dispatch in ln/rmsnorm

Signed-off-by: Xin Yao <[email protected]>

---------

Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

[PyTorch] Support bf16+fp8 cudagraph (#2098)

* support bf16+fp8 model

Signed-off-by: Robin Zhang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: Robin Zhang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: Robin Zhang <[email protected]>

---------

Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

Dropout with 8-bit RNG (#2014)

* Add dropout kernel with 8-bit RNG

Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix license

Signed-off-by: Tim Moon <[email protected]>

* Avoid ambiguous types

Signed-off-by: Tim Moon <[email protected]>

* Do not enforce dropout prob is representable in 8 bits

Signed-off-by: Tim Moon <[email protected]>

* Expand error message

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warning

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unnecessary helper function in PyTorch extensions

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Varun Thumbe <[email protected]>

Create GPU reload buffers on main stream (#2131)

* Create GPU relaod buffers on main stream

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed typo

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Fixed typo

Signed-off-by: Selvaraj Anandaraj <[email protected]>

---------

Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Paweł Gadziński <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>

mxfp8 unfused quant support, refined unit test, remove unecessary quantization code

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

missed a quant code removal

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

minor bug fix

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Change lib order to fix link error

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Context creation, incomplete...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Test fixure

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* A sanity AgGemm test, failing...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix axes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Take care of uneven distribution

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use MPI to get position of local matrices

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Refactor

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Refactor & fixes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Gemm-RS

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Gemm-AR, not working...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fixes

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Setting all-reduce epilogue for gemm-ar

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use supported shapes for GEMM-AR

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak tolerance

Signed-off-by: Vladimir Cherepanov <[email protected]>

* First shot at fp8

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use TensorHolder in tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More test configs

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Support comm_sm_count

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Parametrize dtypes for A, B and D separately

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak scaling

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Amax ptr

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Flags parity with cublas_gemm, saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Cleanup

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Bias tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix bias test

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Aux, saving...

Signed-off-by: Vladimir Cherepanov <[email protected]>

* aux_ld

Signed-off-by: Vladimir Cherepanov <[email protected]>

* A fix

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Use test::Tensor

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Set scale inv

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove unsupported test configs

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Replace libcal with NCCL

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add NVTX markers to API functions

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Tweak GemmAr tests

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More test config

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix merge fallout

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove MPI dependency, comment API, add algo parameter

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix nvshmem dependency

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix nvshmem build

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Excluse CommGemm tests from L0_cppunittest

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add cpp_distributed sh file for CI

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Adapt tp TensorAllocator

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Skip GemmAr test on unsupported HW

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Oversibscribe is needed on some clusters

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Fix incomplete libcal removal

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Move CI tests to L1

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Rename context to include NVTE prefix

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove leftover code

Signed-off-by: Vladimir Cherepanov <[email protected]>

* NVTE_WITH_CUBLASMP off by default

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More detailed NVTE_CHECK diag

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Comment API

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Include stdbool header for legacy C compilers

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Remove now unused argument

Signed-off-by: Vladimir Cherepanov <[email protected]>

* Abstract away cuBLASMp algo behind our own enum

Signed-off-by: Vladimir Cherepanov <[email protected]>

* More detailed shape diag messages

Signed-off-by: Vladimir Cherepanov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update transformer_engine/common/include/transformer_engine/comm_gemm.h

Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Vladimir Cherepanov <[email protected]>

* Add license

Signed-off-by: Vladimir Cherepanov <[email protected]>

---------

Signed-off-by: Vladimir Cherepanov <[email protected]>
Signed-off-by: Vladimir Cherepanov <[email protected]>
Co-authored-by: Vladimir Cherepanov <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Przemyslaw Tredak <[email protected]>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.

Signed-off-by: Ming Huang <[email protected]>

* Slightly refactor

Signed-off-by: Ming Huang <[email protected]>

* Adding documents of new args.

Signed-off-by: Ming Huang <[email protected]>

* Adding unit-tests.

Signed-off-by: Ming Huang <[email protected]>

* Adding license.

Signed-off-by: Ming Huang <[email protected]>

* Move unit-tests to L1.

Signed-off-by: Ming Huang <[email protected]>

* Move quantizaer store/reset into FP8 only.

Signed-off-by: Ming Huang <[email protected]>

* Adding all layout support for Blackwell+

Signed-off-by: Ming Huang <[email protected]>

* Adopt the feedback from code-review.

Signed-off-by: Ming Huang <[email protected]>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.

Signed-off-by: Ming Huang <[email protected]>

---------

Signed-off-by: Ming Huang <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>

[JAX] Delay MeshResource validation until first usage (#2124)

Delay MeshResource validation until first usage

Signed-off-by: Jeremy Berchtold <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>

[JAX] Decouple Recipe and ScalingMode (#1728)

* Decouple recipe and scaling mode

Signed-off-by: Jeremy Berchtold <[email protected]>

* Expose global QuantizeConfig instance as a getter

Signed-off-by: Jeremy Berchtold <[email protected]>

* Format and lint

Signed-off-by: Jeremy Berchtold <[email protected]>

* Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling

Signed-off-by: Jeremy Berchtold <[email protected]>

* Rename UsageType to TensorSource

Signed-off-by: Jeremy Berchtold <[email protected]>

* Update test_layer.py

Signed-off-by: Jeremy Berchtold <[email protected]>

---------

Signed-off-by: Jeremy Berchtold <[email protected]>
Signed-off-by: jberchtold-nvidia <[email protected]>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)

* add amax input to DBiasQuantizePrimitive and FFI

Signed-off-by: Phuong Nguyen <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make sure amax is init with zero

Signed-off-by: Phuong Nguyen <[email protected]>

* fix sharding rule

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121)

Signed-off-by: Kshitij Lakhani <[email protected]>

Temporarily remove comm_gemm tests (#2133)

Signed-off-by: Vladimir Cherepanov <[email protected]>

[PyTorch] Disable determinism for sm100 (#2130)

* disable determinism for sm100+ and cudnn<9.14

Signed-off-by: Charlene Yang <[email protected]>

* fix remaining CI failures

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert some changes

Signed-off-by: Charlene Yang <[email protected]>

* revert more changes

Signed-off-by: Charlene Yang <[email protected]>

* remove sm100 from determinism table

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch] ONNX export of FP8 Current Scaling (#2068)

* Compute amax in normalization forward in current scaling in untuned kernels

Signed-off-by: Jan Bielak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* code drop

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* apply tims suggestions

Signed-off-by: Pawel Gadzinski <[email protected]>

---------

Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Co-authored-by: Jan Bielak <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134)

use torch empty for empty shape instead of from_blob

Signed-off-by: zhongboz <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>

build: pull cached wheels (#2127)

* build: pull cached wheels

Signed-off-by: oliver könig <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update setup.py

Signed-off-by: oliver könig <[email protected]>

---------

Signed-off-by: oliver könig <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>

feat: Add support for multiple quantization modes in the UB communicators (#2043)

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)

* add checks to cuda kernel launch and cuda API calls

Signed-off-by: Xin Yao <[email protected]>

* Remove exceptions from destructors

Signed-off-by: Tim Moon <[email protected]>

* fix weired dispatch in ln/rmsnorm

Signed-off-by: Xin Yao <[email protected]>

---------

Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>

[PyTorch] Support bf16+fp8 cudagraph (#2098)

* support bf16+fp8 model

Signed-off-by: Robin Zhang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: Robin Zhang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: Robin Zhang <[email protected]>

---------

Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>

Dropout with 8-bit RNG (#2014)

* Add dropout kernel with 8-bit RNG

Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix license

Signed-off-by: Tim Moon <[email protected]>

* Avoid ambiguous types

Signed-off-by: Tim Moon <[email protected]>

* Do not enforce dropout prob is representable in 8 bits

Signed-off-by: Tim Moon <[email protected]>

* Expand error message

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warning

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unnecessary helper function in PyTorch extensions

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Create GPU reload buffers on main stream (#2131)

* Create GPU relaod buffers on main stream

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed typo

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Fixed typo

Signed-off-by: Selvaraj Anandaraj <[email protected]>

---------

Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Paweł Gadziński <[email protected]>

minor code cleanup

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

minor cosmetics

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Address review comment

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

minor comment update

Signed-off-by: Varun Thumbe <[email protected]>

Fix CI failures for UB overlap changes (#2149)

Signed-off-by: djns99 <[email protected]>

minor bug: quantizer should not be none for unfused quantization

Signed-off-by: Varun Thumbe <[email protected]>

[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135)

* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell

Signed-off-by: Kshitij Lakhani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix the skip message

Signed-off-by: Kshitij Lakhani <[email protected]>

* Assert in fused attn bwd pass for sm100

Signed-off-by: Kshitij Lakhani <[email protected]>

Add check for sm100

Signed-off-by: Kshitij Lakhani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add support to get all devs in the process for jax

Signed-off-by: Kshitij Lakhani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Code clean up

Signed-off-by: Kshitij Lakhani <[email protected]>

* Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion

Signed-off-by: Kshitij Lakhani <[email protected]>

* Represent attn bias using enum instead of string

Signed-off-by: Kshitij Lakhani <[email protected]>

---------

Signed-off-by: Kshitij Lakhani <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

fix linting error

Signed-off-by: Varun Thumbe <[email protected]>

[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119)

* add noop to comp amax

Signed-off-by: zhongboz <[email protected]>

* fix for fp8 blockwise recipe

Signed-off-by: zhongboz <[email protected]>

* resolve comments

Signed-off-by: zhongboz <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: zhongboz <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>

address review comments

Signed-off-by: Varun Thumbe <[email protected]>

* Update test_multi_process_distributed_grouped_gemm.py

change accidentally added while merging

Signed-off-by: vthumbe1503 <[email protected]>

* Update dense.py

change accidentally added while merging

Signed-off-by: vthumbe1503 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address review comments

Signed-off-by: Varun Thumbe <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address revie comments

Signed-off-by: Varun Thumbe <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Bug solved: delayed scaling quantization with mxfp8 inputs didnt work

Signed-off-by: Varun Thumbe <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the unit test error

Signed-off-by: Varun Thumbe <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* just to trigger ci

Signed-off-by: Varun Thumbe <[email protected]>

* address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation

Signed-off-by: Varun Thumbe <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

* fix merge conflict

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: vthumbe1503 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

TE Gemma tutorial attempt#2 (#1839)

* add tutorial files and other local changes

Signed-off-by: Sudhakar Singh <[email protected]>

* remove extraneous code for easy debu

Signed-off-by: Sudhakar Singh <[email protected]>

* make cuda graphs work with non-paged and paged attention

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* perf imp for kv cache ops

Signed-off-by: Sudhakar Singh <[email protected]>

* add code for calibration

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* optimize kv_cache reindex and copy kernels

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* changes to make quantizers work with fp8_calibration

Signed-off-by: Sudhakar Singh <[email protected]>

* avoid reindexing from python side

Signed-off-by: Charlene Yang <[email protected]>

* rename variable from previous commit

Signed-off-by: Charlene Yang <[email protected]>

* minor fix

Signed-off-by: Charlene Yang <[email protected]>

* minor fix

Signed-off-by: Charlene Yang <[email protected]>

* use quantizer only if needed

Signed-off-by: Sudhakar Singh <[email protected]>

* functionality of the tutorial tested and perf checked

Signed-off-by: Sudhakar Singh <[email protected]>

* remove files and update headers/licenses

Signed-off-by: Sudhakar Singh <[email protected]>

* update header/license

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update tutorial for review

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make weights downloadable on the fly; remove extra print statements

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint and update comments

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add comma back, typo

Signed-off-by: Sudhakar Singh <[email protected]>

* sequence_start_positions should be None for training

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add paged attention numberes and update requirements.txt file

Signed-off-by: Sudhakar Singh <[email protected]>

* more fixes

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make tutorial work on blackwell

Signed-off-by: Sudhakar Singh <[email protected]>

* remove gemma FT tutorial for now

Signed-off-by: Sudhakar Singh <[email protected]>

* fixing the headings placement and rewording attention -> kv caching

Signed-off-by: Sudhakar Singh <[email protected]>

* fixes from comments

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the images

Signed-off-by: Sudhakar Singh <[email protected]>

* misc fixes

Signed-off-by: Sudhakar Singh <[email protected]>

* add more comments to te_gemma.py and cleanup utils.py

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add more information about the hierarchy of the classes used in the tutorial

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add better cuda graphs picture

Signed-off-by: Sudhakar Singh <[email protected]>

* addd updated cuda graphs pictures

Signed-off-by: Sudhakar Singh <[email protected]>

* add illustrated cuda graphs

Signed-off-by: Sudhakar Singh <[email protected]>

* fix

Signed-off-by: Sudhakar Singh <[email protected]>

* small fixes in documentation

Signed-off-by: Sudhakar Singh <[email protected]>

* add torch.no_grad() to force reduced memory usage

Signed-off-by: Sudhakar Singh <[email protected]>

* some fixes from recent comments

Signed-off-by: Sudhakar Singh <[email protected]>

* more fixes from remaining comments

Signed-off-by: Sudhakar Singh <[email protected]>

* add te_rope_emb to class desc

Signed-off-by: Sudhakar Singh <[email protected]>

* fix tutorial wording; add calibration fix to grouped_linear.py

Signed-off-by: Sudhakar Singh <[email protected]>

---------

Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Charlene Yang <[email protected]>

Fix memory overhead of linear layer when all gather from sequence parallel (#2125)

* fix memory overhead of all gather from sequence parallel

Signed-off-by: Yuzhong Wang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py

Signed-off-by: Tim Moon <[email protected]>

* quick fix the errors when for UB buffers

Signed-off-by: Yuzhong Wang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update transformer_engine/pytorch/module/linear.py

Signed-off-by: Tim Moon <[email protected]>

* Avoid deallocating FP8 scale-invs since they are reused

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Yuzhong Wang <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>

Fix incorrect TP rank calculation when using data parallel (#2179)

Signed-off-by: djns99 <[email protected]>

[Pytorch] Add Cutlass Grouped GEMM Support for fine-grained MoE Model (#2045)

* feat: add cutlass group gemm support

Signed-off-by: Min Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor: refactor multi tensor gemm interface

Signed-off-by: Min Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor: refactor nvte_multi_stream_cublas_gemm func and add license info

Signed-off-by: Min Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: add unit test for cutlass group gemm

Signed-off-by: Min Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: add cutlass support type protect

Signed-off-by: Min Yang <[email protected]>

* add tests and fix lint

Signed-off-by: Xin Yao <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: fix unit tests error

Signed-off-by: Min Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: refactor host workspace malloc

Signed-off-by: Min Yang <[email protected]>

* update cutlass

Signed-off-by: Xin Yao <[email protected]>

* update cutlass

Signed-off-by: Xin Yao <[email protected]>

* further relex threshold and add a env var to warn fall back

Signed-off-by: Xin Yao <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Min Yang <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: alan yang <[email protected]>
Co-authored-by: Min Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xin Yao <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>

[PyTorch] Support FA3 for MLA and with CP (#1907)

feature(FA3,MLA,CP):
1. Update FA3 to commit-id 3ba6f82 (tag 2.8.0.post2 with compile error fixed), PR-1604 support hdimQK != hdimV backward
2. Update get_attention_backend method because FA3 support MLA now
3. Add CP MLA support for FA3
4. Add unit tests for FA3 MLA CP
5. Update attention doc

Signed-off-by: zhujian <[email protected]>

Fix cuDNN version checks when getting backend and for sm89 kv cache (#2185)

* Fix cudnn version checks for kv cache for sm89. Add cudnn version check in preparation for 9.14 when getting backend

Signed-off-by: Kshitij Lakhani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor fix for cuDNN version condition check

Signed-off-by: Kshitij Lakhani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Kshitij Lakhani <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@jberchtold-nvidia
Copy link
Collaborator

/te-ci L2 jax

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good pending CI, left some comments and questions, mostly minor things. Thanks for updating the Flax modules as well, great work on this PR!

@jberchtold-nvidia
Copy link
Collaborator

/te-ci L2 jax

@phu0ngng phu0ngng changed the title Clamped Swiglu Integration to JAX [JAX] Clamped Swiglu Integration Sep 23, 2025
Copy link
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
I left some minor comments. LGTM otherwise. Thanks!



@dataclass(frozen=True)
class ClampedSwigluParams:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, let's move this class to cpp_extensions.py and expose it as a part of the tex.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

He had it there originally and I told him to move it here. Since this param will be user facing, they'll need to create it themselves and import from cpp_extensions. My understanding is that we wanted to encapsulate and hide those internals from users of the higher-level VJPs or Flax layers, so I suggested moving it here.

But do you think it's more important to prevent the dependency from cpp_extensions on this type in the higher-level VJP even tho users will need to use cpp_extensions types directly? Lmk what you think, thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reasons I suggested moving it to the tex are that:

  • Users don't really use this activation VJP, so they will still need to import from this file even if they don't use other stuff.
  • We have other Params/Configs that are placed in the tex (i.e., AmaxScope is in cpp_extensions/quantization.py, CollectiveOp is in cpp_extensions/gemm.py).

At the same time, I agree that exposing it here may be a cleaner approach. If we decide to keep it here, perhaps we should start a follow-up work on which functions/symbols we want to export from each VJP file so that users can do import transformer_engine.jax as te_jax.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'm okay with reverting back to having it in cpp_extensions (tex) for now then to be consistent with other structures until we decide on a way to organize these structures. Apologies for the churn, Varun

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No issues @jberchtold-nvidia


def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]:
def _jax_act_lu(
inputs, activation_type, quantizer=None, act_params: Optional[ClampedSwigluParams] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, should we make a default ActParam instead of None here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually following generally safer practice of not providing default values other than None for user defined data type in the function definition. Although ActivationParams Object today is immutable, it might be dangerous in future in case it isnt anymore. Since the default values created are just instantiated and shared during import time. And accidental change of the param in one function can impact the other which can be harder to debug. Although this can be solved with default factory and such a things to happen is very unrealistic, so let me know what you think.

Comment on lines +41 to +44
struct ClampedSwigluConfig {
float limit;
float alpha;
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
Since we will lower this struct for all activations (even if we don't use it), should we have a more general name for it instead?
Something like:

struct ActParams {
  float swiglu_limit;
  float swiglu_alpha;
 };

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Varun made a different version here with a generic outer struct and a nested inner struct for ClampedSwiglu, which I believe is what you have in mind. I haven't reviewed this other version yet, planning to do it today
vthumbe1503#2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have moved the generic approach to this PR itself

@vthumbe1503 vthumbe1503 requested a review from Copilot September 23, 2025 23:09
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR integrates Clamped SwiGLU activation into the JAX backend, adding support for the GPT-OSS variant of SwiGLU that includes parameter-based clamping and scaling functionality.

  • Adds ClampedSwiGLU activation implementation across PyTorch, JAX, and C++ backends
  • Implements parameter passing infrastructure for activation functions requiring additional parameters
  • Extends existing activation framework to support parameterized activations

Reviewed Changes

Copilot reviewed 23 out of 23 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
transformer_engine/pytorch/ops/basic/activation.py Adds ClampedSwiGLU class with limit and alpha parameters
transformer_engine/pytorch/ops/basic/init.py Exports ClampedSwiGLU activation
transformer_engine/pytorch/csrc/extensions/pybind.cpp Exposes clamped_swiglu functions to Python
transformer_engine/pytorch/csrc/extensions/activation.cpp Refactors activation helpers to support parameterized functions
transformer_engine/pytorch/csrc/extensions.h Declares clamped activation function signatures
transformer_engine/jax/layernorm_mlp.py Adds activation_params parameter to layernorm_mlp
transformer_engine/jax/flax/transformer.py Adds mlp_activation_params to TransformerLayer
transformer_engine/jax/flax/module.py Extends LayerNormMLP with activation_params support
transformer_engine/jax/csrc/extensions/pybind.cpp Adds CLAMPED_SWIGLU enum value
transformer_engine/jax/csrc/extensions/activation.cpp Implements clamped activation FFI handlers
transformer_engine/jax/csrc/extensions.h Defines ActivationConfig structs
transformer_engine/jax/cpp_extensions/activation.py Implements ActivationParams framework and clamped functions
transformer_engine/jax/activation.py Updates activation functions to accept parameters
transformer_engine/common/util/vectorized_pointwise.h Updates kernels for parameterized activations
transformer_engine/common/util/math.h Adds clamped SiLU mathematical functions
transformer_engine/common/util/cast_gated_kernels.cuh Updates gated kernels for parameter support
transformer_engine/common/include/transformer_engine/activation.h Adds clamped activation API declarations
transformer_engine/common/activation/swiglu.cu Implements nvte_clamped_swiglu functions
transformer_engine/common/activation/relu.cu Updates function calls with parameters
transformer_engine/common/activation/gelu.cu Updates function calls with parameters
transformer_engine/common/activation/activation_template.h Updates templates for parameter passing
tests/pytorch/test_fusible_ops.py Adds ClampedSwiGLU test cases
tests/jax/test_custom_call_compute.py Updates tests for parameterized activations

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@vthumbe1503 vthumbe1503 force-pushed the gpt-oss-jax branch 6 times, most recently from e81a0e1 to 6675779 Compare September 24, 2025 00:47
Signed-off-by: Varun Thumbe <[email protected]>
vthumbe1503 and others added 9 commits September 23, 2025 18:41
Signed-off-by: vthumbe1503 <[email protected]>
Add documentation for quantization function parameters and return value.

Signed-off-by: vthumbe1503 <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L2 jax

@vthumbe1503
Copy link
Collaborator Author

/te-ci L2 jax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants