Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
f8e5415
Pick up cuBLASMp during build
vcherepanov-nv Feb 5, 2025
b7a55cc
Saving...
vcherepanov-nv Feb 6, 2025
1d104e8
Change lib order to fix link error
vcherepanov-nv Mar 20, 2025
d03d95d
Saving...
vcherepanov-nv Mar 20, 2025
73d7dcc
Context creation, incomplete...
vcherepanov-nv Mar 22, 2025
93f721a
Test fixure
vcherepanov-nv Mar 26, 2025
77b2f48
Saving...
vcherepanov-nv Apr 4, 2025
0b94258
A sanity AgGemm test, failing...
vcherepanov-nv Apr 7, 2025
da9e1f5
Saving...
vcherepanov-nv Apr 7, 2025
3f78470
Fix axes
vcherepanov-nv Apr 9, 2025
72c6b5a
Take care of uneven distribution
vcherepanov-nv Apr 10, 2025
95b459f
Use MPI to get position of local matrices
vcherepanov-nv Apr 10, 2025
ba7566d
Refactor
vcherepanov-nv Apr 11, 2025
d7f0e77
Refactor & fixes
vcherepanov-nv Apr 13, 2025
5a87fac
Saving...
vcherepanov-nv Apr 14, 2025
08178de
Gemm-RS
vcherepanov-nv Apr 14, 2025
f5f1cef
Gemm-AR, not working...
vcherepanov-nv Apr 14, 2025
fb32a33
Fixes
vcherepanov-nv Apr 14, 2025
e29ca02
Setting all-reduce epilogue for gemm-ar
vcherepanov-nv Apr 14, 2025
a44b787
Use supported shapes for GEMM-AR
vcherepanov-nv Apr 15, 2025
29dfd6e
Tweak tolerance
vcherepanov-nv Apr 15, 2025
b66e2b3
First shot at fp8
vcherepanov-nv Apr 18, 2025
cf42791
Use TensorHolder in tests
vcherepanov-nv Apr 18, 2025
675a84c
More test configs
vcherepanov-nv Apr 19, 2025
1c6cf58
Support comm_sm_count
vcherepanov-nv Apr 19, 2025
39c42ff
Parametrize dtypes for A, B and D separately
vcherepanov-nv Apr 21, 2025
2e3c468
Tweak scaling
vcherepanov-nv Apr 23, 2025
51d618c
Amax ptr
vcherepanov-nv Apr 23, 2025
fa7418f
Flags parity with cublas_gemm, saving...
vcherepanov-nv Apr 24, 2025
6b5cf31
Cleanup
vcherepanov-nv Apr 25, 2025
ae3b95e
Bias tests
vcherepanov-nv Apr 25, 2025
bac5306
Fix bias test
vcherepanov-nv Apr 25, 2025
d933220
Aux, saving...
vcherepanov-nv Apr 25, 2025
25f3f64
aux_ld
vcherepanov-nv Apr 28, 2025
ae0e022
A fix
vcherepanov-nv May 1, 2025
25a3f0d
Use test::Tensor
vcherepanov-nv May 3, 2025
6680f6d
Set scale inv
vcherepanov-nv May 3, 2025
b660497
Remove unsupported test configs
vcherepanov-nv May 5, 2025
7df803d
Tweak tests
vcherepanov-nv May 5, 2025
05279e7
Replace libcal with NCCL
vcherepanov-nv May 6, 2025
08bf8b5
Add NVTX markers to API functions
vcherepanov-nv May 7, 2025
c86169b
Tweak GemmAr tests
vcherepanov-nv May 14, 2025
f7fa07f
More test config
vcherepanov-nv May 16, 2025
cbb6040
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2025
a1c673e
Fix merge fallout
vcherepanov-nv May 27, 2025
e9d0c55
Remove MPI dependency, comment API, add algo parameter
vcherepanov-nv Jun 2, 2025
2da71ac
Fix nvshmem dependency
vcherepanov-nv Jun 4, 2025
3cd3c27
Fix nvshmem build
vcherepanov-nv Jul 18, 2025
43e23f3
Excluse CommGemm tests from L0_cppunittest
vcherepanov-nv Jul 20, 2025
402d69a
Add cpp_distributed sh file for CI
vcherepanov-nv Jul 20, 2025
711d91a
Adapt tp TensorAllocator
vcherepanov-nv Jul 22, 2025
76a0a55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2025
e430603
Skip GemmAr test on unsupported HW
vcherepanov-nv Jul 23, 2025
020b97c
Oversibscribe is needed on some clusters
vcherepanov-nv Jul 24, 2025
c9c63f6
Fix incomplete libcal removal
vcherepanov-nv Jul 28, 2025
eb994c4
Move CI tests to L1
vcherepanov-nv Jul 28, 2025
ed46865
Rename context to include NVTE prefix
vcherepanov-nv Aug 18, 2025
1018cac
Remove leftover code
vcherepanov-nv Aug 18, 2025
dff0827
NVTE_WITH_CUBLASMP off by default
vcherepanov-nv Aug 18, 2025
67a0294
More detailed NVTE_CHECK diag
vcherepanov-nv Aug 19, 2025
db2d304
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2025
8d50768
Comment API
vcherepanov-nv Aug 19, 2025
9c7e08d
Include stdbool header for legacy C compilers
vcherepanov-nv Aug 19, 2025
7055436
Remove now unused argument
vcherepanov-nv Aug 22, 2025
60059bb
Abstract away cuBLASMp algo behind our own enum
vcherepanov-nv Aug 22, 2025
17e7499
More detailed shape diag messages
vcherepanov-nv Aug 22, 2025
152c792
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2025
50fef0a
Update transformer_engine/common/include/transformer_engine/comm_gemm.h
mk-61 Aug 25, 2025
b111180
Add license
vcherepanov-nv Aug 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion qa/L0_cppunittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ cd $TE_PATH/tests/cpp
cmake -GNinja -Bbuild .
cmake --build build
export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS))
ctest --test-dir build -j$NUM_PARALLEL_JOBS
ctest --test-dir build -j$NUM_PARALLEL_JOBS -E '(AgGemm|GemmRs|GemmAr)'
15 changes: 15 additions & 0 deletions qa/L1_cpp_distributed/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

set -e

# Find TE
: ${TE_PATH:=/opt/transformerengine}
TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}')
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH

cd $TE_PATH/tests/cpp
cmake -GNinja -S. -Bbuild
cmake --build build
mpirun --allow-run-as-root --np 4 --oversubscribe ./build/comm_gemm/test_comm_gemm
13 changes: 13 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""Installation script."""

from importlib import metadata
import os
import time
from pathlib import Path
Expand Down Expand Up @@ -66,6 +67,18 @@ def setup_common_extension() -> CMakeExtension:
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))):
cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON")
cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution(
"nvidia-cublasmp-cu12"
).locate_file("nvidia/cublasmp/cu12")
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
"nvidia-nvshmem-cu12"
).locate_file("nvidia/nvshmem")
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
print("CMAKE_FLAGS:", cmake_flags[-2:])

# Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
if nvte_cmake_extra_args:
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_
message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include)
include_directories(../../transformer_engine/common)
include_directories(../../transformer_engine)
include_directories(${CMAKE_SOURCE_DIR})

find_package(CUDAToolkit REQUIRED)
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)

add_subdirectory(comm_gemm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine as a quick hack, but if we add more L1 tests we should have separate build processes for L0 and L1 tests. This way the L0 tests don't need to build the L1 tests or require MPI. It also removes the need for the hacky regex exclusions in the QA script.

add_subdirectory(operator)
add_subdirectory(util)
19 changes: 19 additions & 0 deletions tests/cpp/comm_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

add_executable(test_comm_gemm
test_comm_gemm.cu
../test_common.cu)

find_package(OpenMP REQUIRED)
find_package(MPI REQUIRED)
find_library(NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED)
target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include)
target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc CUDNN::cudnn MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX)

include(GoogleTest)
gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600)
Loading