-
Notifications
You must be signed in to change notification settings - Fork 525
Add cuBLASMp-backed GEMM-like API to TE common #1824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
69 commits
Select commit
Hold shift + click to select a range
f8e5415
Pick up cuBLASMp during build
vcherepanov-nv b7a55cc
Saving...
vcherepanov-nv 1d104e8
Change lib order to fix link error
vcherepanov-nv d03d95d
Saving...
vcherepanov-nv 73d7dcc
Context creation, incomplete...
vcherepanov-nv 93f721a
Test fixure
vcherepanov-nv 77b2f48
Saving...
vcherepanov-nv 0b94258
A sanity AgGemm test, failing...
vcherepanov-nv da9e1f5
Saving...
vcherepanov-nv 3f78470
Fix axes
vcherepanov-nv 72c6b5a
Take care of uneven distribution
vcherepanov-nv 95b459f
Use MPI to get position of local matrices
vcherepanov-nv ba7566d
Refactor
vcherepanov-nv d7f0e77
Refactor & fixes
vcherepanov-nv 5a87fac
Saving...
vcherepanov-nv 08178de
Gemm-RS
vcherepanov-nv f5f1cef
Gemm-AR, not working...
vcherepanov-nv fb32a33
Fixes
vcherepanov-nv e29ca02
Setting all-reduce epilogue for gemm-ar
vcherepanov-nv a44b787
Use supported shapes for GEMM-AR
vcherepanov-nv 29dfd6e
Tweak tolerance
vcherepanov-nv b66e2b3
First shot at fp8
vcherepanov-nv cf42791
Use TensorHolder in tests
vcherepanov-nv 675a84c
More test configs
vcherepanov-nv 1c6cf58
Support comm_sm_count
vcherepanov-nv 39c42ff
Parametrize dtypes for A, B and D separately
vcherepanov-nv 2e3c468
Tweak scaling
vcherepanov-nv 51d618c
Amax ptr
vcherepanov-nv fa7418f
Flags parity with cublas_gemm, saving...
vcherepanov-nv 6b5cf31
Cleanup
vcherepanov-nv ae3b95e
Bias tests
vcherepanov-nv bac5306
Fix bias test
vcherepanov-nv d933220
Aux, saving...
vcherepanov-nv 25f3f64
aux_ld
vcherepanov-nv ae0e022
A fix
vcherepanov-nv 25a3f0d
Use test::Tensor
vcherepanov-nv 6680f6d
Set scale inv
vcherepanov-nv b660497
Remove unsupported test configs
vcherepanov-nv 7df803d
Tweak tests
vcherepanov-nv 05279e7
Replace libcal with NCCL
vcherepanov-nv 08bf8b5
Add NVTX markers to API functions
vcherepanov-nv c86169b
Tweak GemmAr tests
vcherepanov-nv f7fa07f
More test config
vcherepanov-nv cbb6040
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a1c673e
Fix merge fallout
vcherepanov-nv e9d0c55
Remove MPI dependency, comment API, add algo parameter
vcherepanov-nv 2da71ac
Fix nvshmem dependency
vcherepanov-nv 3cd3c27
Fix nvshmem build
vcherepanov-nv 43e23f3
Excluse CommGemm tests from L0_cppunittest
vcherepanov-nv 402d69a
Add cpp_distributed sh file for CI
vcherepanov-nv 711d91a
Adapt tp TensorAllocator
vcherepanov-nv 76a0a55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e430603
Skip GemmAr test on unsupported HW
vcherepanov-nv 020b97c
Oversibscribe is needed on some clusters
vcherepanov-nv c9c63f6
Fix incomplete libcal removal
vcherepanov-nv eb994c4
Move CI tests to L1
vcherepanov-nv ed46865
Rename context to include NVTE prefix
vcherepanov-nv 1018cac
Remove leftover code
vcherepanov-nv dff0827
NVTE_WITH_CUBLASMP off by default
vcherepanov-nv 67a0294
More detailed NVTE_CHECK diag
vcherepanov-nv db2d304
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8d50768
Comment API
vcherepanov-nv 9c7e08d
Include stdbool header for legacy C compilers
vcherepanov-nv 7055436
Remove now unused argument
vcherepanov-nv 60059bb
Abstract away cuBLASMp algo behind our own enum
vcherepanov-nv 17e7499
More detailed shape diag messages
vcherepanov-nv 152c792
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 50fef0a
Update transformer_engine/common/include/transformer_engine/comm_gemm.h
mk-61 b111180
Add license
vcherepanov-nv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
set -e | ||
|
||
# Find TE | ||
: ${TE_PATH:=/opt/transformerengine} | ||
TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}') | ||
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH | ||
|
||
cd $TE_PATH/tests/cpp | ||
cmake -GNinja -S. -Bbuild | ||
cmake --build build | ||
mpirun --allow-run-as-root --np 4 --oversubscribe ./build/comm_gemm/test_comm_gemm |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
add_executable(test_comm_gemm | ||
test_comm_gemm.cu | ||
../test_common.cu) | ||
|
||
find_package(OpenMP REQUIRED) | ||
find_package(MPI REQUIRED) | ||
find_library(NCCL_LIB | ||
NAMES nccl libnccl | ||
PATH_SUFFIXES lib | ||
REQUIRED) | ||
target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) | ||
target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc CUDNN::cudnn MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) | ||
|
||
include(GoogleTest) | ||
gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine as a quick hack, but if we add more L1 tests we should have separate build processes for L0 and L1 tests. This way the L0 tests don't need to build the L1 tests or require MPI. It also removes the need for the hacky regex exclusions in the QA script.