From 8a3c63e177c547308d4ab5ac5fc91447ecaa30eb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Mar 2025 15:58:18 +0100 Subject: [PATCH 1/3] Raw files without cutlass. --- .gitmodules | 3 + Cargo.lock | 74 +- Cargo.toml | 11 +- backends/candle/Cargo.toml | 15 +- candle-extensions/candle-cublaslt/Cargo.toml | 3 +- .../candle-flash-attn-v1/.gitignore | 3 + .../candle-flash-attn-v1/.gitmodules | 3 + .../candle-flash-attn-v1/Cargo.toml | 23 + .../candle-flash-attn-v1/LICENSE-APACHE | 201 ++ .../candle-flash-attn-v1/LICENSE-MIT | 23 + .../candle-flash-attn-v1/README.md | 4 + .../candle-flash-attn-v1/build.rs | 258 +++ .../candle-flash-attn-v1/cutlass | 1 + .../candle-flash-attn-v1/kernels/flash_api.cu | 114 ++ .../candle-flash-attn-v1/kernels/fmha.h | 153 ++ .../candle-flash-attn-v1/kernels/fmha/gemm.h | 451 +++++ .../kernels/fmha/gmem_tile.h | 554 ++++++ .../kernels/fmha/kernel_traits.h | 116 ++ .../candle-flash-attn-v1/kernels/fmha/mask.h | 90 + .../kernels/fmha/smem_tile.h | 1703 +++++++++++++++++ .../kernels/fmha/softmax.h | 607 ++++++ .../candle-flash-attn-v1/kernels/fmha/utils.h | 1215 ++++++++++++ .../candle-flash-attn-v1/kernels/fmha_api.cpp | 275 +++ .../kernels/fmha_fprop_kernel_1xN.h | 706 +++++++ .../kernels/fmha_fwd_hdim128.cu | 12 + .../kernels/fmha_fwd_hdim32.cu | 17 + .../kernels/fmha_fwd_hdim64.cu | 17 + .../kernels/fmha_fwd_launch_template.h | 91 + .../kernels/fmha_kernel.h | 78 + .../candle-flash-attn-v1/kernels/fmha_utils.h | 99 + .../candle-flash-attn-v1/kernels/philox.cuh | 157 ++ .../kernels/static_switch.h | 40 + .../candle-flash-attn-v1/src/ffi.rs | 41 + .../candle-flash-attn-v1/src/lib.rs | 272 +++ .../tests/flash_attn_tests.rs | 61 + .../candle-layer-norm/.gitignore | 3 + .../candle-layer-norm/Cargo.toml | 15 + .../candle-layer-norm/LICENSE-APACHE | 201 ++ .../candle-layer-norm/LICENSE-MIT | 23 + candle-extensions/candle-layer-norm/README.md | 14 + candle-extensions/candle-layer-norm/build.rs | 256 +++ .../candle-layer-norm/kernels/ln.h | 204 ++ .../candle-layer-norm/kernels/ln_api.cu | 262 +++ .../kernels/ln_fwd_kernels.cuh | 273 +++ .../kernels/ln_kernel_traits.h | 172 ++ .../candle-layer-norm/kernels/ln_utils.cuh | 728 +++++++ .../candle-layer-norm/kernels/static_switch.h | 25 + .../candle-layer-norm/src/ffi.rs | 29 + .../candle-layer-norm/src/lib.rs | 509 +++++ candle-extensions/candle-rotary/.gitignore | 3 + candle-extensions/candle-rotary/Cargo.toml | 22 + .../candle-rotary/LICENSE-APACHE | 201 ++ candle-extensions/candle-rotary/LICENSE-MIT | 23 + candle-extensions/candle-rotary/README.md | 4 + candle-extensions/candle-rotary/build.rs | 56 + .../candle-rotary/kernels/cuda_compat.h | 27 + .../candle-rotary/kernels/rotary.cu | 131 ++ candle-extensions/candle-rotary/src/ffi.rs | 22 + candle-extensions/candle-rotary/src/lib.rs | 175 ++ .../candle-rotary/tests/rotary_tests.rs | 85 + 60 files changed, 10923 insertions(+), 31 deletions(-) create mode 100644 .gitmodules create mode 100644 candle-extensions/candle-flash-attn-v1/.gitignore create mode 100644 candle-extensions/candle-flash-attn-v1/.gitmodules create mode 100644 candle-extensions/candle-flash-attn-v1/Cargo.toml create mode 100644 candle-extensions/candle-flash-attn-v1/LICENSE-APACHE create mode 100644 candle-extensions/candle-flash-attn-v1/LICENSE-MIT create mode 100644 candle-extensions/candle-flash-attn-v1/README.md create mode 100644 candle-extensions/candle-flash-attn-v1/build.rs create mode 160000 candle-extensions/candle-flash-attn-v1/cutlass create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/flash_api.cu create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha.h create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha/gemm.h create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha/gmem_tile.h create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha/kernel_traits.h create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha/mask.h create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha/smem_tile.h create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha/softmax.h create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha/utils.h create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha_api.cpp create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha_fprop_kernel_1xN.h create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_hdim128.cu create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_hdim32.cu create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_hdim64.cu create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_launch_template.h create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha_kernel.h create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/fmha_utils.h create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/philox.cuh create mode 100644 candle-extensions/candle-flash-attn-v1/kernels/static_switch.h create mode 100644 candle-extensions/candle-flash-attn-v1/src/ffi.rs create mode 100644 candle-extensions/candle-flash-attn-v1/src/lib.rs create mode 100644 candle-extensions/candle-flash-attn-v1/tests/flash_attn_tests.rs create mode 100644 candle-extensions/candle-layer-norm/.gitignore create mode 100644 candle-extensions/candle-layer-norm/Cargo.toml create mode 100644 candle-extensions/candle-layer-norm/LICENSE-APACHE create mode 100644 candle-extensions/candle-layer-norm/LICENSE-MIT create mode 100644 candle-extensions/candle-layer-norm/README.md create mode 100644 candle-extensions/candle-layer-norm/build.rs create mode 100644 candle-extensions/candle-layer-norm/kernels/ln.h create mode 100644 candle-extensions/candle-layer-norm/kernels/ln_api.cu create mode 100644 candle-extensions/candle-layer-norm/kernels/ln_fwd_kernels.cuh create mode 100644 candle-extensions/candle-layer-norm/kernels/ln_kernel_traits.h create mode 100644 candle-extensions/candle-layer-norm/kernels/ln_utils.cuh create mode 100644 candle-extensions/candle-layer-norm/kernels/static_switch.h create mode 100644 candle-extensions/candle-layer-norm/src/ffi.rs create mode 100644 candle-extensions/candle-layer-norm/src/lib.rs create mode 100644 candle-extensions/candle-rotary/.gitignore create mode 100644 candle-extensions/candle-rotary/Cargo.toml create mode 100644 candle-extensions/candle-rotary/LICENSE-APACHE create mode 100644 candle-extensions/candle-rotary/LICENSE-MIT create mode 100644 candle-extensions/candle-rotary/README.md create mode 100644 candle-extensions/candle-rotary/build.rs create mode 100644 candle-extensions/candle-rotary/kernels/cuda_compat.h create mode 100644 candle-extensions/candle-rotary/kernels/rotary.cu create mode 100644 candle-extensions/candle-rotary/src/ffi.rs create mode 100644 candle-extensions/candle-rotary/src/lib.rs create mode 100644 candle-extensions/candle-rotary/tests/rotary_tests.rs diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..247e1547 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "candle-extensions/candle-flash-attn-v1/cutlass"] + path = candle-extensions/candle-flash-attn-v1/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/Cargo.lock b/Cargo.lock index a099e6bf..2027fa2d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -494,6 +494,29 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" +[[package]] +name = "candle-core" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db8659ea87ee8197d2fc627348916cce0561330ee7ae3874e771691d3cecb2f" +dependencies = [ + "byteorder", + "candle-kernels 0.3.3", + "cudarc", + "gemm", + "half", + "memmap2", + "num-traits", + "num_cpus", + "rand 0.8.5", + "rand_distr", + "rayon", + "safetensors", + "thiserror 1.0.69", + "yoke", + "zip 0.6.6", +] + [[package]] name = "candle-core" version = "0.5.0" @@ -501,7 +524,7 @@ source = "git+https://github.com/OlivierDehaene/candle?rev=7e02ad856104799b73a94 dependencies = [ "accelerate-src", "byteorder", - "candle-kernels", + "candle-kernels 0.5.0", "candle-metal-kernels", "cudarc", "gemm", @@ -525,7 +548,7 @@ dependencies = [ name = "candle-cublaslt" version = "0.2.2" dependencies = [ - "candle-core", + "candle-core 0.5.0", "cudarc", "half", ] @@ -537,22 +560,31 @@ source = "git+https://github.com/OlivierDehaene/candle?rev=7e02ad856104799b73a94 dependencies = [ "anyhow", "bindgen_cuda", - "candle-core", + "candle-core 0.5.0", "half", ] [[package]] name = "candle-flash-attn-v1" version = "0.0.1" -source = "git+https://github.com/huggingface/candle-flash-attn-v1?rev=3f1870b0d708579904c76e41745c659c3f9fa038#3f1870b0d708579904c76e41745c659c3f9fa038" dependencies = [ "anyhow", - "candle-core", + "candle-core 0.5.0", + "candle-nn 0.3.3", "half", "num_cpus", "rayon", ] +[[package]] +name = "candle-kernels" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d80cdd5f1cc60d30ba61353cdba5accd0fbc4d4ef2fe707fcb5179a9821adbea" +dependencies = [ + "bindgen_cuda", +] + [[package]] name = "candle-kernels" version = "0.5.0" @@ -564,10 +596,9 @@ dependencies = [ [[package]] name = "candle-layer-norm" version = "0.0.1" -source = "git+https://github.com/huggingface/candle-layer-norm?rev=94c2add7d94c2d63aebde77f7534614e04dbaea1#94c2add7d94c2d63aebde77f7534614e04dbaea1" dependencies = [ "anyhow", - "candle-core", + "candle-core 0.5.0", "half", "num_cpus", "rayon", @@ -584,13 +615,28 @@ dependencies = [ "tracing", ] +[[package]] +name = "candle-nn" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ddce8312032760a6791d6adc9c56dc54fd7c1be38d85dcc4862f1c75228bbc7" +dependencies = [ + "candle-core 0.3.3", + "half", + "num-traits", + "rayon", + "safetensors", + "serde", + "thiserror 1.0.69", +] + [[package]] name = "candle-nn" version = "0.5.0" source = "git+https://github.com/OlivierDehaene/candle?rev=7e02ad856104799b73a946ac1e153f0de77feaaf#7e02ad856104799b73a946ac1e153f0de77feaaf" dependencies = [ "accelerate-src", - "candle-core", + "candle-core 0.5.0", "candle-metal-kernels", "half", "intel-mkl-src", @@ -605,11 +651,11 @@ dependencies = [ [[package]] name = "candle-rotary" version = "0.0.1" -source = "git+https://github.com/huggingface/candle-rotary?rev=0a718a0856569a92f3112e64f10d07e4447822e8#0a718a0856569a92f3112e64f10d07e4447822e8" dependencies = [ "anyhow", "bindgen_cuda", - "candle-core", + "candle-core 0.5.0", + "candle-nn 0.3.3", "half", ] @@ -619,8 +665,8 @@ version = "0.5.0" source = "git+https://github.com/OlivierDehaene/candle?rev=7e02ad856104799b73a946ac1e153f0de77feaaf#7e02ad856104799b73a946ac1e153f0de77feaaf" dependencies = [ "byteorder", - "candle-core", - "candle-nn", + "candle-core 0.5.0", + "candle-nn 0.5.0", "fancy-regex", "num-traits", "rand 0.8.5", @@ -4202,12 +4248,12 @@ version = "1.6.1" dependencies = [ "accelerate-src", "anyhow", - "candle-core", + "candle-core 0.5.0", "candle-cublaslt", "candle-flash-attn", "candle-flash-attn-v1", "candle-layer-norm", - "candle-nn", + "candle-nn 0.5.0", "candle-rotary", "candle-transformers", "hf-hub", diff --git a/Cargo.toml b/Cargo.toml index 9f758d86..b9c26a8c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,9 @@ members = [ "backends/python", "backends/grpc-client", "candle-extensions/candle-cublaslt", + "candle-extensions/candle-flash-attn-v1", + "candle-extensions/candle-layer-norm", + "candle-extensions/candle-rotary", "core", "router", ] @@ -49,13 +52,7 @@ candle = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104 candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-nn" } candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-transformers" } candle-flash-attn = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-flash-attn" } - - -[patch.crates-io] -candle = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-core" } -candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-nn" } -candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-transformers" } -candle-flash-attn = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-flash-attn" } +half = { version = "2.3.1", features = ["num-traits"] } [profile.release] debug = 0 diff --git a/backends/candle/Cargo.toml b/backends/candle/Cargo.toml index bb70f563..8b6b5c6a 100644 --- a/backends/candle/Cargo.toml +++ b/backends/candle/Cargo.toml @@ -9,15 +9,14 @@ homepage.workspace = true anyhow = { workspace = true } accelerate-src = { version = "0.3.2", optional = true } intel-mkl-src = { version = "0.8.1", optional = true } -candle = { version = "*", package = "candle-core", default-features = false } -candle-nn = { version = "*" } -candle-transformers = { version = "*" } -candle-flash-attn = { version = "*", optional = true } -candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn-v1", rev = "3f1870b0d708579904c76e41745c659c3f9fa038", optional = true } -# candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "cf789b7dd6d4abb19b03b9556442f94f0588b4a0", optional = true } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } +candle-flash-attn = { workspace = true, optional = true} +candle-flash-attn-v1 = { path = "../../candle-extensions/candle-flash-attn-v1", optional = true } candle-cublaslt = { path = "../../candle-extensions/candle-cublaslt", optional = true } -candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "94c2add7d94c2d63aebde77f7534614e04dbaea1", optional = true } -candle-rotary = { git = "https://github.com/huggingface/candle-rotary", rev = "0a718a0856569a92f3112e64f10d07e4447822e8", optional = true } +candle-layer-norm = { path = "../../candle-extensions/candle-layer-norm", optional = true } +candle-rotary = { path = "../../candle-extensions/candle-rotary", optional = true } nohash-hasher = { workspace = true } text-embeddings-backend-core = { path = "../core" } tracing = { workspace = true } diff --git a/candle-extensions/candle-cublaslt/Cargo.toml b/candle-extensions/candle-cublaslt/Cargo.toml index ab0b240e..85be5b82 100644 --- a/candle-extensions/candle-cublaslt/Cargo.toml +++ b/candle-extensions/candle-cublaslt/Cargo.toml @@ -6,10 +6,9 @@ edition = "2021" description = "CUBLASLt gemm for the candle ML framework." [dependencies] -# candle = { version = "0.8", package = "candle-core", features = ["cuda"]} candle = { workspace=true, features = ["cuda"]} cudarc = { workspace = true, features = [ "cublaslt", "f16" ]} -half = { version = "2.3.1", features = ["num-traits"] } +half = { workspace = true} [features] static-linking = ["cudarc/static-linking"] diff --git a/candle-extensions/candle-flash-attn-v1/.gitignore b/candle-extensions/candle-flash-attn-v1/.gitignore new file mode 100644 index 00000000..fbc9a58c --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/.gitignore @@ -0,0 +1,3 @@ +.idea +target +Cargo.lock diff --git a/candle-extensions/candle-flash-attn-v1/.gitmodules b/candle-extensions/candle-flash-attn-v1/.gitmodules new file mode 100644 index 00000000..be2c482f --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/.gitmodules @@ -0,0 +1,3 @@ +[submodule "cutlass"] + path = cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/candle-extensions/candle-flash-attn-v1/Cargo.toml b/candle-extensions/candle-flash-attn-v1/Cargo.toml new file mode 100644 index 00000000..5e445b88 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "candle-flash-attn-v1" +version = "0.0.1" +edition = "2021" + +description = "Flash attention V1 layer for the candle ML framework." +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT OR Apache-2.0" +readme = "README.md" + +[dependencies] +candle = { workspace = true } +half = { workspace = true } + +[build-dependencies] +anyhow = { version = "1", features = ["backtrace"] } +num_cpus = "1.15.0" +rayon = "1.7.0" + +[dev-dependencies] +anyhow = { version = "1", features = ["backtrace"] } +candle-nn = { version = "0.3.0", features = ["cuda"] } diff --git a/candle-extensions/candle-flash-attn-v1/LICENSE-APACHE b/candle-extensions/candle-flash-attn-v1/LICENSE-APACHE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/candle-extensions/candle-flash-attn-v1/LICENSE-MIT b/candle-extensions/candle-flash-attn-v1/LICENSE-MIT new file mode 100644 index 00000000..31aa7938 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/LICENSE-MIT @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/candle-extensions/candle-flash-attn-v1/README.md b/candle-extensions/candle-flash-attn-v1/README.md new file mode 100644 index 00000000..85fabb7f --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/README.md @@ -0,0 +1,4 @@ +# Candle Flash Attention v1 Layer + +Flash Attention v2 does not support Turing GPUs (T4, RTX 2080). This layer can be used in replacement of the official +flash attention Candle layer in the meantime. diff --git a/candle-extensions/candle-flash-attn-v1/build.rs b/candle-extensions/candle-flash-attn-v1/build.rs new file mode 100644 index 00000000..2722045a --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/build.rs @@ -0,0 +1,258 @@ +// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel. +// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment +// variable in order to cache the compiled artifacts and avoid recompiling too often. +use anyhow::{Context, Result}; +use rayon::prelude::*; +use std::path::PathBuf; +use std::str::FromStr; + +const KERNEL_FILES: [&str; 4] = [ + "flash_api.cu", + "fmha_fwd_hdim32.cu", + "fmha_fwd_hdim64.cu", + "fmha_fwd_hdim128.cu", +]; + +fn main() -> Result<()> { + let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else( + |_| num_cpus::get_physical(), + |s| usize::from_str(&s).unwrap(), + ); + + rayon::ThreadPoolBuilder::new() + .num_threads(num_cpus) + .build_global() + .unwrap(); + + println!("cargo:rerun-if-changed=build.rs"); + for kernel_file in KERNEL_FILES.iter() { + println!("cargo:rerun-if-changed=kernels/{kernel_file}"); + } + println!("cargo:rerun-if-changed=kernels/**.h"); + println!("cargo:rerun-if-changed=kernels/**.cuh"); + println!("cargo:rerun-if-changed=kernels/fmha/**.h"); + let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?); + let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") { + Err(_) => + { + #[allow(clippy::redundant_clone)] + out_dir.clone() + } + Ok(build_dir) => { + let path = PathBuf::from(build_dir); + path.canonicalize().expect(&format!( + "Directory doesn't exists: {} (the current directory is {})", + &path.display(), + std::env::current_dir()?.display() + )) + } + }; + set_cuda_include_dir()?; + + let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); + println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); + + let compute_cap = compute_cap()?; + + let out_file = build_dir.join("libflashattentionv1.a"); + + let kernel_dir = PathBuf::from("kernels"); + let cu_files: Vec<_> = KERNEL_FILES + .iter() + .map(|f| { + let mut obj_file = out_dir.join(f); + obj_file.set_extension("o"); + (kernel_dir.join(f), obj_file) + }) + .collect(); + let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified()); + let should_compile = if out_file.exists() { + kernel_dir + .read_dir() + .expect("kernels folder should exist") + .any(|entry| { + if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) { + let in_modified = entry.metadata().unwrap().modified().unwrap(); + in_modified.duration_since(*out_modified).is_ok() + } else { + true + } + }) + } else { + true + }; + if should_compile { + cu_files + .par_iter() + .map(|(cu_file, obj_file)| { + let mut command = std::process::Command::new("nvcc"); + command + .arg("-O3") + .arg("-std=c++17") + .arg(format!("--gpu-architecture=sm_{compute_cap}")) + .arg("-c") + .args(["-o", obj_file.to_str().unwrap()]) + .args(["--default-stream", "per-thread"]) + .arg("-Icutlass/include") + .arg("-U__CUDA_NO_HALF_OPERATORS__") + .arg("-U__CUDA_NO_HALF_CONVERSIONS__") + .arg("-U__CUDA_NO_HALF2_OPERATORS__") + .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + .arg("--expt-relaxed-constexpr") + .arg("--expt-extended-lambda") + .arg("--use_fast_math") + .arg("--ptxas-options=-v") + .arg("--verbose"); + if let Ok(ccbin_path) = &ccbin_env { + command + .arg("-allow-unsupported-compiler") + .args(["-ccbin", ccbin_path]); + } + command.arg(cu_file); + let output = command + .spawn() + .context("failed spawning nvcc")? + .wait_with_output()?; + if !output.status.success() { + anyhow::bail!( + "nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + &command, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ) + } + Ok(()) + }) + .collect::>()?; + let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::>(); + let mut command = std::process::Command::new("nvcc"); + command + .arg("--lib") + .args(["-o", out_file.to_str().unwrap()]) + .args(obj_files); + let output = command + .spawn() + .context("failed spawning nvcc")? + .wait_with_output()?; + if !output.status.success() { + anyhow::bail!( + "nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + &command, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ) + } + } + println!("cargo:rustc-link-search={}", build_dir.display()); + println!("cargo:rustc-link-lib=flashattentionv1"); + println!("cargo:rustc-link-lib=dylib=cudart"); + println!("cargo:rustc-link-lib=dylib=stdc++"); + + Ok(()) +} + +fn set_cuda_include_dir() -> Result<()> { + // NOTE: copied from cudarc build.rs. + let env_vars = [ + "CUDA_PATH", + "CUDA_ROOT", + "CUDA_TOOLKIT_ROOT_DIR", + "CUDNN_LIB", + ]; + let env_vars = env_vars + .into_iter() + .map(std::env::var) + .filter_map(Result::ok) + .map(Into::::into); + + let roots = [ + "/usr", + "/usr/local/cuda", + "/opt/cuda", + "/usr/lib/cuda", + "C:/Program Files/NVIDIA GPU Computing Toolkit", + "C:/CUDA", + ]; + let roots = roots.into_iter().map(Into::::into); + let root = env_vars + .chain(roots) + .find(|path| path.join("include").join("cuda.h").is_file()) + .context("cannot find include/cuda.h")?; + println!( + "cargo:rustc-env=CUDA_INCLUDE_DIR={}", + root.join("include").display() + ); + Ok(()) +} + +#[allow(unused)] +fn compute_cap() -> Result { + println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); + + // Try to parse compute caps from env + let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { + println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}"); + compute_cap_str + .parse::() + .context("Could not parse code")? + } else { + // Use nvidia-smi to get the current compute cap + let out = std::process::Command::new("nvidia-smi") + .arg("--query-gpu=compute_cap") + .arg("--format=csv") + .output() + .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; + let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; + let mut lines = out.lines(); + assert_eq!( + lines.next().context("missing line in stdout")?, + "compute_cap" + ); + let cap = lines + .next() + .context("missing line in stdout")? + .replace('.', ""); + let cap = cap + .parse::() + .with_context(|| format!("cannot parse as int {cap}"))?; + println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}"); + cap + }; + + // Grab available GPU codes from nvcc and select the highest one + let (supported_nvcc_codes, max_nvcc_code) = { + let out = std::process::Command::new("nvcc") + .arg("--list-gpu-code") + .output() + .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); + let out = std::str::from_utf8(&out.stdout).unwrap(); + + let out = out.lines().collect::>(); + let mut codes = Vec::with_capacity(out.len()); + for code in out { + let code = code.split('_').collect::>(); + if !code.is_empty() && code.contains(&"sm") { + if let Ok(num) = code[1].parse::() { + codes.push(num); + } + } + } + codes.sort(); + let max_nvcc_code = *codes.last().unwrap(); + (codes, max_nvcc_code) + }; + + // Check that nvcc supports the asked compute cap + if !supported_nvcc_codes.contains(&compute_cap) { + anyhow::bail!( + "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}." + ); + } + if compute_cap > max_nvcc_code { + anyhow::bail!( + "CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}" + ); + } + + Ok(compute_cap) +} diff --git a/candle-extensions/candle-flash-attn-v1/cutlass b/candle-extensions/candle-flash-attn-v1/cutlass new file mode 160000 index 00000000..5f13dcad --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/cutlass @@ -0,0 +1 @@ +Subproject commit 5f13dcad781284678edafa3b8d108120cfc6a6e4 diff --git a/candle-extensions/candle-flash-attn-v1/kernels/flash_api.cu b/candle-extensions/candle-flash-attn-v1/kernels/flash_api.cu new file mode 100644 index 00000000..06133921 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/flash_api.cu @@ -0,0 +1,114 @@ +#include "fmha.h" +#include "fmha_utils.h" + +void run_fmha_fwd(Launch_params &launch_params) { + if (launch_params.params.d <= 32) { + run_fmha_fwd_hdim32(launch_params); + } else if (launch_params.params.d <= 64) { + run_fmha_fwd_hdim64(launch_params); + } else if (launch_params.params.d <= 128) { + run_fmha_fwd_hdim128(launch_params); + } +} + +extern "C" void run_mha( + void *q_ptr, + void *k_ptr, + void *v_ptr, + void *o_ptr, + void *o_tmp_ptr, + void *softmax_lse_ptr, + + int32_t *cu_seqlens_q_ptr, + int32_t *cu_seqlens_k_ptr, + + uint32_t q_row_stride, + uint32_t k_row_stride, + uint32_t v_row_stride, + uint32_t o_row_stride, + uint32_t o_tmp_row_stride, + + uint32_t q_head_stride, + uint32_t k_head_stride, + uint32_t v_head_stride, + uint32_t o_head_stride, + uint32_t o_tmp_head_stride, + + uint32_t b, + uint32_t h, + uint32_t d, + float softmax_scale, + + uint32_t seqlen_q, + uint32_t seqlen_k, + + int is_causal, + int is_bf16, + + int32_t multi_processor_count, + int32_t num_splits +) { + Data_type data_type = !is_bf16 ? DATA_TYPE_FP16 : DATA_TYPE_BF16; + + Launch_params launch_params; + + launch_params.elts_per_thread = 0; + launch_params.multi_processor_count = multi_processor_count; + launch_params.stream = 0; + launch_params.is_dropout = false; + launch_params.return_softmax = false; + + FMHA_fprop_params ¶ms = launch_params.params; + + // Set the pointers and strides. + params.q_ptr = q_ptr; + params.k_ptr = k_ptr; + params.v_ptr = v_ptr; + params.o_ptr = o_ptr; + params.o_tmp_ptr = o_tmp_ptr; + + params.softmax_lse_ptr = softmax_lse_ptr; + + // All stride are in elements, not bytes. + + params.q_row_stride_in_elts = q_row_stride; + params.k_row_stride_in_elts = k_row_stride; + params.v_row_stride_in_elts = v_row_stride; + params.o_row_stride_in_elts = o_row_stride; + params.o_tmp_row_stride_in_elts = o_tmp_row_stride; + + params.q_head_stride_in_elts = q_head_stride; + params.k_head_stride_in_elts = k_head_stride; + params.v_head_stride_in_elts = v_head_stride; + params.o_head_stride_in_elts = o_head_stride; + params.o_tmp_head_stride_in_elts = o_tmp_head_stride; + + // Set the dimensions. + params.h = h; + params.b = b; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.d = d; + + // Set the different scale values. + const float scale_bmm1 = softmax_scale; + params.scale_bmm1f = scale_bmm1; + set_alpha(params.scale_bmm1, scale_bmm1, data_type); + + params.p_dropout = 1.; // probability to keep + params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1f; + set_alpha(params.scale_dropout, params.rp_dropout, data_type); + + params.is_bf16 = is_bf16; + params.is_causal = is_causal; + + params.cu_seqlens_q = cu_seqlens_q_ptr; + params.cu_seqlens_k = cu_seqlens_k_ptr; + + params.num_splits = num_splits; + + run_fmha_fwd(launch_params); +} diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha.h b/candle-extensions/candle-flash-attn-v1/kernels/fmha.h new file mode 100644 index 00000000..cda60d81 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha.h @@ -0,0 +1,153 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include + +#include "fmha_utils.h" + + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + // size_t qkv_stride_in_elts; + // size_t qkv_stride_in_bytes; + // TD [2022-04-16]: We're using 32-bit indexing to save registers. + // The code probably won't work for arrays larger than 2GB. + uint32_t q_row_stride_in_elts; + uint32_t k_row_stride_in_elts; + uint32_t v_row_stride_in_elts; + uint32_t q_head_stride_in_elts; + uint32_t k_head_stride_in_elts; + uint32_t v_head_stride_in_elts; + + // The number of heads. + int h; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct FMHA_fprop_params : public Qkv_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + + // The stride between rows of O. + // size_t o_stride_in_elts; + // size_t o_stride_in_bytes; + uint32_t o_row_stride_in_elts; + uint32_t o_head_stride_in_elts; + uint32_t o_tmp_row_stride_in_elts; + uint32_t o_tmp_head_stride_in_elts; + + // The pointer to the O_tmp matrix, which holds O intermediate value during + // the loop; + void *__restrict__ o_tmp_ptr; + + // The pointer to the S matrix. + void * __restrict__ s_ptr; + // The stride between rows of the S matrix. + // int64_t s_stride_in_bytes; + uint32_t s_stride_in_bytes; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, d; + + // The scaling factors for the kernel. + float scale_bmm1f; + uint32_t scale_bmm1; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + + int *__restrict__ blockmask; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + uint32_t p_dropout_in_uint; + uint16_t p_dropout_in_uint16_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_bmm1_rp_dropout; + + // Scale factor of 1 / (1 - p_dropout), in half2. + uint32_t scale_dropout; + + // Random state. +// at::PhiloxCudaState philox_args; + // Pointer to the RNG seed (idx 0) and offset (idx 1). +// uint64_t * rng_state; + + bool is_bf16; + bool is_causal; + + int num_splits; // How many SMs per attention matrix. +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Launch_params{ + size_t elts_per_thread; + + int multi_processor_count; + + cudaStream_t stream; + + bool is_dropout; + bool return_softmax; + + Kernel_params params; + int num_full_heads; + int num_main_groups; + int heads_last_wave; + int main_steps; + int rest_steps; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_fmha_fwd_hdim32(Launch_params &launch_params); +void run_fmha_fwd_hdim64(Launch_params &launch_params); +void run_fmha_fwd_hdim128(Launch_params &launch_params); diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha/gemm.h b/candle-extensions/candle-flash-attn-v1/kernels/fmha/gemm.h new file mode 100644 index 00000000..4b759871 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha/gemm.h @@ -0,0 +1,451 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include "utils.h" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/layout/layout.h" +#include "cutlass/arch/mma.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ > +struct Fragment_base_ { + + // The data type. + using Data_type = Data_type_; + // default input type + using Input_type_ = Data_type_; + // Does it store the array of elements. + static constexpr bool HAS_ELTS = BITS_PER_ELT_ >= 8; + // The number of elements. + static constexpr int NUM_ELTS = NUM_ELTS_; + // The size of element in bits. + static constexpr int BITS_PER_ELT = BITS_PER_ELT_; + // The size of byte of a single register. + static constexpr int BYTES_PER_REG = 4; + // The size in bits. + static constexpr int BITS_PER_REG = BYTES_PER_REG * 8; + // The number of registers needed to store the fragment. + static constexpr int NUM_REGS = DivUpConstexpr(NUM_ELTS * BITS_PER_ELT, BITS_PER_REG); + // The size in bytes (as returned by sizeof(Fragment_base<>). + static constexpr int SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG; + // The alignment. + static constexpr int ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : MinConstexpr(NUM_REGS * BYTES_PER_REG, 16); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The type of the elements. + typename Data_type_, + // The number of elements. + int NUM_ELTS_, + // The alignment if you want to force a value -- use 0 otherwise. + int ALIGNMENT_ = 0, + // The base class. + typename Base_ = Fragment_base_ +> +struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { + + // The size of a load/store. + static constexpr int BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t); + + // Clear the fragment. Using PTX in that code seems to produce better SASS... + inline __device__ void clear() { + #pragma unroll + for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { + asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : ); + } + } + + // Immutable access to a register. + inline __device__ const uint32_t& reg(int ii) const { + return this->regs_[ii]; + } + + // Mutable access to a register. + inline __device__ uint32_t& reg(int ii) { + return this->regs_[ii]; + } + + uint32_t regs_[Base_::NUM_REGS]; + + // Immutable access to the elements. + inline __device__ const Data_type_& elt(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to the elements. + inline __device__ Data_type_& elt(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Immutable access to the elements with a cast. + template< typename Cast_type > + inline __device__ const Cast_type& elt_as(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to the elements. + template< typename Cast_type > + inline __device__ Cast_type& elt_as(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Add another fragment. + inline __device__ void add(const Fragment &other) { + // TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS? + // Also are we doing int addition or __half2 addition? + #pragma unroll + for( int ii = 0; ii < NUM_ELTS_; ++ii ) { + this->elt(ii) += other.elt(ii); + } + } + + // Multiply by another fragment. + inline __device__ void hmul(const Fragment &other) { + #pragma unroll + for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { + this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii)); + } + } + + template + inline __device__ void hrelu_() { + #pragma unroll + for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { + this->reg(ii) = fmha::hrelu2(this->reg(ii)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Layout > +struct Fragment_a : public Fragment { +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Layout > +struct Fragment_b : public Fragment { +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Fragment_accumulator : public Fragment { + + // The base class. + using Base = Fragment; + + // Add two fragments. + template< typename Other_fragment_ > + inline __device__ void add(const Other_fragment_ &other) { + for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + inline __device__ void mul_(const float other) { + for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { + this->elt(ii) *= other; + } + } + + // Do the HMMA. + template< typename Layout_a, typename Layout_b > + inline __device__ void mma(const Fragment_a &a, + const Fragment_b &b) { + asm volatile( \ + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \ + " {%0, %1, %2, %3}, \n" \ + " {%4, %5, %6, %7}, \n" \ + " {%8, %9}, \n" \ + " {%0, %1, %2, %3}; \n" \ + : "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)) + , "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( \ + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \ + " {%0, %1, %2, %3}, \n" \ + " {%4, %5, %6, %7}, \n" \ + " {%8, %9}, \n" \ + " {%0, %1, %2, %3}; \n" \ + : "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)) + , "r"(b.reg(2)), "r"(b.reg(3))); + } + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Fragment, int M, int N > +inline __device__ void clear(Fragment (&frag)[M][N]) { + #pragma unroll + for( int mi = 0; mi < M; ++mi ) { + #pragma unroll + for( int ni = 0; ni < N; ++ni ) { + frag[mi][ni].clear(); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Accumulator_type, int WARPS_K > +struct Clear_accumulator { +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int WARPS_K > +struct Clear_accumulator { + template< typename Acc, int M, int N > + static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { + fmha::clear(acc); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { + + #pragma unroll + for( int mi = 0; mi < M; ++mi ) { + #pragma unroll + for( int ni = 0; ni < N; ++ni ) { + acc[mi][ni].mma(a[mi], b[ni]); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////// +/// Statically maps half types => cutlass data types +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct HalfTypeToCutlassType { using Type = Type_; }; + +/// Statically maps __half => cutlass::half_t +template <> struct HalfTypeToCutlassType<__half> { + using Type = cutlass::half_t; +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +template <> struct HalfTypeToCutlassType<__nv_bfloat16> { + using Type = cutlass::bfloat16_t; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { + using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; +#else + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + // TD [2022-06-02] We don't support Volta (SM70) yet. + assert(0); +#endif + using Element = typename HalfTypeToCutlassType::Type; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + + using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type; + + constexpr int kIters = Shape::kK / InstructionShape::kK; + // using FragmentA = typename WarpMma::FragmentA; + // using FragmentB = typename WarpMma::FragmentB; + using FragmentA = typename WarpMma::ArchMmaOperator::FragmentA; + using FragmentB = typename WarpMma::ArchMmaOperator::FragmentB; + using FragmentC = typename WarpMma::FragmentC; + + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) { + // printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements); + // printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements); + // printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements); + // printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements); + // printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements); + // printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements); + // } + + // static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS); + // static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS); + static_assert(FragmentA::kStorageElements * kIters == a[0].NUM_REGS); + static_assert(FragmentB::kStorageElements * kIters * 16 / InstructionShape::kN == b[0].NUM_REGS); + static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS); + // const FragmentA a_cl = reinterpret_cast(a); + // const FragmentB b_cl = reinterpret_cast(b); + FragmentC c_cl = reinterpret_cast(acc); + FragmentA a_cl[kIters][M]; + FragmentA b_cl[kIters][N]; + constexpr int kRegs = InstructionShape::kK == 16 ? 4 : 2; + #pragma unroll + for (int iter = 0; iter < kIters; iter++) { + #pragma unroll + for (int mi = 0; mi < M; mi++) { + uint32_t *a_ptr = a_cl[iter][mi].raw_data(); + #pragma unroll + for (int ki = 0; ki < kRegs; ki++) { + a_ptr[ki] = a[mi].regs_[iter * kRegs + ki]; + } + } + } + #pragma unroll + for (int iter = 0; iter < kIters; iter++) { + #pragma unroll + for (int ni = 0; ni < N; ni++) { + uint32_t *b_ptr = b_cl[iter][ni].raw_data(); + #pragma unroll + for (int ki = 0; ki < kRegs; ki++) { + // b_ptr[ki] = b[ni].regs_[iter * kRegs + ki]; + // TD [2022-06-02] For some reason the order for frag_b is different. + b_ptr[ki] = b[ni].regs_[InstructionShape::kK == 16 ? iter * kRegs + ki : ki * kRegs + iter]; + } + } + } + + WarpMma mma_op; + // mma_op(c_cl, a_cl, b_cl, c_cl); + #pragma unroll + for (int iter = 0; iter < kIters; iter++) { + mma_op(c_cl, reinterpret_cast(a_cl[iter]), + reinterpret_cast(b_cl[iter]), c_cl); + } + + // The modified c_cl is not copied back into acc, idk why + #pragma unroll + for (int mi = 0; mi < M; mi++) { + #pragma unroll + for (int ni = 0; ni < N; ni++) { + #pragma unroll + for (int i =0; i < 8; i++) { + acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i]; + } + } + } + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The number of rows in the CTA tile. + int M_, + // The number of cols in the CTA tile. + int N_, + // The number of elements in the the K dimension of the GEMM loop. + int K_, + // The number of rows of warps. + int WARPS_M_, + // The number of cols of warps. + int WARPS_N_, + // The number of warps in the K dimension of the GEMM loop. + int WARPS_K_> +struct Cta_tile_ { + + static constexpr int M = M_, N = N_, K = K_; + // The number of warps. + static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_; + // The number of warps per CTA. + static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K; + // The number of threads per warp. + static constexpr int THREADS_PER_WARP = 32; + // The number of threads per CTA. + static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hmma_tile { + // The number of elements computed with a single warp-MMA. + static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16; + + // The number of elements computed with a single CTA-MMA. + static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, + N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, + K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K; + + // The number of MMAs needed to compute the GEMM. + static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA), + MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA), + MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA); + + // // The number of elements computed per warp. + // static constexpr int M_PER_WARP = MMAS_M * M_PER_MMA, + // N_PER_WARP = MMAS_N * N_PER_MMA, + // K_PER_WARP = MMAS_K * K_PER_MMA; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using A_type = uint16_t; +using B_type = uint16_t; +using C_type = uint16_t; +using Accumulator_type = float; +using Epilogue_type = float; + +constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8; +constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8; +constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Cta_tile_extd = Cta_tile_; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Cta_tile_with_k_with_padding = Cta_tile_extd::VALUE, + Cta_tile_::WARPS_M, + Cta_tile_::WARPS_N, + Cta_tile_::WARPS_K>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha/gmem_tile.h b/candle-extensions/candle-flash-attn-v1/kernels/fmha/gmem_tile.h new file mode 100644 index 00000000..2ed187a4 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha/gmem_tile.h @@ -0,0 +1,554 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include + +#include "utils.h" + +namespace fmha { + +template< + // The dimensions of the tile computed by the CTA. + typename Cta_tile_, + // The number of bits per element. + int BITS_PER_ELEMENT, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns. + int COLS, + int BYTES_PER_LDGS_ = 16 +> +struct Gmem_tile_qkv { + + using Cta_tile = Cta_tile_; + + static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8; + // The size of each LDG. + static constexpr int BYTES_PER_LDG = BYTES_PER_LDGS_; + // The size of a row in bytes. + static constexpr int BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8; + + // The number of threads to load a "row" of the matrix. + static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG; + + static constexpr int ROWS = ROWS_; + // The number of "rows" loaded per LDG. + static constexpr int ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW; + // The number of LDGs needed to load a chunk of the Q matrix. + static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG); + + // Ctor. + template< typename BInfo > + inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts, + const uint32_t head_stride_in_elts, const int headdim, + const BInfo &binfo, const int tidx, bool use_seqlen_q) + : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) + , actual_seqlen(use_seqlen_q ? binfo.actual_seqlen_q : binfo.actual_seqlen_k) + , ptr(reinterpret_cast(ptr_)) + , tidx_(tidx) + , col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_LDG / BYTES_PER_ELEMENT) < headdim) { + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Store the row as we need it to disable the loads. + // TD [2022-04-16]: To minimize registers, we'll recompute row_ instead of storing it + // row_ = row; + + // The row offset in the batched GEMM. For each seq element, we store QKV in that order. + // int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes; + uint32_t row_offset = (uint32_t)(((use_seqlen_q ? binfo.sum_s_q : binfo.sum_s_k) + row) * row_stride_in_bytes); + // Add the block index. + // row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; + row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); + + // Assemble the final pointer. + ptr += row_offset + col * BYTES_PER_LDG; + } + + // Store data to shared memory. + template< typename Smem_tile > + inline __device__ void commit(Smem_tile &smem_tile) { + smem_tile.store(fetch_); + } + + inline __device__ void load() { + int row_ = tidx_ / THREADS_PER_ROW; + const void *ptrs[LDGS]; + uint32_t preds[LDGS]; + #pragma unroll + for( int ii = 0; ii < LDGS; ++ii ) { + // ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes; + ptrs[ii] = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; + preds[ii] = col_predicate && ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)); + fetch_[ii] = make_uint4(0, 0, 0, 0); + } + + // not packing predicates removes restrictions (e.g. FP16 384, 4 warps) + Ldg_functor fct(fetch_, ptrs); + #pragma unroll + for( int ii = 0; ii < LDGS; ++ii ) { + fct.load(ii, preds[ii]); + } + } + + // Store data to memory. + inline __device__ void store(const uint4 (&data)[LDGS]) { + int row_ = tidx_ / THREADS_PER_ROW; + #pragma unroll + for( int ii = 0; ii < LDGS; ++ii ) { + // char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes; + char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; + if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) { + fmha::stg(ptr_, data[ii]); + } + } + } + + inline __device__ void move(const int steps = 1) { + // ptr += (int64_t)ROWS * row_stride_in_bytes * steps; + ptr += (uint32_t)ROWS * row_stride_in_bytes * steps; + actual_seqlen -= ROWS * steps; + } + + // The stride between rows for the QKV matrice. + // int64_t row_stride_in_bytes; + const uint32_t row_stride_in_bytes; + // The pointer. + char *ptr; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row the thread is processing as we move the tile. + // int row_; + const int tidx_; + // The length of the sequence loaded by that memory tile. + int actual_seqlen; + const bool col_predicate; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + typename Cta_tile, + int BYTES_PER_ELEMENT = 2 +> +struct Gmem_tile_o { + + static_assert(BYTES_PER_ELEMENT == 2 || BYTES_PER_ELEMENT == 4); + + // The mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The size of each element. + // static constexpr int BYTES_PER_ELEMENT = 2; + // The size of each STG. + static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 4; + static constexpr int COLS = Cta_tile::N; + // The size of a row in bytes. + static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT; + + // The number of threads to store a "row" of the matrix. + static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG; + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + static constexpr int ROWS = Cta_tile::M; + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA; + // The number of outter loop for the stores. + static constexpr int LOOPS = ROWS / ROWS_PER_LOOP; + + // The number of "rows" stored per STG. + static constexpr int ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW; + // Do we have to guard against partial writes/reads. + static constexpr bool HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0; + // The number of STGs needed to store a chunk of the Q matrix. + static constexpr int STGS_PER_LOOP = DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_STG); + // The number of STGs needed to store a chunk of the Q matrix in total. + static constexpr int STGS = STGS_PER_LOOP * LOOPS; + + // Ctor. + template + // inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx) + inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts, + const uint32_t head_stride_in_elts, const int headdim, + const BInfo &binfo, const int tidx) + : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) + , actual_seqlen_q(binfo.actual_seqlen_q) + , ptr_(reinterpret_cast(ptr)) + , tidx_(tidx) + , col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_STG / BYTES_PER_ELEMENT) < headdim) { + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Store the row as we need it to disable loads. + // row_ = row; + + // The row offset in the batched GEMM. + // int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW; + uint32_t row_offset = (uint32_t)((binfo.sum_s_q + row) * row_stride_in_bytes); + row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); + // Assemble the final pointer. + ptr_ += row_offset + col * BYTES_PER_STG; + + // Is that thread active on the last STG? + if( HAS_INCOMPLETE_STG ) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; + } + } + + // Store data to global memory. + template + inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) { + int row_ = tidx_ / THREADS_PER_ROW; + #pragma unroll + for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { + int jj = mi * STGS_PER_LOOP + ii; + if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) { + break; + } + + if (BYTES_PER_ELEMENT == 4) { + if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { + fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, src[ii]); + } + } else if (BYTES_PER_ELEMENT == 2) { + float x = reinterpret_cast(src[ii].x); + float y = reinterpret_cast(src[ii].y); + float z = reinterpret_cast(src[ii].z); + float w = reinterpret_cast(src[ii].w); + uint2 out = fmha::float4_pack(x, y, z, w); + if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { + fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out); + } + } + } + } + + // Store data to global memory with atomicAdd. + inline __device__ void atomic_add(const uint4 (&src)[STGS_PER_LOOP], int mi) { + static_assert(BYTES_PER_ELEMENT == 4); // Only do atomic add on floats + int row_ = tidx_ / THREADS_PER_ROW; + #pragma unroll + for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { + int jj = mi * STGS_PER_LOOP + ii; + if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) { + break; + } + + if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { + float *ptr_ = reinterpret_cast(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes); + #pragma unroll + for (int jj = 0; jj < 4; ++jj) { + atomicAdd(ptr_ + jj, reinterpret_cast(src[ii])[jj]); + } + } + } + } + + // Load data from global memory. + inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { + static_assert(BYTES_PER_ELEMENT == 4); + int row_ = tidx_ / THREADS_PER_ROW; + #pragma unroll + for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { + int jj = mi * STGS_PER_LOOP + ii; + if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) { + break; + } + + if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { + fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes); + } + } + } + + inline __device__ void move(const int steps = 1) { + // row_ += ROWS * steps; + // ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps; + ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps; + actual_seqlen_q -= ROWS * steps; + } + + // The stride between rows for the QKV matrice. + // int64_t row_stride_in_bytes; + const uint32_t row_stride_in_bytes; + // The pointer. + char *ptr_; + // Is the thread active for the last STG? + int is_active_for_last_stg_; + // The length of the sequence loaded by that memory tile. + int actual_seqlen_q; + const int tidx_; + const bool col_predicate; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Cta_tile, int BYTES_PER_ELEMENT > +struct Gmem_tile_mma_sd { + + // The mma tile. + using Mma_tile = fmha::Hmma_tile; + + // Each STG stores 8 elements. + static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 8; + // The number of MMAs in the M dimension. + static constexpr int MMAS_M = Mma_tile::MMAS_M; + // The number of MMAs in the N dimension. + static constexpr int MMAS_N = Mma_tile::MMAS_N; + // The number of rows computed per MMA per thread block. + static constexpr int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA; + // The number of cols computed per MMA per thread block. + static constexpr int N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA; + // The number of threads per block. + static constexpr int THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA; + // The size of each row in bytes. I.e. how many bytes are stored per STG. + static constexpr int BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG; + // The distance between elements stored per loop (in bytes). + static constexpr int LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW; + + // The type of elements stored per STG. + using Type = typename fmha::Uint_from_size_in_bytes::Type; + + // Ctor. + template + inline __device__ Gmem_tile_mma_sd(void *ptr, const Params ¶ms, const int bidb, const int bidh, const int tidx) + : ptr_(static_cast(ptr)) { + + // The block index. + // size_t bidx = bidb * params.h + bidh; + uint32_t bidx = bidb * params.h + bidh; + + // The distance between two blocks (in bytes). + // const size_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT; + const uint32_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT; + // Set store location for each thread at the beginning of the loop + ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG; + } + + // Store to global memory. + inline __device__ void store(const Type &data, const int mi, const int ni) { + // size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + fmha::stg(ptr_ + offset, data); + } + + // Load from global memory. + inline __device__ void load(Type &data, const int mi, const int ni) { + // size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + fmha::ldg(data, ptr_ + offset); + } + + // Move to the next tile. + inline __device__ void move(const int steps = 1) { + ptr_ += LOOP_STRIDE_BYTES * steps; + } + + // The pointer in global memory. + char *ptr_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Cta_tile, typename Base = Gmem_tile_mma_sd > +struct Gmem_tile_mma_s : public Base { + + // The number of mmas in the vertical dimension. + static constexpr int M = Base::MMAS_M; + // The number of mmas in the horizontal dimension. + static constexpr int N = Base::MMAS_N; + // The type of the vectors stored by each STG. + using Type = typename Base::Type; + + // Ctor. + template< typename Params, typename Block_info > + inline __device__ Gmem_tile_mma_s(const Params ¶ms, const Block_info& binfo, const int tidx) + : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) { + } + + // Store to global memory. + template + inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){ + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + uint4 dst; + dst.x = frag[ni][mi].reg(0); + dst.y = frag[ni][mi].reg(2); + dst.z = frag[ni][mi].reg(1); + dst.w = frag[ni][mi].reg(3); + if( mask.any_valid(mi, ni) ) { + Base::store(dst, mi, ni); + } + } + } + } + + // Load from global memory. + template + inline __device__ void load(uint4 (®s)[M][N], const Mask &mask) { + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + regs[mi][ni] = make_uint4(0, 0, 0, 0); + if( mask.any_valid(mi, ni) ) { + Base::load(regs[mi][ni], mi, ni); + } + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The dimensions of the tile computed by the CTA. + typename Cta_tile +> +struct Gmem_summary_stats { + + // The Mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The number of MMAs in M/N dimensions. + static constexpr int MMAS_M = Mma_tile::MMAS_M; + + // The size of each element. + static constexpr int BYTES_PER_ELEMENT = 4; + static constexpr int BYTES_PER_MMA = (Cta_tile::THREADS_PER_WARP / 4) * 2 * BYTES_PER_ELEMENT; + static constexpr int ROWS = Cta_tile::M; + + // Ctor. + template + inline __device__ Gmem_summary_stats(void *ptr, const Params ¶ms, const int tidx) + : ptr_(reinterpret_cast(ptr)), tidx_(tidx) { + + // The block index for the batch. + const int bidb = blockIdx.x; + // The block index for the head. + const int bidh = blockIdx.y; + // The block index. + // size_t bidx = bidb * params.h + bidh; + uint32_t bidx = bidb * params.h + bidh; + + // Extract the position in the warp. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The distance between two blocks (in bytes). + // size_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT; + uint32_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT; + + // Set store location for each thread at the beginning of the loop + ptr_row_ = ptr_ + bidx * block_stride_bytes; + ptr_ += bidx * block_stride_bytes + (lane / 4) * BYTES_PER_ELEMENT; + } + + // Store data to global memory. + inline __device__ void store(const uint32_t (&data)[MMAS_M * 2]) { + int warp = tidx_ / Cta_tile::THREADS_PER_WARP; + int lane = tidx_ % Cta_tile::THREADS_PER_WARP; + if ((warp == 0) && (lane % 4 == 0)) { + #pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // TODO: Not sure if it's right for MMAS_M > 1 + fmha::stg(ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT, data[mi * 2 + 0]); + fmha::stg(ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT, data[mi * 2 + 1]); + } + } + } + + // Store data to global memory. + inline __device__ void store_row(const uint32_t (&data)[MMAS_M], const int row) { + #pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // TODO: Not sure if it's right for MMAS_M > 1 + fmha::stg(ptr_row_ + mi * BYTES_PER_MMA + row * BYTES_PER_ELEMENT, data[mi]); + } + } + + // Load from global memory. + inline __device__ void load(uint32_t (&data)[MMAS_M * 2]) { + #pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // TODO: Not sure if it's right for MMAS_M > 1 + fmha::ldg(data[mi * 2 + 0], ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT); + fmha::ldg(data[mi * 2 + 1], ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT); + } + } + + // Load from global memory. + inline __device__ void load_next(uint32_t (&data)[MMAS_M * 2], int move_steps=1) { + char *ptr_next = ptr_ + move_steps * ROWS * BYTES_PER_ELEMENT; + #pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // TODO: Not sure if it's right for MMAS_M > 1 + fmha::ldg(data[mi * 2 + 0], ptr_next + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT); + fmha::ldg(data[mi * 2 + 1], ptr_next + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT); + } + } + + // Store data to global memory. + template + inline __device__ void load_row(uint32_t (&data)[N], const int row[N]) { + #pragma unroll + for (int ni = 0; ni < N; ++ni) { + fmha::ldg(data[ni], ptr_row_ + row[ni] * BYTES_PER_ELEMENT); + } + } + + // Move the pointer to the next location. + inline __device__ void move() { + ptr_ += ROWS * BYTES_PER_ELEMENT; + ptr_row_ += ROWS * BYTES_PER_ELEMENT; + } + + // Move the pointer to the next location. + inline __device__ void move(const int steps) { + ptr_ += ROWS * BYTES_PER_ELEMENT * steps; + ptr_row_ += ROWS * BYTES_PER_ELEMENT * steps; + } + + // The pointer. + char *ptr_; + char *ptr_row_; + const int tidx_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha/kernel_traits.h b/candle-extensions/candle-flash-attn-v1/kernels/fmha/kernel_traits.h new file mode 100644 index 00000000..2b8aac85 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha/kernel_traits.h @@ -0,0 +1,116 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#include + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FMHA_kernel_traits { + + // The CTA description for the 1st GEMM. + using Cta_tile_p = fmha::Cta_tile_extd; + // The CTA description for the 2nd GEMM. + using Cta_tile_o = fmha::Cta_tile_extd; + + // Do we use one buffer for K and V. + static constexpr bool SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u; + // Do we keep K in registers. + static constexpr bool K_IN_REGS = (FLAGS & 0x10u) == 0u; + // Do we keep V in registers. + static constexpr bool V_IN_REGS = (FLAGS & 0x100u) == 0u; + + // The global memory tile to load Q. + using Gmem_tile_q = fmha::Gmem_tile_qkv; + + // The shared memory tile to swizzle Q. + // using Smem_tile_q = fmha::Smem_tile_a; + using Smem_tile_q = fmha::Smem_tile_a; + + // The global memory tile to load K. + using Gmem_tile_k = fmha::Gmem_tile_qkv; + // The shared memory tile to swizzle K. + using Smem_tile_k = fmha::Smem_tile_b; + + // The global memory tile to load V. + using Gmem_tile_v = fmha::Gmem_tile_qkv; + // The shared memory tile to swizzle V. + using Smem_tile_v = fmha::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = fmha::Gmem_tile_o; + // The shared memory tile for O. + using Smem_tile_o = fmha::Smem_tile_o;; + + // The global memory tile to load/store S. + using Gmem_tile_s = fmha::Gmem_tile_mma_s; + + // The shared memory tile to transpose S. + using Smem_tile_st = fmha::Smem_tile_mma_transposed; + + using Gmem_tile_do = fmha::Gmem_tile_qkv; + + // // The global memory tile to store the accumulated dK and dV + // // Hack: we set BYTES_PER_LDGS=32 to emulate the access pattern of dK and dV + // // where there are 16 bits per lements and 16 bytes per load. In reality we won't + // // be issue any load or store of size 32 bytes. + // using Gmem_tile_dkv_accum = fmha::Gmem_tile_qkv; + + // The global memory tile to store the softmax sum. + using Gmem_softmax_sum = fmha::Gmem_summary_stats; + + // The shared memory tile to store dp sum. + using Smem_dp_sum = fmha::Smem_tile_dp_sum; + + using elem_type = elem_type_; + + // Make sure the number of threads match. + static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, ""); + + // The number of threads. + static constexpr int THREADS = Cta_tile_p::THREADS_PER_CTA; + // Make sure the number of threads matches both CTAs. + static_assert(THREADS == Cta_tile_o::THREADS_PER_CTA, ""); + + // The amount of shared memory needed to load Q and K. + static constexpr int BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE; + // The extra amount of shared memory needed to load V. + static constexpr int BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE; + // The amount of shared memory needed for Q, K and V.. + static constexpr int BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V; + // The amount of shared memory needed to load Q and store O. + static constexpr int BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE; + + // The amount of shared memory needed for Q, K, V and O. + static constexpr int BYTES_PER_SMEM = fmha::MaxConstexpr(BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO); + // Make sure we have enough shared memory. + static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, ""); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha/mask.h b/candle-extensions/candle-flash-attn-v1/kernels/fmha/mask.h new file mode 100644 index 00000000..00610950 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha/mask.h @@ -0,0 +1,90 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +namespace fmha { + + +template +struct Mask { + using Mma_tile = fmha::Hmma_tile; + + template + __device__ Mask(const BInfo &binfo, int tidx, const int loop_step_idx_ = 0) + : actual_seqlen_k(binfo.actual_seqlen_k - loop_step_idx_ * Cta_tile::N) + , loop_step_idx(loop_step_idx_) { + + const int warp = tidx / Cta_tile::THREADS_PER_WARP; + const int lane = tidx % Cta_tile::THREADS_PER_WARP; + + static_assert(Cta_tile::WARPS_K == 1, ""); + + // find the warp in the Cta tile + const int warp_n = (warp / Cta_tile::WARPS_M); + const int warp_m = (warp % Cta_tile::WARPS_M); + // decompose warp into 8x4 tile + const int quad = lane / 4; + const int tid = (lane % 4) * 2; + row = warp_m * 16 + quad; + col = warp_n * 16 + tid; + } + + inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const { + + // ii and jj iterate over the 2x4 fragment + // const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1); + const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1); + const int current_row = row_offset + ii * 8; + const bool col_valid = current_col < actual_seqlen_k; + // const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k; + //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k; + // bool all_valid = Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 1)) { + // printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid); + // } + return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid; + // return row_valid && col_valid; + } + + //BERT Mask: if upper left is invalid, none are valid + inline __device__ bool any_valid(const int mi, const int ni) const { + return is_valid(mi, ni, 0, 0) || is_valid(mi, ni, 1, 0); + } + + inline __device__ void load(const int it) { + row_offset = it * Cta_tile::M + row; + } + int row_offset; + + int row; + int col; + const int loop_step_idx; + const int actual_seqlen_k; +}; + +} // namespace fmha diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha/smem_tile.h b/candle-extensions/candle-flash-attn-v1/kernels/fmha/smem_tile.h new file mode 100644 index 00000000..80efee99 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha/smem_tile.h @@ -0,0 +1,1703 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include "utils.h" +#include "utils.h" +#include "gemm.h" + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The description of the tile computed by this CTA. + typename Cta_tile, + // The number of rows in the 2D shared memory buffer. + int M_, + // The number of cols. + int N_, + // The size in bits of each element. + int BITS_PER_ELEMENT_, + // The number of bytes per STS. + int BYTES_PER_STS_ = 16, + // The number of buffers. (Used in multistage and double buffer cases.) + int BUFFERS_PER_TILE_ = 1, + // Do we enable the fast path for LDS.128 and friends. + int ENABLE_LDS_FAST_PATH_ = 0, + // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. + int ROWS_PER_XOR_PATTERN_ = 8, + // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. + int COLS_PER_XOR_PATTERN_ = 1, + // Use or not predicates + bool USE_PREDICATES_ = true +> +struct Smem_tile_without_skews { + + // The size in bits of each element. + enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; + // The size in bytes of a single STS. + enum { BYTES_PER_STS = BYTES_PER_STS_ }; + // The number of elements per STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + // To support arbitrary N, we pad some values to a power-of-2. + enum { N_WITH_PADDING = Next_power_of_two::VALUE }; + // The number of bytes per row without packing of rows. + enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 }; + // The number of bytes per row -- we want at least 128B per row. + enum { BYTES_PER_ROW = Max::VALUE }; + // The number of rows in shared memory (two rows may be packed into a single one). + enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW }; + + // The number of threads per row. + enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS }; + // The number of threads per row. + enum { THREADS_PER_ROW = Min::VALUE }; + + // The number of STS per row. + enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; + // It must be at least one. + static_assert(STS_PER_ROW >= 1, ""); + // The number of rows written with a single STS. + enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + // Make sure we write to at least one row per STS. Thanks Dr. Obvious ;) + static_assert(ROWS_PER_STS >= 1, ""); + // The number of STS needed to store all rows. + enum { STS_PER_COL = Div_up::VALUE }; + // The number of STS in total. + enum { STS = STS_PER_COL * STS_PER_ROW }; + + // TD [2022-06-02] In the case of Q (16 x 64) in the backward pass with 256 threads, + // we only need to store 16 * 64 * 2 = 2KB instead of 4KB. + static constexpr bool PARTIAL_STORE = ROWS_PER_STS > ROWS; + static constexpr int STORING_THREADS = PARTIAL_STORE ? ROWS * THREADS_PER_ROW : Cta_tile::THREADS_PER_CTA; + + // The size of one buffer in bytes in shared memory. + // enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA }; + enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * STORING_THREADS }; + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // Do we enable the LDS.128 fast path? + enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ }; + static_assert(ENABLE_LDS_FAST_PATH == 0); + // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. + enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ }; + // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. + enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS }; + // Use or not predicates + enum { USE_PREDICATES = USE_PREDICATES_ }; + + // The type of elements that are stored in shared memory by each thread. + using Store_type = typename Uint_from_size_in_bytes::Type; + + // Ctor. + inline __device__ Smem_tile_without_skews(void *smem, int tidx) + : smem_(__nvvm_get_smem_pointer(smem)), tidx_(tidx) { + + // The row written by a thread. See doc/mma_smem_layout.xlsx. + int smem_write_row = tidx / THREADS_PER_ROW; + + // The XOR pattern. + int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN; + // Compute the column and apply the XOR pattern. + int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor; + + // The offset. + this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS; + + // TODO: Why not merge it with the read offset? + // this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); + // this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); + } + + // Compute the store pointers. + template< int N > + inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { + #pragma unroll + for( int ii = 0; ii < N; ++ii ) { + // Decompose the STS into row/col. + int row = ii / STS_PER_ROW; + int col = ii % STS_PER_ROW; + + // Assemble the offset. + int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW; + + // Take the column into account. + if( STS_PER_ROW > 1 ) { + offset += col*THREADS_PER_ROW*BYTES_PER_STS; + } + + // Apply the XOR pattern if needed. + if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) { + const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN; + offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS; + } + + // Assemble the final pointer :) + // ptrs[ii] = smem_ + offset + smem_write_buffer_; + // smem_write_buffer_ is already merged with smem_write_offset_ + ptrs[ii] = smem_ + offset; + } + } + + inline __device__ void debug_reset() { + for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { + for( int row = 0; row < ROWS; ++row ) { + for( int col = 0; col < BYTES_PER_ROW; col += 4 ) { + if( threadIdx.x == 0 ) { + uint32_t val = 0x0; + sts(val, smem_ + row*BYTES_PER_ROW + col + buffer); + } + } + } + } + } + + // Print the content of the tile (only for debug ;)). + inline __device__ void debug_print() const { + for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { + for( int row = 0; row < ROWS; ++row ) { + for( int col = 0; col < BYTES_PER_ROW; col += 4 ) { + if( threadIdx.x == 0 ) { + uint32_t val; + lds(val, smem_ + row*BYTES_PER_ROW + col + buffer); + printf("block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\n", + blockIdx.x, + blockIdx.y, + blockIdx.z, + smem_, + buffer, + row, + col, + val); + } + } + } + } + } + + // Move the read offset to next buffer. + inline __device__ void move_to_next_read_buffer() { + // if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) { + // this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; + // } else if( BUFFERS_PER_TILE > 1 ) { + // this->smem_read_buffer_ += BYTES_PER_BUFFER; + // } + if( BUFFERS_PER_TILE > 1 && smem_read_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) { + this->smem_read_offset_ -= BYTES_PER_TILE_INC_BOUNDARY; + } else if( BUFFERS_PER_TILE > 1 ) { + this->smem_read_offset_ += BYTES_PER_BUFFER; + } + } + + // Move the read offset to next buffer. TODO: Remove this member function!!! + inline __device__ void move_next_read_buffer() { + this->move_to_next_read_buffer(); + } + + // Move the read offset to next N buffer (circular-buffer). + inline __device__ void move_to_next_read_buffer(int N) { + if( BUFFERS_PER_TILE > 1 ) { + // this->smem_read_buffer_ += N * BYTES_PER_BUFFER; + // this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0; + this->smem_read_offset_ += N * BYTES_PER_BUFFER; + this->smem_read_offset_ -= smem_read_offset_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0; + } + } + + // Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!! + inline __device__ void move_next_read_buffer(int N) { + this->move_to_next_read_buffer(N); + } + + // Move the write offset to next buffer. + inline __device__ void move_to_next_write_buffer() { + // if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) { + // this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; + // } else if( BUFFERS_PER_TILE > 1 ) { + // this->smem_write_buffer_ += BYTES_PER_BUFFER; + // } + if( BUFFERS_PER_TILE > 1 && smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) { + this->smem_write_offset_ -= BYTES_PER_TILE_INC_BOUNDARY; + } else if( BUFFERS_PER_TILE > 1 ) { + this->smem_write_offset_ += BYTES_PER_BUFFER; + } + } + + // Move the write offset to next buffer. TODO: Remove that member function! + inline __device__ void move_next_write_buffer() { + this->move_to_next_write_buffer(); + } + + // Move the read offset. + inline __device__ void move_read_offset(int delta) { + this->smem_read_offset_ += delta; + } + + // Move the write offset. + inline __device__ void move_write_offset(int delta) { + this->smem_write_offset_ += delta; + } + + // Store to the tile in shared memory. + template< int N > + inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + // Trying to reduce the shared mem for Q from 4KB per buffer to 2KB per buffer. + if (!PARTIAL_STORE || (tidx_ / THREADS_PER_ROW < ROWS)) { + sts(smem_ptrs, data); + } + } + + // Store to the tile in shared memory. + template< int N, int M > + inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + sts(smem_ptrs, data, preds); + } + + // Store to the tile in shared memory. + template< int N > + inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) { + this->store(data, preds); + } + + // Store to the tile in shared memory. + template< int N > + inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) { + uint32_t tmp[1] = { preds }; + this->store(gmem_ptrs, tmp); + } + + // The shared memory pointer. + const uint32_t smem_; + // The read offset. Reserve 4 offsets if needed. + int smem_read_offset_; + // The write offset. + int smem_write_offset_; + // The buffer base offset for read. + // int smem_read_buffer_; + // The buffer base offset for write. + // int smem_write_buffer_; + const int tidx_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The layout of the tile. + typename Layout, + // The size of the STS. + int BYTES_PER_STS = 16, + // The number of buffers per tile. + int BUFFERS_PER_TILE = 1, + // Use or not predicates + bool USE_PREDICATES = true +> +struct Smem_tile_a { +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int MMAS_K, int MMAS_K_WITH_PADDING > +struct Compute_reset_mask { + // The potential mask. + enum { HALF = MMAS_K_WITH_PADDING / 2 }; + // The remainder. + enum { MOD = MMAS_K % HALF }; + // The final value. + enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask::VALUE }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int MMAS_K_WITH_PADDING > +struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> { + enum { VALUE = 0 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int MMAS_K > +struct Compute_reset_mask { + enum { VALUE = MMAS_K - 1 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N > +struct Rows_per_xor_pattern_a { + // The size in bits. + enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A }; + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N > +struct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a { +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a::VALUE +> +struct Smem_tile_row_a : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The base class. + using Base = Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_a; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = fmha::Hmma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_row_a(void *smem, int tidx) : Base(smem, tidx) { + + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + const int WARPS_M = Cta_tile::WARPS_M; + const int WARPS_N = Cta_tile::WARPS_N; + const int WARPS_K = Cta_tile::WARPS_K; + + static_assert(WARPS_M == 1); + static_assert(WARPS_N == 4 || WARPS_N == 8); + static_assert(WARPS_K == 1); + static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8); + + // The row and column read by the thread. + int smem_read_row = (tidx & 0x0f); + constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING; + int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN; + smem_read_col ^= (tidx & 0x10) / 16; + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if( Mma_tile_with_padding::MMAS_K >= 2 ) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Load from shared memory. + inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) { + #pragma unroll + for( int mi = 0; mi < Mma_tile::MMAS_M; ++mi ) { + // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). + int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // Load using LDSM.M88.4. + uint4 tmp; + // ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset); + + // Store the value into the fragment. + a[mi].reg(0) = tmp.x; + a[mi].reg(1) = tmp.y; + a[mi].reg(2) = tmp.z; + a[mi].reg(3) = tmp.w; + } + + // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; + } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; + } else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; + } else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; + } else if( Mma_tile_with_padding::MMAS_K >= 2 ) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; + } + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE +> +struct Smem_tile_a + : public Smem_tile_row_a { + // The base class. + using Base = Smem_tile_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void *smem, int tidx) : Base(smem, tidx) { + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The layout of the tile. + typename Layout, + // The size of the STS. + int BYTES_PER_STS = 16, + // The number of buffers per tile. + int BUFFERS_PER_TILE = 1, + // Use or not predicates + bool USE_PREDICATES = true +> +struct Smem_tile_b { +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N > +struct Rows_per_xor_pattern_b { + // The size in bits. + enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B }; + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N > +struct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b { +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b::VALUE +> +struct Smem_tile_col_b : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The base class. + using Base = Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_b< Col>; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding< Cta_tile>; + // The number of MMAs. + using Mma_tile_with_padding = fmha::Hmma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // The number of STS per thread + enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; + // The number of STS per thread must be at least 1. + enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; + + // Ctor. + inline __device__ Smem_tile_col_b(void *smem, int tidx) : Base(smem, tidx) { + + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + const int WARPS_M = Cta_tile::WARPS_M; + const int WARPS_N = Cta_tile::WARPS_N; + const int WARPS_K = Cta_tile::WARPS_K; + static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8); + static_assert(WARPS_M == 1); + static_assert(WARPS_N == 4 || WARPS_N == 8); + static_assert(WARPS_K == 1); + + // The masks to select the warps. + const int WARP_MASK_N = Warp_masks::N; + + // The divisor for the warps. + const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA + + (tidx & 0x07) + + (tidx & 0x10) / 2; + constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING; + int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN; + smem_read_col ^= (tidx & 0x08) / 8; + // The shared memory offset. + this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if( Mma_tile_with_padding::MMAS_K >= 2 ) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { + #pragma unroll + for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { + // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). + int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // Load using LDSM.M88.4. + uint4 tmp; + // ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset); + + // Store the value into the fragment. + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + } + + // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; + } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; + } else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; + } else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; + } else if( Mma_tile_with_padding::MMAS_K >= 2 ) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE +> +struct Smem_tile_b< Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE > + : public Smem_tile_col_b { + + // The base class. + using Base = Smem_tile_col_b< Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>; + + // Ctor. + inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) { + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N > +struct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b< N> { +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template< + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b::VALUE, + // How many cols to use for the XOR pattern to avoid bank conflicts? + int COLS_PER_XOR_PATTERN_ = 1 +> +struct Smem_tile_row_b : public Smem_tile_without_skews { + + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The base class. + using Base = Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_b; + + // Can we use LDSM? No if the data type is 32-bit large. + enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 }; + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 }; + // The number of elements per LDS. + enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B }; + + // The number of STS per thread + enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; + // The number of STS per thread must be at least 1. + enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; + + // Ctor. + inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) { + + // The number of warps. + const int WARPS_M = Cta_tile::WARPS_M; + const int WARPS_N = Cta_tile::WARPS_N; + const int WARPS_K = Cta_tile::WARPS_K; + static_assert(WARPS_K == 1); + static_assert(WARPS_M == 4 || WARPS_M == 8); + static_assert(WARPS_N == 1); + + // The masks to select the warps. + const int WARP_MASK_N = Warp_masks::N; + const int WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + + static_assert(USE_LDSMT); + static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8); + + // The row/col read by the thread. + int smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 + + (tidx & 0x07) + (tidx & 0x08); + constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING; + int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN; + smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16; + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS; + + // Fill zeroes for group conv + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // The size of each element in bits. + const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; + // The size in bytes of the data needed to compute an MMA per CTA. + const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; + + #pragma unroll + for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if( BYTES_PER_MMA_PER_CTA >= 128 ) { + // Nothing to do! + } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } else if( BYTES_PER_MMA_PER_CTA == 64 ) { + // Nothing to do! + } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); + } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) + if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && + Mma_tile::MMAS_N % 2 == 1 ) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { + // The size of each element in bits. + const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; + // The size in bytes of the data needed to compute an MMA per CTA. + const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; + + // uint32_t smem_read_og = this->smem_ + this->smem_read_offset_; + #pragma unroll + for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { + // Prepare the offset. + int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW_BEFORE_PACKING; + if ( BYTES_PER_MMA_PER_CTA == 32 ) { + offset += this->smem_read_offset_; + } else if ( BYTES_PER_MMA_PER_CTA == 64 ) { + offset += this->smem_read_offset_ + (ni/2) * BYTES_PER_MMA_PER_CTA * 2; + } else { + offset += this->smem_read_offset_ + (ni ) * BYTES_PER_MMA_PER_CTA; + } + + // Load the data using LDSM.MT88.2. + // uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset; + uint32_t ptr = this->smem_ + offset; + uint4 tmp; + if( USE_LDSMT ) { + ldsmt(tmp, ptr); + } else { + lds(tmp.x, (ptr ) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING); + lds(tmp.y, (ptr ) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING); + lds(tmp.z, (ptr ^ 32) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING); + lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING); + } + + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("BYTES_PER_MMA_PER_CTA=%d, ni = %d, smem_read diff = %d\n", BYTES_PER_MMA_PER_CTA, ni, ptr - smem_read_og); + // } + // Store those values in the fragment. + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + + // Move the pointer for the next ni. I expect the compiler to not recompute those. + if( BYTES_PER_MMA_PER_CTA >= 128 ) { + // Nothing to do! + } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } else if( BYTES_PER_MMA_PER_CTA == 64 ) { + // Nothing to do! + } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 8 ) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2)); + } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); + } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) + if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && + Mma_tile::MMAS_N % 2 == 1 ) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE +> +struct Smem_tile_b + : public Smem_tile_row_b { + + // The base class. + using Base = Smem_tile_row_b; + + // Ctor. + inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) { + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v : public fmha::Smem_tile_without_skews::VALUE, 1> { + + // The base class. + using Base = Smem_tile_without_skews::VALUE, 1>; + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The fragment. + using Fragment = Fragment_b< fmha::Col>; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v(void *smem, int tidx) : Base(smem, tidx) { + + // The row/col read by the thread. + int read_row, read_col; + + static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); + + read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f); + constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING; + read_col = ((read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN; + read_col ^= (tidx & 0x10) / 16; + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW_BEFORE_PACKING + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { +#pragma unroll + for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { + // Jump by 16 * #warps row. + int row = ki * 16 * Cta_tile::WARPS_K; + + // Load the data using LDSM.MT88.2. + uint4 tmp; + fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW_BEFORE_PACKING); + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + + // Move the pointer for the next ni. I expect the compiler to not recompute those. + if( Mma_tile::MMAS_N == 1 ) { + // noop + } else if( Mma_tile::MMAS_N == 2 ) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } else if( Mma_tile::MMAS_N == 4 ) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); + } else if (Mma_tile::MMAS_N == 8) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2)); + } else { + assert(false); // Not implemented! + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o { + + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // The accumulators. + using Data_type = typename Accumulator::Data_type; + + // The size of each element. + static constexpr int BYTES_PER_ELEMENT = sizeof(Data_type); + // The size of each STS. + static constexpr int BYTES_PER_STS = 8; + // The size of each row in shared memory. + static constexpr int BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT; + + // The size of each LDS. + static constexpr int BYTES_PER_LDS = 16; + static constexpr int THREADS_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT / BYTES_PER_LDS; + + // The number of rows. + static constexpr int ROWS = Cta_tile::M; + // The number of "rows" to process per loop iteration (in the "epilogue"). + static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA; + // The number of outer loops. + static constexpr int LOOPS = ROWS / ROWS_PER_LOOP; + // Make sure it matches our expectations. + static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); + + // The number of rows loaded per LDS. + static constexpr int ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW; + // Do we have to guard against partial writes/reads. + static constexpr bool HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0; + // The total number of LDS per loop. + static constexpr int LDS_PER_LOOP = fmha::DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_LDS); + + // The amount of shared memory. + static constexpr int BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW; + + // The write pointer. + uint32_t smem_write_, smem_read_; + // Is the thread active for the last LDS of the series? + int is_active_for_last_lds_; + + // static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K); + static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); + + // Ctor. + inline __device__ Smem_tile_o(void *smem, int tidx) { + + // Get a 32-bit value for the shared memory address. + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); + static_assert(Cta_tile::N == 16 || Cta_tile::N == 32 || Cta_tile::N == 64 || Cta_tile::N == 128); + + int write_row = (tidx & 0x1c) / 4; + + const int lane = tidx % 32; + const int warp = tidx / 32; + + constexpr int ELEMENTS_PER_STS = BYTES_PER_STS / BYTES_PER_ELEMENT; + constexpr int STS_PER_WARP = 16 * Mma_tile::MMAS_N / ELEMENTS_PER_STS; + int write_col = warp * STS_PER_WARP + lane % STS_PER_WARP; + + // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("write_row = %d, write_col = %d\n", write_row, write_col); + // } + + // if ((blockIdx.x == 0) && (blockIdx.y == 0) && (write_row == 0) && (write_col == 0)) { + // printf("threadIdx.x = %d\n", threadIdx.x); + // } + + // Assemble the write pointer. + smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + + // The element read by each thread. + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + + // Take the XOR pattern into account for the column. + read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8))); + // read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8)))); + + // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("read_row = %d, read_col = %d\n", read_row, read_col); + // } + // if ((blockIdx.x == 0) && (blockIdx.y == 0) && (read_row == 0) && (read_col == 0)) { + // printf("threadIdx.x = %d\n", threadIdx.x); + // } + // Assemble the read pointer. + this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + + // Is that thread active on the last LDS? + if( HAS_INCOMPLETE_LDS ) { + this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M; + } + } + + // Load the output fragments. + template + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { + #pragma unroll + for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) { + + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K]; + #pragma unroll + for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) { + int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT; + uint32_t smem_read = this->smem_read_ + imm; + // TD [2022-06-05] Ugly fix for d=128 in the forward pass, maybe there's a better way. + if ((Cta_tile::N == 128) && (ROWS_PER_LDS == 4) && (ii % 2 == 1)) { + smem_read ^= 8 * BYTES_PER_LDS; + } + // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("imm diff = %d\n", smem_read - this->smem_read_); + // } + if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) { + // fmha::lds(tmp[jj], this->smem_read_ + imm); + fmha::lds(tmp[jj], smem_read); + } + } + + // Perform the reduction. + out[ii] = zero_init ? tmp[0] : fmha::fadd4(out[ii], tmp[0]); + // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("out reduction: out = %.6f\n", reinterpret_cast(out[ii])[0]); + // } + #pragma unroll + for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) { + out[ii] = fmha::fadd4(out[ii], tmp[jj]); + // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("out reduction tmp = %.6f, out = %.6f\n", reinterpret_cast(tmp[jj])[0], reinterpret_cast(out[ii])[0]); + // } + } + } + } + + // Store the accumulators. + template + inline __device__ void store(const Accumulator (&acc)[M][N], int mi) { + // uint32_t smem_write_og = this->smem_write_; + static constexpr int M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA; + #pragma unroll + for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { + + // The number of MMAs that are stored per loop iteration. + static constexpr int MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS; + + // Store 1st column of the different MMAs. + #pragma unroll + for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + uint2 tmp0, tmp1; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1); + + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3); + + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og); + // } + + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // uint4 read_tmp; + // fmha::lds(read_tmp, this->smem_read_); + // printf("smem_o = %.6f\n", reinterpret_cast(read_tmp)[0]); + // } + // Swizzle the write pointer using a XOR of 16B. + this->smem_write_ ^= 32; + + // Store 2nd column of the different MMAs. + #pragma unroll + for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + + uint2 tmp0, tmp1; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5); + + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7); + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + + // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og); + // } + + // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. + static_assert(Mma_tile::MMAS_N <= 8, "Not implemented"); + if( Mma_tile::MMAS_N >= 8 && ni % 4 == 3 ) { + this->smem_write_ ^= 15 * 32; + } else if( Mma_tile::MMAS_N >= 4 && ni % 2 == 1 ) { + this->smem_write_ ^= 7 * 32; + } else if( Mma_tile::MMAS_N >= 2 ) { + this->smem_write_ ^= 3 * 32; + } else { + this->smem_write_ ^= 3 * 32; + } + // this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // uint4 read_tmp; + // fmha::lds(read_tmp, this->smem_read_); + // printf("smem_o = %.6f\n", reinterpret_cast(read_tmp)[0]); + // } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_mma { + + using Mma_tile = fmha::Hmma_tile; + using Fragment = fmha::Fragment_a; + + enum { COLS = Cta_tile::N }; + enum { BYTES_PER_ELT = 2 }; + enum { BYTES_PER_STS = 4 }; + enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO + enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW }; + + enum { WARPS_M = Cta_tile::WARPS_M }; + enum { WARPS_N = Cta_tile::WARPS_N }; + enum { WARPS_K = Cta_tile::WARPS_K }; + + static_assert(WARPS_K == 1); + inline __device__ Smem_tile_mma(char *smem, int tidx) { + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + int write_col, write_row; + static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_M == 8) || WARPS_N == 1); + if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xe0) / 4 + (tidx & 0x03); + write_col ^= (write_row & 0x07) * 4; + } else { + write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x03); + // write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x0f)))) * 4; + write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x07)))) * 4; + } + + // write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + } + + template + inline __device__ void store(const uint4 (®s)[M][N]) { + static_assert(COLS == Cta_tile::N); + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + // size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); + // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); + // offset ^= 4 * BYTES_PER_STS; + // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); + // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); + // size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); + fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); + offset ^= 4 * BYTES_PER_STS; + fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); + fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); + } + } + } + + template + inline __device__ void store(const Fragment (&frag)[N][M]) { + static_assert(COLS == Cta_tile::N); + uint4 regs[M][N]; + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + // Need to transpose ref(1) and reg(2) here since when we load it we transpose again. + regs[mi][ni] = make_uint4(frag[ni][mi].reg(0), frag[ni][mi].reg(2), + frag[ni][mi].reg(1), frag[ni][mi].reg(3)); + } + } + this->store(regs); + } + + // uint32_t smem_; + // uint32_t write_offset_; + uint32_t smem_write_; +}; + +template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>> +struct Smem_tile_mma_transposed : public Base { + enum { BYTES_PER_LDS = 16 }; + enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; + enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; + enum { WARPS_M = Base::WARPS_M }; + enum { WARPS_N = Base::WARPS_N }; + static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); + using Fragment = typename Base::Fragment; + inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) { + + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); + int read_row, read_col; + read_row = (tidx & 0x0f); + read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16; + + // read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x0f)))); + read_col ^= (read_row & 0x07); + // read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + template + inline __device__ void load(Fragment (&frag)[M][N]) { + static_assert(Base::COLS == Cta_tile::N); + for( int mi = 0; mi < M; mi++ ) { + for( int ni = 0; ni < N; ni++ ) { + // size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint4 dst; + // fmha::ldsmt(dst, this->smem_ + offset); + // size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + fmha::ldsmt(dst, offset); + frag[mi][ni].reg(0) = dst.x; + frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major! + frag[mi][ni].reg(2) = dst.y; + frag[mi][ni].reg(3) = dst.w; + } + } + } + + // uint32_t read_offset_; + uint32_t smem_read_; +}; + +template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>> +struct Smem_tile_mma_epilogue : public Base { + enum { BYTES_PER_LDS = 16 }; + enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; + enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS }; + static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW); + enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS }; + static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M); + enum { WARPS_M = Base::WARPS_M }; + enum { WARPS_N = Base::WARPS_N }; + static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); + + using Acc = fmha::Fragment_accumulator; + + inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) { + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + const int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + // read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : 0x07))); + static_assert(Base::BYTES_PER_ROW == 32 || Base::BYTES_PER_ROW == 64 || Base::BYTES_PER_ROW == 128 || Base::BYTES_PER_ROW == 256); + read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x07)))); + // read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + inline __device__ void load(uint4 (&data)[NUM_LDS]) { + for( int ii = 0; ii < NUM_LDS; ii++ ) { + // size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; + // fmha::lds(data[ii], this->smem_ + offset); + // size_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; + uint32_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; + fmha::lds(data[ii], offset); + } + } + + template + inline __device__ void store(const Acc (&acc)[M][N]){ + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + // 1st row - 4 elements per row. + float tmp00 = acc[mi][ni].elt(0); + float tmp01 = acc[mi][ni].elt(1); + float tmp02 = acc[mi][ni].elt(4); + float tmp03 = acc[mi][ni].elt(5); + // 2nd row - 4 elements per row. + float tmp10 = acc[mi][ni].elt(2); + float tmp11 = acc[mi][ni].elt(3); + float tmp12 = acc[mi][ni].elt(6); + float tmp13 = acc[mi][ni].elt(7); + + uint32_t x = fmha::float2_pack(tmp00, tmp01); + uint32_t y = fmha::float2_pack(tmp02, tmp03); + uint32_t z = fmha::float2_pack(tmp10, tmp11); + uint32_t w = fmha::float2_pack(tmp12, tmp13); + + // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x); + // fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z); + // offset ^= 4 * Base::BYTES_PER_STS; + // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y); + // fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w); + // size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_); + // } + fmha::sts(offset + 0 * BYTES_PER_ROW, x); + fmha::sts(offset + 8 * BYTES_PER_ROW, z); + offset ^= 4 * Base::BYTES_PER_STS; + fmha::sts(offset + 0 * BYTES_PER_ROW, y); + fmha::sts(offset + 8 * BYTES_PER_ROW, w); + } + } + } + + template + inline __device__ void store(const uint4 (®s)[M][N]) { + for( int mi = 0; mi < M; mi++ ) { + for( int ni = 0; ni < N; ni++ ) { + // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + uint32_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); + fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); + offset ^= 4 * Base::BYTES_PER_STS; + fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); + fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); + } + } + } + + // uint32_t read_offset_; + uint32_t smem_read_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_transpose { + + using Mma_tile = fmha::Hmma_tile; + using Fragment_write = fmha::Fragment_b; + using Fragment_read = fmha::Fragment_b; + + enum { COLS = Cta_tile::N }; + enum { BYTES_PER_ELT = 2 }; + enum { BYTES_PER_STS = 4 }; + enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO + enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW }; + + enum { BYTES_PER_LDS = 16 }; + + enum { WARPS_M = Cta_tile::WARPS_M }; + enum { WARPS_N = Cta_tile::WARPS_N }; + enum { WARPS_K = Cta_tile::WARPS_K }; + + static_assert(WARPS_K == 1); + static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); + + inline __device__ Smem_tile_transpose(char *smem, int tidx) { + smem_ = __nvvm_get_smem_pointer(smem); + // uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + int write_col, write_row; + static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); + if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xe0) / 4 + (tidx & 0x03); + } else { + write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x03); + } + write_col ^= (write_row & 0x07) * 4; + + write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + // smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + + int read_row, read_col; + read_row = (tidx & 0x0f); + read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16; + + read_col ^= (read_row & 0x07); + read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + // smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + template + inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); + offset ^= 4 * BYTES_PER_STS; + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); + } + } + + template + inline __device__ void load(Fragment_read (&frag_r)[N]) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint4 dst; + fmha::ldsmt(dst, this->smem_ + offset); + frag_r[ni].reg(0) = dst.x; + frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! + frag_r[ni].reg(2) = dst.z; + frag_r[ni].reg(3) = dst.w; + } + } + + template + inline __device__ void transpose(const Fragment_write (&frag_w)[M][N], Fragment_read (&frag_r)[M], int mi) { + static_assert(COLS == Cta_tile::N); + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); + offset ^= 4 * BYTES_PER_STS; + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); + } + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint4 dst; + fmha::ldsmt(dst, this->smem_ + offset); + frag_r[ni].reg(0) = dst.x; + frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! + frag_r[ni].reg(2) = dst.z; + frag_r[ni].reg(3) = dst.w; + } + } + + uint32_t smem_; + uint32_t write_offset_; + uint32_t read_offset_; + // uint32_t smem_write_; + // uint32_t smem_read_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + typename Gmem_tile, + // The number of buffers. (Used in multistage and double buffer cases.) + int BUFFERS_PER_TILE_ = 1 +> +struct Smem_tile_dp_sum { + + using Cta_tile = typename Gmem_tile::Cta_tile; + using Mma_tile = fmha::Hmma_tile; + + // The size of each element. + static constexpr int BYTES_PER_ELEMENT = 4; + static constexpr int ROWS = Gmem_tile::ROWS; + static constexpr int THREADS_PER_ROW = Gmem_tile::THREADS_PER_ROW; + static constexpr int MMAS_M = Mma_tile::MMAS_M; + + static constexpr int ROWS_PER_LDG = Gmem_tile::ROWS_PER_LDG; + static constexpr int LDGS = Gmem_tile::LDGS; + + static constexpr int ROWS_PER_MMA = Mma_tile::M_PER_MMA; + + // The size of one buffer in bytes in shared memory. + static constexpr int BYTES_PER_BUFFER = ROWS * BYTES_PER_ELEMENT; + // The number of buffers. + static constexpr int BUFFERS_PER_TILE = BUFFERS_PER_TILE_; + // The size in bytes of total buffers. + static constexpr int BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE; + // The boundary for smem_read_offset and smem_write_offset increment. + static constexpr int ROWS_PER_TILE_INC_BOUNDARY = ROWS * BUFFERS_PER_TILE - ROWS; + + inline __device__ Smem_tile_dp_sum(float *smem, const int tidx) + : smem_(smem), smem_read_buffer_(smem), smem_write_buffer_(smem), tidx_(tidx) { + } + + // Move the read offset to next buffer. + inline __device__ void move_to_next_read_buffer() { + if( BUFFERS_PER_TILE > 1 && (smem_read_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY ) { + this->smem_read_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY; + } else if( BUFFERS_PER_TILE > 1 ) { + this->smem_read_buffer_ += ROWS; + } + } + + // Move the write offset to next buffer. + inline __device__ void move_to_next_write_buffer() { + if( BUFFERS_PER_TILE > 1 && (smem_write_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY ) { + this->smem_write_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY; + } else if( BUFFERS_PER_TILE > 1 ) { + this->smem_write_buffer_ += ROWS; + } + } + + inline __device__ void store(const float (&sum)[LDGS]) { + if (tidx_ % THREADS_PER_ROW == 0) { + int row = tidx_ / THREADS_PER_ROW; + #pragma unroll + for (int i = 0; i < LDGS; ++i) { + if (row + i * ROWS_PER_LDG < ROWS) { + smem_write_buffer_[row + i * ROWS_PER_LDG] = sum[i]; + } + } + } + } + + inline __device__ void store(const float sum, const int buffer_idx) { + float *smem_write = smem_ + buffer_idx * ROWS; + int row = tidx_ / THREADS_PER_ROW; + if ((row < ROWS) && (tidx_ % THREADS_PER_ROW == 0)) { + smem_write[row] = sum; + } + } + + inline __device__ void store(const float (&sum)[LDGS], const int buffer_idx) { + float *smem_write = smem_ + buffer_idx * ROWS; + if (tidx_ % THREADS_PER_ROW == 0) { + int row = tidx_ / THREADS_PER_ROW; + #pragma unroll + for (int i = 0; i < LDGS; ++i) { + if (row + i * ROWS_PER_LDG < ROWS) { + smem_write[row + i * ROWS_PER_LDG] = sum[i]; + } + } + } + } + + inline __device__ void store_pair(const float (&sum)[MMAS_M * 2]) { + float *smem_write = smem_; + // Extract the position in the warp. + int warp = tidx_ / Cta_tile::THREADS_PER_WARP; + int lane = tidx_ % Cta_tile::THREADS_PER_WARP; + int row = lane / 4; + #pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0]; + smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1]; + } + } + + inline __device__ void store_pair(const float (&sum)[MMAS_M * 2], const int buffer_idx) { + float *smem_write = smem_ + buffer_idx * ROWS; + // Extract the position in the warp. + int warp = tidx_ / Cta_tile::THREADS_PER_WARP; + int lane = tidx_ % Cta_tile::THREADS_PER_WARP; + int row = lane / 4; + #pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0]; + smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1]; + } + } + + template + inline __device__ void load(float (&sum)[N], const int (&row)[N]) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + sum[ni] = smem_read_buffer_[row[ni]]; + } + } + + template + inline __device__ void load(float (&sum)[N], const int (&row)[N], const int buffer_idx) { + float *smem_read = smem_ + buffer_idx * ROWS; + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + sum[ni] = smem_read[row[ni]]; + } + } + + static inline __device__ float reduce_warp(float sum) { + fmha::SumOp sum_op; + return fmha::Allreduce::run(sum, sum_op); + } + + const int tidx_; + float * const smem_; + float *smem_read_buffer_; + float *smem_write_buffer_; +}; + +} // namespace fmha diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha/softmax.h b/candle-extensions/candle-flash-attn-v1/kernels/fmha/softmax.h new file mode 100644 index 00000000..af559835 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha/softmax.h @@ -0,0 +1,607 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float apply_exp_(float x, float max) { + return __expf(x - max); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float apply_exp2_(float x, float max) { + return exp2f(x - max); + // With fast-math, this produces the same PTX instruction as the assembly below + // float diff = x - max; + // float res; + // asm ("ex2.approx.ftz.f32 %0, %1;\n\t" : "=f"(res) : "f"(diff)); + // return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ReadType {}; +template<> struct ReadType<4> { using T = float;}; +template<> struct ReadType<8> { using T = float2;}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_reduce { + // Helper class to distribute MMA tiles reduced over rows per warp over quads. + + // The Mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The number of MMAs in M/N dimensions. + static constexpr int MMAS_M = Mma_tile::MMAS_M; + static constexpr int MMAS_N = Mma_tile::MMAS_N; + + static constexpr int WARPS_M = Cta_tile::WARPS_M; + static constexpr int WARPS_N = Cta_tile::WARPS_N; + + + static constexpr int ROWS = WARPS_M * MMAS_M * 16; + static constexpr int COLS = WARPS_N; + static_assert(COLS == 4 || COLS == 8); + static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8; + static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float); + static constexpr int ELTS_PER_TILE = ROWS * COLS; + + static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW; + // TD [2022-05-02]: No longer true if head_dim != 64 + // static_assert(THREADS_PER_GROUP == 16); // DEBUG + static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP; + static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS; + static_assert(LOOPS == 1); + + using read_t = typename ReadType::T; + + __device__ inline Smem_tile_reduce(float *smem_, const int tidx) { + + int lane = tidx % 32; + int warp = tidx / 32; + + int warp_m = warp % WARPS_M; + int warp_n = warp / WARPS_M; + + qid_ = lane % 4; + int qp = lane / 4; + + // Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps. + // This won't affect reading as we assume commutative reduction ops. + const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN); + smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col]; + smem_read_ = &reinterpret_cast(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_]; + smem_read_row_ = &reinterpret_cast(smem_)[warp_m * 16 * MMAS_M * 4 + qid_]; + + } + + __device__ inline void store(float (&frag)[2 * MMAS_M]) { + if( qid_ == 0 ) { + #pragma unroll + for( int mi = 0; mi < MMAS_M; mi++ ) { + int offset = mi * 16 * WARPS_N; + smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0]; + smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1]; + } + } + } + + __device__ inline void load(read_t (&frag)[2 * MMAS_M]) { + #pragma unroll + for( int mi = 0; mi < MMAS_M; mi++ ) { + int offset = mi * 16 * 4; + frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4]; + frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4]; + } + } + + __device__ inline void load_row(read_t (&frag)[MMAS_M], int row) { + #pragma unroll + for( int mi = 0; mi < MMAS_M; mi++ ) { + int offset = mi * 16 * 4; + frag[mi] = smem_read_row_[offset + 0 * 8 * 4 + row * 4]; + } + } + + int qid_; + float *smem_write_; + read_t *smem_read_; + read_t *smem_read_row_; + +}; + + +template +struct Softmax_base { + + // The Mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The number of MMAs in M/N dimensions. + static constexpr int MMAS_M = Mma_tile::MMAS_M; + static constexpr int MMAS_N = Mma_tile::MMAS_N; + + // The number of groups of warp such that we have at most 4 warps writing consecutive elements. + static constexpr int GROUPS = fmha::DivUpConstexpr(Cta_tile::WARPS_N, 4); + // The number of elements that we are going to store per row. + static constexpr int ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS; + // The number of rows. + static constexpr int ROWS = Cta_tile::M * GROUPS; + // The total number of elements. + static constexpr int ELEMENTS = ROWS * ELEMENTS_PER_ROW; + + // Ctor. + template + inline __device__ Softmax_base(const Params ¶ms, void *smem, int tidx) + : // packed_mask_ptr_(reinterpret_cast(params.packed_mask_ptr)), + smem_(reinterpret_cast(smem)), tidx_(tidx) { + + // Move to the 1st mask loaded by the thread+ tidx; + // packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t); + + // Extract the position in the warp. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // Decompose the warp index into M and N. + int warp_m = warp % Cta_tile::WARPS_M; + int warp_n = warp / Cta_tile::WARPS_M; + + // Decompose the warp-n index into group/position-inside-the-group. + int warp_g = warp_n / ELEMENTS_PER_ROW; + int warp_i = warp_n % ELEMENTS_PER_ROW; + + // The location written by the threads. + int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4; + int write_col = warp_i; + + // Assemble the write pointer. + smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col]; + + // Assemble the read pointer. + smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4]; + } + + template + inline __device__ void apply_mask(const Mask &mask) { + #pragma unroll + for( int mi = 0; mi < MMAS_M; ++mi ) { + #pragma unroll + for( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N; ++ni ) { + #pragma unroll + for( int jj = 0; jj < 4; ++jj ) { + if( !mask.is_valid(mi, ni, ii, jj) ) { + elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; + } + } + } + } + } + } + + // Apply the exp to all the elements. + template + inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) { + #pragma unroll + for( int mi = 0; mi < MMAS_M * 2; ++mi ) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + constexpr float kLog2e = M_LOG2E; + const float max_base2 = max_in_base2 ? max[mi] : max[mi] * kLog2e; + #pragma unroll + for( int ni = 0; ni < MMAS_N * 4; ++ni ) { + // elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]); + elt_[mi][ni] = apply_exp2_(elt_in_base2 ? elt_[mi][ni] : elt_[mi][ni] * kLog2e, + max_base2); + } + } + } + + // Apply the exp to all the elements. + template + inline __device__ void scale_apply_exp(const float (&max)[MMAS_M * 2], const float scale_) { + const float max_scale = scale_max ? scale_ * M_LOG2E : M_LOG2E; + const float scale = scale_ * M_LOG2E; + #pragma unroll + for( int mi = 0; mi < MMAS_M * 2; ++mi ) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + const float max_scaled = max[mi] * max_scale; + #pragma unroll + for( int ni = 0; ni < MMAS_N * 4; ++ni ) { + elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * scale, max_scaled); + } + } + } + + // Apply the exp to all the elements. + template + inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) { + #pragma unroll + for( int ni = 0; ni < MMAS_N * 4; ++ni ) { + constexpr float kLog2e = M_LOG2E; + const float max_base2 = max_in_base2 ? max[ni] : max[ni] * kLog2e; + #pragma unroll + for( int mi = 0; mi < MMAS_M * 2; ++mi ) { + elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2); + } + } + } + // inline __device__ void apply_exp_col(const float (&max)[MMAS_N]) { + // constexpr float kLog2e = M_LOG2E; + // #pragma unroll + // for( int ni = 0; ni < MMAS_N * 4; ++ni ) { + // float max_base2 = max_in_base2 ? max[ni / 4] : max[ni / 4] * kLog2e; + // max_base2 = __shfl_sync(0xffffffff, max_base2, (ni % 4) * 8 + threadIdx.x % 8); + // #pragma unroll + // for( int mi = 0; mi < MMAS_M * 2; ++mi ) { + // elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2); + // } + // } + // } + + template + inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) { + // We encode the dropout pattern in the sign bit of the non-negative + // softmax to distinguish from pre-existing zeros + auto encode_dropout = [](bool keep, float val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); + }; + #pragma unroll + for( int mi = 0; mi < MMAS_M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N; ni++ ) { + uint16_t tmp[8]; + // fmha::uint4_to_ushort8(ph(), tmp); + uint4 tmp_32 = ph(); + fmha::uint4_to_ushort8(tmp_32, tmp); + // if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("tidx = %d, ni = %d, ph Philox: %u, %u, %u, %u\n", threadIdx.x, ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w); + // } + #pragma unroll + for (int ii = 0; ii < 2; ++ii) { + #pragma unroll + for (int jj = 0; jj < 4; ++jj) { + elt_[mi * 2 + ii][4 * ni + jj] = + encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]); + } + } + } + } + } + + template + inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t, + unsigned long long philox_subsequence) { + // We encode the dropout pattern in the sign bit of the non-negative + // softmax to distinguish from pre-existing zeros + auto encode_dropout = [](bool keep, float val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); + }; + static_assert(MMAS_M == 1); // We're assuming 16x16 blocks. + #pragma unroll + for( int mi = 0; mi < MMAS_M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N; ni++ ) { + uint16_t tmp[8]; + // fmha::uint4_to_ushort8(ph(), tmp); + fmha::uint4_to_ushort8(ph(philox_subsequence + ni * Cta_tile::WARPS_N), tmp); + // uint4 tmp_32 = ph(philox_subsequence + ni * Cta_tile::WARPS_N); + // fmha::uint4_to_ushort8(tmp_32, tmp); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w); + // } + #pragma unroll + for (int ii = 0; ii < 2; ++ii) { + #pragma unroll + for (int jj = 0; jj < 4; ++jj) { + elt_[mi * 2 + ii][4 * ni + jj] = + encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]); + } + } + } + } + } + + template + inline __device__ void apply_dropout_16bits(Philox &ph0, Philox &ph1, uint16_t p_dropout_in_uint16_t) { + // We encode the dropout pattern in the sign bit of the non-negative + // softmax to distinguish from pre-existing zeros + auto encode_dropout = [](bool keep, float val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); + }; + #pragma unroll + for( int mi = 0; mi < MMAS_M; mi++ ) { + static_assert(MMAS_N % 2 == 0); + #pragma unroll + for( int ni = 0; ni < MMAS_N; ni += 2 ) { + uint16_t tmp[8]; + fmha::uint4_to_ushort8(ph0(), tmp); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w); + // } + #pragma unroll + for (int ii = 0; ii < 2; ++ii) { + #pragma unroll + for (int jj = 0; jj < 4; ++jj) { + elt_[mi * 2 + ii][4 * ni + jj] = + encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]); + } + } + fmha::uint4_to_ushort8(ph1(), tmp); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w); + // } + #pragma unroll + for (int ii = 0; ii < 2; ++ii) { + #pragma unroll + for (int jj = 0; jj < 4; ++jj) { + elt_[mi * 2 + ii][4 * (ni + 1) + jj] = + encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * (ni + 1) + jj]); + } + } + } + } + } + + // Scale all the elements. + inline __device__ void scale(const float (&sum)[MMAS_M * 2]) { + // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. + float inv_sum[MMAS_M * 2]; + #pragma unroll + for( int mi = 0; mi < MMAS_M * 2; ++mi ) { + inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; + } + + // Update the values. + #pragma unroll + for( int mi = 0; mi < MMAS_M * 2; ++mi ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N * 4; ++ni ) { + elt_[mi][ni] *= inv_sum[mi]; + } + } + } + + // Subtract all elements by dp_sum + inline __device__ void subtract_dp_sum(const float (&dp_sum)[MMAS_M * 2]) { + #pragma unroll + for( int mi = 0; mi < MMAS_M * 2; ++mi ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N * 4; ++ni ) { + elt_[mi][ni] -= dp_sum[mi]; + } + } + } + + // The pointer to the mask. + const char *packed_mask_ptr_; + // Shared memory for the CTA-wide reduction. + float *smem_, *smem_write_, *smem_read_; + // The current thread index. + int tidx_; + // The elements. + float elt_[MMAS_M * 2][MMAS_N * 4]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax : public Softmax_base { + + // The base class. + using Base = Softmax_base; + // The fragment. + using Fragment_a = fmha::Fragment_a; + + static_assert(Fragment_a::NUM_REGS == 4); + + static constexpr int WARPS_M = Cta_tile::WARPS_M; + static constexpr int WARPS_N = Cta_tile::WARPS_N; + // The MMAs. + static constexpr int MMAS_M = Base::MMAS_M; + static constexpr int MMAS_N = Base::MMAS_N; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + using Accumulator_out = Fragment; + static_assert(Accumulator_out::NUM_REGS == 4); + + static_assert(std::is_same::value); + + using Smem_tile_red = Smem_tile_reduce; + static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N); + // Ctor. + template + inline __device__ Softmax(const Params ¶ms, void *smem, int tidx) + : Base(params, smem, tidx) + , params_scale_bmm1_(params.scale_bmm1) + , smem_sum_(static_cast(smem), tidx) + , smem_max_(static_cast(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) { + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + #pragma unroll + for( int mi = 0; mi < M; ++mi ) { + #pragma unroll + for( int ki = 0; ki < K; ++ki ) { + + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_pack(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_pack(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_pack(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_pack(tmp_12, tmp_13); + } + } + } + + // Scale FP32 fragments + inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) { + const float scalef = reinterpret_cast(this->params_scale_bmm1_); + + #pragma unroll + for( int mi = 0; mi < MMAS_M; ++mi ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N; ++ni ) { + // 1st row - 4 elements per row. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef; + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef; + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef; + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef; + // 2nd row - 4 elements per row. + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef; + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef; + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef; + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef; + } + } + } + + // Scale FP32 fragments + inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) { + + #pragma unroll + for( int mi = 0; mi < MMAS_M; ++mi ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N; ++ni ) { + // 1st row - 4 elements per row. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0); + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1); + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4); + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5); + // 2nd row - 4 elements per row. + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2); + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3); + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6); + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7); + } + } + } + + template + __device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) { + #pragma unroll + for( int mi = 0; mi < 2 * MMAS_M; mi++ ) { + frag[mi] = zero_init ? this->elt_[mi][0] : op(frag[mi], this->elt_[mi][0]); + #pragma unroll + for( int ni = 1; ni < 4 * MMAS_N; ni++ ) { + frag[mi] = op(frag[mi], this->elt_[mi][ni]); + } + } + } + + template + __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) { + thread_reduce_(frag, op); + quad_reduce(frag, frag, op); + smem_red.store(frag); + __syncthreads(); + typename Smem_tile_red::read_t tmp[2 * MMAS_M]; + smem_red.load(tmp); + quad_allreduce(frag, tmp, op); + } + + template + __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){ + MaxOp max; + reduce_(frag, max, smem_max_); + } + + __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){ + SumOp sum; + reduce_(frag, sum, smem_sum_); + } + + template + __device__ inline void reduce_sum_before_sync_(float (&frag)[2 * MMAS_M]){ + SumOp sum; + thread_reduce_(frag, sum); + quad_reduce(frag, frag, sum); + smem_sum_.store(frag); + } + + template + __device__ inline void reduce_after_sync_(float (&frag)[NROWS][MMAS_M], + const int (&rows)[NROWS], + Operator &op, Smem_tile_red & smem_red) { + #pragma unroll + for (int ii = 0; ii < NROWS; ii++) { + typename Smem_tile_red::read_t tmp[MMAS_M]; + smem_red.load_row(tmp, rows[ii]); + quad_allreduce(frag[ii], tmp, op); + } + } + + template + __device__ inline void reduce_sum_after_sync_(float (&frag)[NROWS][MMAS_M], + const int (&rows)[NROWS]){ + SumOp sum; + reduce_after_sync_(frag, rows, sum, smem_sum_); + } + + template + __device__ inline void reduce_max_after_sync_(float (&frag)[NROWS][MMAS_M], + const int (&rows)[NROWS]){ + MaxOp max; + reduce_after_sync_(frag, rows, max, smem_max_); + } + + const uint32_t params_scale_bmm1_; + Smem_tile_red smem_max_; + Smem_tile_red smem_sum_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha/utils.h b/candle-extensions/candle-flash-attn-v1/kernels/fmha/utils.h new file mode 100644 index 00000000..38e4e741 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha/utils.h @@ -0,0 +1,1215 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Row {}; +struct Col {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int M, bool = (M & (M-1)) == 0 > +struct Next_power_of_two { +}; + +template< int M > +struct Next_power_of_two< M, true > { enum { VALUE = M }; }; +template<> +struct Next_power_of_two< 3, false> { enum { VALUE = 4 }; }; +template<> +struct Next_power_of_two< 5, false> { enum { VALUE = 8 }; }; +template<> +struct Next_power_of_two< 6, false> { enum { VALUE = 8 }; }; +template<> +struct Next_power_of_two< 7, false> { enum { VALUE = 8 }; }; +template<> +struct Next_power_of_two< 9, false> { enum { VALUE = 16 }; }; +template<> +struct Next_power_of_two< 10, false> { enum { VALUE = 16 }; }; +template<> +struct Next_power_of_two< 11, false> { enum { VALUE = 16 }; }; +template<> +struct Next_power_of_two< 12, false> { enum { VALUE = 16 }; }; +template<> +struct Next_power_of_two< 13, false> { enum { VALUE = 16 }; }; +template<> +struct Next_power_of_two< 14, false> { enum { VALUE = 16 }; }; +template<> +struct Next_power_of_two< 15, false> { enum { VALUE = 16 }; }; +template<> +struct Next_power_of_two< 24, false> { enum { VALUE = 32 }; }; +template<> +struct Next_power_of_two< 48, false> { enum { VALUE = 64 }; }; +template<> +struct Next_power_of_two< 80, false> { enum { VALUE = 128 }; }; +template<> +struct Next_power_of_two< 96, false> { enum { VALUE = 128 }; }; +template<> +struct Next_power_of_two<112, false> { enum { VALUE = 128 }; }; +template<> +struct Next_power_of_two<144, false> { enum { VALUE = 256 }; }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N, bool = (N & (N-1)) == 0 > +struct Prev_power_of_two { +}; + +template< int N > +struct Prev_power_of_two< N, true > { enum { VALUE = N }; }; +template<> +struct Prev_power_of_two< 3, false> { enum { VALUE = 2 }; }; +template<> +struct Prev_power_of_two< 5, false> { enum { VALUE = 4 }; }; +template<> +struct Prev_power_of_two< 6, false> { enum { VALUE = 4 }; }; +template<> +struct Prev_power_of_two< 7, false> { enum { VALUE = 4 }; }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int M, int N > +struct Div_up { + enum { VALUE = (M + N-1) / N }; +}; + +constexpr int DivUpConstexpr(int M, int N) { return (M + N - 1) / N; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int A, int B > +struct Max { + enum { VALUE = A >= B ? A : B }; +}; + +constexpr int MaxConstexpr(int A, int B) { return A >= B ? A : B; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int A, int B, int C > +struct Max_3 { + enum { VALUE = Max::VALUE, C>::VALUE }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int A, int B > +struct Min { + enum { VALUE = A <= B ? A : B }; +}; + +constexpr int MinConstexpr(int A, int B) { return A <= B ? A : B; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int SIZE_IN_BYTES > +struct Uint_from_size_in_bytes { +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Uint_from_size_in_bytes<1> { + using Type = uint8_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Uint_from_size_in_bytes<2> { + using Type = uint16_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Uint_from_size_in_bytes<4> { + using Type = uint32_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Uint_from_size_in_bytes<8> { + using Type = uint2; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Uint_from_size_in_bytes<16> { + using Type = uint4; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int WARPS_M, int WARPS_N, int WARPS_K > +struct Warp_masks { +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Warp_masks<8, 1, 1> { enum { M = 0xe0, N = 0x00, K = 0x00 }; }; +template<> +struct Warp_masks<4, 2, 1> { enum { M = 0x60, N = 0x80, K = 0x00 }; }; +template<> +struct Warp_masks<4, 1, 2> { enum { M = 0x60, N = 0x00, K = 0x80 }; }; +template<> +struct Warp_masks<4, 1, 1> { enum { M = 0x60, N = 0x00, K = 0x00 }; }; +template<> +struct Warp_masks<2, 4, 1> { enum { M = 0x20, N = 0xc0, K = 0x00 }; }; +template<> +struct Warp_masks<2, 2, 2> { enum { M = 0x20, N = 0x40, K = 0x80 }; }; +template<> +struct Warp_masks<2, 2, 1> { enum { M = 0x20, N = 0x40, K = 0x00 }; }; +template<> +struct Warp_masks<2, 1, 2> { enum { M = 0x20, N = 0x00, K = 0x40 }; }; +template<> +struct Warp_masks<2, 1, 1> { enum { M = 0x20, N = 0x00, K = 0x00 }; }; +template<> +struct Warp_masks<1, 8, 1> { enum { M = 0x00, N = 0xe0, K = 0x00 }; }; +template<> +struct Warp_masks<1, 4, 2> { enum { M = 0x00, N = 0x60, K = 0x80 }; }; +template<> +struct Warp_masks<1, 4, 1> { enum { M = 0x00, N = 0x60, K = 0x00 }; }; +template<> +struct Warp_masks<1, 2, 2> { enum { M = 0x00, N = 0x20, K = 0x40 }; }; +template<> +struct Warp_masks<1, 2, 1> { enum { M = 0x00, N = 0x20, K = 0x00 }; }; +template<> +struct Warp_masks<1, 1, 4> { enum { M = 0x00, N = 0x00, K = 0x60 }; }; +template<> +struct Warp_masks<1, 1, 2> { enum { M = 0x00, N = 0x00, K = 0x20 }; }; +template<> +struct Warp_masks<1, 1, 1> { enum { M = 0x00, N = 0x00, K = 0x00 }; }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +inline __device__ __host__ T div_up(T m, T n) { + return (m + n-1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int clz(int x) { + for( int i = 31; i >= 0; --i ) { + if( (1 << i) & x ) { + return 31 - i; + } + } + return 32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int find_log_2(int x, bool round_up = false) { + int a = 31 - clz(x); + if( round_up ) { + a += (x & (x-1)) ? 1 : 0; + } + return a; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) { + // uint32_t c; + // asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + // return c; + __half2 result = __hmul2(reinterpret_cast(a), + reinterpret_cast(b)); + return reinterpret_cast(result); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hmul4(uint2 a, uint2 b) { + uint2 c; + c.x = hmul2(a.x, b.x); + c.y = hmul2(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hmul8(uint4 a, uint4 b) { + uint4 c; + c.x = hmul2(a.x, b.x); + c.y = hmul2(a.y, b.y); + c.z = hmul2(a.z, b.z); + c.w = hmul2(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { + uint4 c; + c.x = hmul2(a, b.x); + c.y = hmul2(a, b.y); + c.z = hmul2(a, b.z); + c.w = hmul2(a, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t hrelu2(uint32_t x); + +template<> +inline __device__ uint32_t hrelu2<__half>(uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( \ + "{\n" \ + "\t .reg .f16x2 sela;\n" \ + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ + "\t and.b32 %0, sela, %1;\n" + "}\n" : "=r"(res) : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +inline __device__ uint32_t hrelu2<__nv_bfloat16>(uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile( "max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t habs2(uint32_t x) { + uint32_t res; + asm volatile( "abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x)); + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +static inline __device__ T clamp(T x, T lb, T ub) { + return x < lb ? lb : (x > ub ? ub : x); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t clamp_to_zero(uint16_t x) { + uint16_t mask; + asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x)); + return mask & x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t float_to_half(float f) { + uint16_t h; + asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f)); + return h; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float2_to_half2(float a, float b) { + uint32_t c; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); +#else + uint16_t lo = float_to_half(a); + uint16_t hi = float_to_half(b); + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi)); +#endif + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t float2_pack(float a, float b); + +template <> +inline __device__ uint32_t float2_pack<__half>(float a, float b) { + __half2 result = __floats2half2_rn(a, b); + return reinterpret_cast(result); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ uint32_t float2_pack<__nv_bfloat16>(float a, float b) { + __nv_bfloat162 result = __floats2bfloat162_rn(a, b); + return reinterpret_cast(result); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float_to_half2(float a) { + return float2_to_half2(a,a); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float2_to_half2(const float2 &f) { + return float2_to_half2(f.x, f.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) { + uint2 d; + d.x = float2_to_half2(x, y); + d.y = float2_to_half2(z, w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint2 float4_pack(float x, float y, float z, float w) { + uint2 d; + d.x = float2_pack(x, y); + d.y = float2_pack(z, w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c)); +#else + d = hrelu2<__half>(hfma2(a, b, c)); +#endif + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t h0_h0(uint32_t x) { + uint32_t y; + asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" + : "=r"(y) : "r"(x)); + return y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float h0_to_float(uint32_t h2) { + float f; + asm volatile("{\n" \ + ".reg .f16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %1;\n" \ + "cvt.f32.f16 %0, lo;\n" \ + "}\n" : "=f"(f) : "r"(h2)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t h1_h1(uint32_t x) { + uint32_t y; + asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" + : "=r"(y) : "r"(x)); + return y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) { + uint16_t d; + asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) { + return hadd2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hadd4(uint2 a, uint2 b) { + uint2 c; + c.x = hadd2(a.x, b.x); + c.y = hadd2(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hadd(uint2 a, uint2 b) { + return hadd4(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hadd8(uint4 a, uint4 b) { + uint4 c; + c.x = hadd2(a.x, b.x); + c.y = hadd2(a.y, b.y); + c.z = hadd2(a.z, b.z); + c.w = hadd2(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float2 half2_unpack(uint32_t a); + +template <> +inline __device__ float2 half2_unpack<__half>(uint32_t a) { + return __half22float2(reinterpret_cast<__half2 (&)>(a)); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { + return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a)); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two half2's or bf162's into float, then take their dot product. +template +inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { + float2 af = fmha::half2_unpack(a); + float2 bf = fmha::half2_unpack(b); + return af.x * bf.x + af.y * bf.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two vectors of 8 half's or bf16's into float, then take their dot product. +template +inline __device__ float hmulsum8(const uint4 a, const uint4 b) { + float sum; + sum = fmha::hfma2_to_float(a.x, b.x); + sum += fmha::hfma2_to_float(a.y, b.y); + sum += fmha::hfma2_to_float(a.z, b.z); + sum += fmha::hfma2_to_float(a.w, b.w); + return sum; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 fadd4(uint4 a, uint4 b) { + float4 c; + c.x = reinterpret_cast(a.x) + reinterpret_cast(b.x); + c.y = reinterpret_cast(a.y) + reinterpret_cast(b.y); + c.z = reinterpret_cast(a.z) + reinterpret_cast(b.z); + c.w = reinterpret_cast(a.w) + reinterpret_cast(b.w); + return reinterpret_cast(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 fmul4(uint4 a, float b) { + float4 c; + c.x = reinterpret_cast(a.x) * b; + c.y = reinterpret_cast(a.y) * b; + c.z = reinterpret_cast(a.z) * b; + c.w = reinterpret_cast(a.w) * b; + return reinterpret_cast(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hadd(uint4 a, uint4 b) { + return hadd8(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float2 half2_to_float2(uint32_t x) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void half2_to_float2(float &x, float &y, uint32_t h) { + float2 tmp = half2_to_float2(h); + x = tmp.x; + y = tmp.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) { + uint16_t d; + asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) { + uint16_t d; + asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void uint4_to_ushort8(const uint4 a, uint16_t (&b)[8]) { + uint32_t *b_tmp = reinterpret_cast(&b[0]); + b_tmp[0] = a.x; + b_tmp[1] = a.y; + b_tmp[2] = a.z; + b_tmp[3] = a.w; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float sigmoid(float x) { + return 1.f / (1.f + expf(-x)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint16_t &dst) { + dst = uint16_t(0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint32_t &dst) { + dst = 0u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint2 &dst) { + dst = make_uint2(0u, 0u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint4 &dst) { + dst = make_uint4(0u, 0u, 0u, 0u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// P R E D I C A T E P A C K I N G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// +enum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE }; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// G E N E R I C P R E D I C A T E D L D G S T S +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N, int M, typename Functor > +inline __device__ void load_(Functor &fct, const uint32_t (&preds)[M]) { + + // The number of complete bytes (where we use all the predicates in a byte). + enum { COMPLETE = N / PREDS_PER_BYTE }; + // Make sure we did allocate enough predicates. + static_assert(Div_up::VALUE <= M, ""); + // The remainder. + enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE }; + // Make sure we got the math right and the remainder is between 0 and 3. + static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); + // The mask to extract the predicates. + enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; + + // Clear the fetch registers. + #pragma unroll + for( int ii = 0; ii < N; ++ii ) { + fct.clear(ii); + } + + // Run complete steps. + bool p[PREDS_PER_BYTE]; + #pragma unroll + for( int ii = 0; ii < COMPLETE; ++ii ) { + + // The predicate. + uint32_t reg = preds[ii / BYTES_PER_REG]; + + // Extract the predicates. + #pragma unroll + for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { + uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj); + p[jj] = (reg & mask) != 0u; + } + + // Issue the loads. + #pragma unroll + for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { + fct.load(ii * PREDS_PER_BYTE + jj, p[jj]); + } + } + + // Skip the rest of the code if we do not have a remainder. + if( REMAINDER > 0 ) { + + // The mask to extract the predicates. + enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; + + // The predicate register. + uint32_t reg = preds[COMPLETE / BYTES_PER_REG]; + + // Extract the predicates. + #pragma unroll + for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { + uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj); + p[jj] = (reg & mask) != 0u; + } + + // Issue the loads. + #pragma unroll + for( int ii = 0; ii < REMAINDER; ++ii ) { + fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int M, typename Functor > +inline __device__ void load_(Functor &fct, uint32_t preds) { + uint32_t tmp[1] = { preds }; + load_(fct, tmp); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// L D G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint8_t &dst, const void *ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint16_t &dst, const void *ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint32_t &dst, const void *ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint2 &dst, const void *ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint4 &dst, const void *ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Data_type, int N > +struct Ldg_functor { + // Ctor. + inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N]) + : fetch_(fetch), ptrs_(ptrs) { + } + + // Clear the element. + inline __device__ void clear(int ii) { + fmha::clear(fetch_[ii]); + } + + // Trigger the loads. + inline __device__ void load(int ii, bool p) { + if( p ) { + ldg(fetch_[ii], ptrs_[ii]); + } + } + + // The fetch registers. + Data_type (&fetch_)[N]; + // The pointers. + const void* (&ptrs_)[N]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Data_type, int N, int M > +inline __device__ void ldg_(Data_type (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { + Ldg_functor fct(fetch, ptrs); + load_(fct, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N, int M > +inline __device__ void ldg(uint8_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N, int M > +inline __device__ void ldg(uint16_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N, int M > +inline __device__ void ldg(uint32_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N, int M > +inline __device__ void ldg(uint2 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N, int M > +inline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// L D S +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint16_t &dst, uint32_t ptr) { + asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint32_t &dst, uint32_t ptr) { + asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint2 &dst, uint32_t ptr) { + asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint4 &dst, uint32_t ptr) { + asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x) + , "=r"(dst.y) + , "=r"(dst.z) + , "=r"(dst.w) + : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// L D S M +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm(uint32_t &dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" + : "=r"(dst) : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsmt(uint32_t &dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n" + : "=r"(dst) : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm(uint2 &dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsmt(uint2 &dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm(uint4 &dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsmt(uint4 &dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// S T G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void *ptr, uint8_t val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void *ptr, uint16_t val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void *ptr, uint32_t val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void *ptr, uint2 val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void *ptr, uint4 val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// S T S +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint16_t val) { + asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint32_t val) { + asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint2 val) { + asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" + : + : "r"(ptr) + , "r"(val.x) + , "r"(val.y)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint4 val) { + asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" + : + : "r"(ptr) + , "r"(val.x) + , "r"(val.y) + , "r"(val.z) + , "r"(val.w)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Data_type, int N > +inline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) { + #pragma unroll + for( int ii = 0; ii < N; ++ii ) { + sts(ptrs[ii], data[ii]); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N > +inline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N > +inline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N > +inline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N > +inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ inline float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ inline T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) { + #pragma unroll + for(int mi=0; mi < M; mi++){ + dst[mi] = src[mi]; + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce(__half2 (&dst)[M], __half2 (&src)[M], Operator &op) { + #pragma unroll + for(int mi=0; mi < M; mi++){ + dst[mi] = src[mi]; + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator &op) { + float tmp[M]; + #pragma unroll + for(int mi=0; mi < M; mi++){ + tmp[mi] = op(src[mi].x, src[mi].y); + } + quad_reduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce(__half2 (&dst)[M], float2 (&src)[M], Operator &op) { + __half2 tmp[M]; + #pragma unroll + for(int mi=0; mi < M; mi++){ + tmp[mi] = op(reinterpret_cast(src[mi].x), + reinterpret_cast(src[mi].y)); + } + quad_reduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator &op) { + #pragma unroll + for(int mi=0; mi < M; mi++){ + dst[mi] = src[mi]; + dst[mi] = Allreduce<4>::run(dst[mi], op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce(__half2 (&dst)[M], __half2 (&src)[M], Operator &op) { + #pragma unroll + for(int mi=0; mi < M; mi++){ + dst[mi] = src[mi]; + dst[mi] = Allreduce<4>::run(dst[mi], op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) { + float tmp[M]; + #pragma unroll + for(int mi=0; mi < M; mi++){ + tmp[mi] = op(src[mi].x, src[mi].y); + } + quad_allreduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce(__half2 (&dst)[M], float2 (&src)[M], Operator &op) { + __half2 tmp[M]; + #pragma unroll + for(int mi=0; mi < M; mi++){ + tmp[mi] = op(reinterpret_cast(src[mi].x), + reinterpret_cast(src[mi].y)); + } + quad_allreduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha_api.cpp b/candle-extensions/candle-flash-attn-v1/kernels/fmha_api.cpp new file mode 100644 index 00000000..8469042e --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha_api.cpp @@ -0,0 +1,275 @@ +/****************************************************************************** + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#include +#include +#include + +#include "fmha.h" + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + + +void set_params_fprop(FMHA_fprop_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t h, + const size_t d, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + at::Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *o_tmp_d, + void *s_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + bool is_causal, + int num_splits) { + + Data_type acc_type = DATA_TYPE_FP32; + Data_type data_type = !(q.dtype() == torch::kBFloat16) ? DATA_TYPE_FP16 : DATA_TYPE_BF16; + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + params.q_row_stride_in_elts = q.stride(0); + params.k_row_stride_in_elts = k.stride(0); + params.v_row_stride_in_elts = v.stride(0); + params.q_head_stride_in_elts = q.stride(1); + params.k_head_stride_in_elts = k.stride(1); + params.v_head_stride_in_elts = v.stride(1); + params.o_ptr = out.data_ptr(); + params.o_row_stride_in_elts = out.stride(0); + params.o_head_stride_in_elts = out.stride(1); + params.o_tmp_ptr = o_tmp_d; + params.o_tmp_row_stride_in_elts = h * d; + params.o_tmp_head_stride_in_elts = d; + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + + // S = softmax(P) + params.s_ptr = s_d; + params.s_stride_in_bytes = get_size_in_bytes(b * h * seqlen_k, data_type); + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.d = d; + + // Set the different scale values. + // const float scale_bmm1 = 1.f / sqrtf(d); + const float scale_bmm1 = softmax_scale; + + params.scale_bmm1f = scale_bmm1; + set_alpha(params.scale_bmm1, scale_bmm1, data_type); + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1f; + TORCH_CHECK(p_dropout < 1.f); + set_alpha(params.scale_dropout, params.rp_dropout, data_type); + + params.is_causal = is_causal; + params.num_splits = num_splits; +} + +void run_fmha_fwd(Launch_params &launch_params) { + if (launch_params.params.d <= 32) { + run_fmha_fwd_hdim32(launch_params); + } else if (launch_params.params.d <= 64) { + run_fmha_fwd_hdim64(launch_params); + } else if (launch_params.params.d <= 128) { + run_fmha_fwd_hdim128(launch_params); + } +} + +std::vector +mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + at::Tensor &out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + const bool return_softmax, + const int num_splits, + c10::optional gen_) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x || is_sm75); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + bool is_dropout = p_dropout > 0.0; + Launch_params launch_params(dprops, stream, is_dropout, return_softmax); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16)); + TORCH_CHECK(k.dtype() == q_dtype); + TORCH_CHECK(v.dtype() == q_dtype); + TORCH_CHECK(out.dtype() == q_dtype); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32); + + TORCH_CHECK(q.is_cuda()); + TORCH_CHECK(k.is_cuda()); + TORCH_CHECK(v.is_cuda()); + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(cu_seqlens_q.is_cuda()); + TORCH_CHECK(cu_seqlens_k.is_cuda()); + + TORCH_CHECK(q.stride(-1) == 1); + TORCH_CHECK(k.stride(-1) == 1); + TORCH_CHECK(v.stride(-1) == 1); + TORCH_CHECK(out.stride(-1) == 1); + TORCH_CHECK(cu_seqlens_q.is_contiguous()); + TORCH_CHECK(cu_seqlens_k.is_contiguous()); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + const int total_q = sizes[TOTAL_DIM]; + const int num_heads = sizes[H_DIM]; + const int head_size = sizes[D_DIM]; + const int total_k = k.size(TOTAL_DIM); + TORCH_CHECK(batch_size > 0); + TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128)); + + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(k, total_k, num_heads, head_size); + CHECK_SHAPE(v, total_k, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + int blocksize_c = head_size > 64 ? 128 : 256; + // Need to round max_seqlen_k to multiples of blocksize_c + int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; + if( max_seqlen_k_ <= 128 ) { + max_seqlen_k = 128; + } else if( max_seqlen_k_ <= 256 ) { + max_seqlen_k = 256; + } + int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; + bool loop = max_seqlen_k > blocksize_c; + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + + // auto o = torch::empty({ total_q, num_heads, head_size }, opts); + + at::Tensor o_tmp; + if (loop) { o_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); } + + auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + // auto softmax_lse = torch::full({batch_size, num_heads, max_seqlen_k}, -std::numeric_limits::infinity(), opts.dtype(at::kFloat)); + + at::Tensor s; + if (return_softmax) { s = torch::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); } + + if( zero_tensors ) { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) {s.zero_();} + } + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + set_params_fprop(launch_params.params, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + head_size, + q, k, v, out, + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + loop ? o_tmp.data_ptr() : nullptr, + return_softmax ? s.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + is_causal, + num_splits); + + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = launch_params.params.b * launch_params.params.h * 32; + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // Forward kernel will populate memory with the seed and offset. + launch_params.params.rng_state = reinterpret_cast(rng_state.data_ptr()); + + if( is_dropout ) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + launch_params.params.philox_args = gen->philox_cuda_state(counter_offset); + } + + run_fmha_fwd(launch_params); + + std::vector result = {softmax_lse}; + result.push_back(rng_state); + if (return_softmax) {result.push_back(s);} + return result; +} diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha_fprop_kernel_1xN.h b/candle-extensions/candle-flash-attn-v1/kernels/fmha_fprop_kernel_1xN.h new file mode 100644 index 00000000..8fd2f2a0 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha_fprop_kernel_1xN.h @@ -0,0 +1,706 @@ +/*************************************************************************************************** + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include "fmha_kernel.h" +#include "fmha/kernel_traits.h" +#include "fmha/gemm.h" +#include "fmha/utils.h" + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gemm_Q_K_base { + using Smem_tile_o = typename Kernel_traits::Smem_tile_o; + using Smem_tile_q = typename Kernel_traits::Smem_tile_q; + using Smem_tile_k = typename Kernel_traits::Smem_tile_k; + using Fragment_q = typename Smem_tile_q::Fragment; + using Fragment_k = typename Smem_tile_k::Fragment; + + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; + + static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2; + + __device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx) + : smem_q(smem_ptr_q, tidx) + , smem_k(smem_ptr_k, tidx) { + + } + + __device__ inline void load_q() { + smem_q.load(frag_q[0], 0); + } + + __device__ inline void reload_q() { + smem_q.load(frag_q[0], 0); + } + + Fragment_q frag_q[2][Mma_tile_p::MMAS_M]; + Smem_tile_q smem_q; + Smem_tile_k smem_k; +}; + +template +struct Gemm_Q_K : public Gemm_Q_K_base { + + using Base = Gemm_Q_K_base; + using Smem_tile_o = typename Base::Smem_tile_o; + using Smem_tile_q = typename Base::Smem_tile_q; + using Smem_tile_k = typename Base::Smem_tile_k; + using Fragment_k = typename Base::Fragment_k; + using Mma_tile_p = typename Base::Mma_tile_p; + using elem_type = elem_type_; + + static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V; + // If V is stored in shared memory, we can't load K using the same shared memory. + static_assert(Kernel_traits::V_IN_REGS); + + static constexpr int SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE; + static constexpr int SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE; + static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE); + + // Q | K / V + // | O | SOFTMAX + static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE + + std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE, + Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX); + + __device__ inline Gemm_Q_K(char * smem_, const int tidx) + : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) { + } + + __device__ inline void load_k(){ + #pragma unroll + for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) { + Base::smem_k.load(frag_k[ki], ki); + } + } + + template + __device__ inline void operator()(Acc (&acc_p)[M][N]){ + // Do this part of P^T = (Q * K^T)^T. + #pragma unroll + for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { + // Trigger the load from shared memory for the next series of Q values. + Base::smem_q.load(Base::frag_q[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); + } + // Do the final stage of math. + { + int ki = Mma_tile_p::MMAS_K; + fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); + } + } + + __device__ inline void reload_k(){ + // Noop. + } + + Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N]; +}; + + +template +struct Gemm_Q_K : public Gemm_Q_K_base { + using Base = Gemm_Q_K_base; + using Smem_tile_o = typename Base::Smem_tile_o; + using Smem_tile_q = typename Base::Smem_tile_q; + using Smem_tile_k = typename Base::Smem_tile_k; + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + using Fragment_k = typename Base::Fragment_k; + using Mma_tile_p = typename Base::Mma_tile_p; + using elem_type = elem_type_; + Fragment_k frag_k[2][Mma_tile_p::MMAS_N]; + + static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V; + static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS; + static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V); + + static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE); + static_assert(Smem_tile_v::BYTES_PER_TILE == (int) Smem_tile_k::BYTES_PER_TILE); + static constexpr int SMEM_OFFSET_O = SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE; + static constexpr int SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE; + + // If V_IN_REGS and SHARE_SMEM_FOR_K_AND_V: Q | K/V | O | SOFTMAX + // If !V_IN_REGS (then !SHARE_SMEM_FOR_K_AND_V): Q | K | V | O | SOFTMAX + static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE + + (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE + + Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX; + + __device__ inline Gemm_Q_K(char * smem_, const int tidx) + : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) { + } + + __device__ inline void load_k(){ + Base::smem_k.load(frag_k[0], 0); + } + + template + __device__ inline void operator()(Acc (&acc_p)[M][N]){ + // Do this part of P^T = (Q * K^T)^T. + #pragma unroll + for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { + // Trigger the load from shared memory for the next series of Q values. + Base::smem_q.load(Base::frag_q[ki & 1], ki); + Base::smem_k.load(frag_k[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); + } + // Do the final stage of math. + { + int ki = Mma_tile_p::MMAS_K; + fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); + } + } + + __device__ inline void reload_k(){ + Base::smem_k.load(frag_k[0], 0); + } +}; + +template +constexpr size_t get_dynamic_smem_size(){ + return Gemm_Q_K::SMEM_BYTES; +} + +template +inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int steps, Prng &ph, const int loop_step_idx) { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using elem_type = typename Kernel_traits::elem_type; +#else + constexpr bool is_fp16_type = std::is_same::value; + assert(is_fp16_type); + using elem_type = __half; +#endif + + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + // The description of the CTA tile for the 2nd batched GEMM. + using Cta_tile_o = typename Kernel_traits::Cta_tile_o; + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_o = fmha::Hmma_tile; + + // The global memory tile to load Q. + using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; + + // The global memory tile to load K. + using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; + + // The global memory tile to load V. + using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle V. + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; + using Gmem_tile_o_tmp = fmha::Gmem_tile_o; + // The shared memory tile to swizzle O. + using Smem_tile_o = typename Kernel_traits::Smem_tile_o; + + using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; + + using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; + + using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum; + + using Gemm1 = Gemm_Q_K; + + using Softmax = fmha::Softmax; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + // How many steps to jump per iteration, which is the same as params.num_splits. + const int step_stride = gridDim.z; + + const BlockInfoPadded binfo(params, bidb, bidh, tidx); + // if( binfo.stop_early() ) return; + if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; + + Gemm1 gemm_q_k(smem_, tidx); + // Allocate the global memory tile loader for Q. + Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, + params.d, binfo, tidx, true); + // Allocate the global memory tile loader for O. + Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, + params.d, binfo, tidx); + Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_tmp_row_stride_in_elts, + params.o_tmp_head_stride_in_elts, params.d, binfo, tidx); + // Allocate the global memory tile loader for S. + Gmem_tile_s gmem_s(params, binfo, tidx); + Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); + + // Wind gmem tiles to the correct position. + static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); + int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0; + // We want begin to be a multiple of gridDim.z + // This is because the row indices processed by each threadblock must align between the + // loop steps, otherwise we have a dependency between the blocks. + // For example, threadblock with blockIdx.z == 1 must process row indices that are + // k * gridDim.z + 1 for integer k. + const int begin_mod_z = begin % gridDim.z; + begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z; + // Otherwise we'd be reading out-of-bound memory before the loop + if ((begin + blockIdx.z) * Cta_tile_p::M >= binfo.actual_seqlen_q) return; + const int steps_og = steps; + steps -= begin; + gmem_q.move(begin + blockIdx.z); + gmem_o.move(begin + blockIdx.z); + gmem_o_tmp.move(begin + blockIdx.z); + if (Return_softmax) { + gmem_s.move(begin + blockIdx.z); + } + gmem_softmax_lse.move(begin + blockIdx.z); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("begin = %d, steps = %d\n", begin, steps); + // } + + fmha::Mask mask(binfo, tidx, loop_step_idx); + + // Allocate the global memory tile loader for K. + Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, + params.d, binfo, tidx, false); + // Allocate the global memory tile loader for V. + Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, + params.d, binfo, tidx, false); + // The base pointer of smem_v; + char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; + + // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! + Smem_tile_v smem_v(smem_v_, tidx); + + // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! + Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx); + + if (!Is_first) { + gmem_k.move(loop_step_idx); + gmem_v.move(loop_step_idx); + if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); } + } + + // Trigger the loads for K. + gmem_k.load(); + // Trigger the loads for Q. + gmem_q.load(); + // Trigger the loads for V. + gmem_v.load(); + + if (!Is_first) { __syncthreads(); } + + float p_prev_lse[Mma_tile_p::MMAS_M * 2]; + if (!Is_first) { + gmem_softmax_lse.load(reinterpret_cast(p_prev_lse)); + } + + // Commit the data for Q and V to shared memory. + gmem_q.commit(gemm_q_k.smem_q); + gmem_v.commit(smem_v); + + // const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); + // #pragma unroll + // for(int it=0;it < Gmem_tile_k::LDGS;it++){ + // gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); + // } + + // Commit the data for K to shared memory. + if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { + gmem_k.commit(gemm_q_k.smem_k); + } + + __syncthreads(); + + // Load the fragments for Q. + gemm_q_k.load_q(); + + // Load the fragments for V. We keep the data in registers during the entire kernel. + typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N]; + #pragma unroll + for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { + smem_v.load(frag_v[ki], ki); + } + + // Commit the data for V to shared memory if it has not been done already. + if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { + // Make sure we are done loading the fragments for K. + __syncthreads(); + + // Commit the data to shared memory for V. + gmem_k.commit(gemm_q_k.smem_k); + + // Make sure the data is in shared memory. + __syncthreads(); + } + + // Load the fragments for K. + gemm_q_k.load_k(); + + // Create the object to do the softmax. + Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx); + + Smem_softmax_sum smem_softmax_lse(reinterpret_cast(&smem_[Gemm1::SMEM_BYTES]), tidx); + + // Load over the entire sequence length. + for (int l = blockIdx.z; l < steps; l += step_stride) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z <= 1)) { + // printf("l = %d\n", l); + // } + if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break; + + // Declare the accumulators for the 1st gemm. + fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + fmha::Clear_accumulator::apply(acc_p); + + // Do this part of P = Q * K^T. + gemm_q_k(acc_p); + + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1)); + // } + + uint4 out[Gmem_tile_o::STGS_PER_LOOP]; + if (!Is_first) { gmem_o_tmp.load(out, 0); } + + // Trigger the load for the next Q values. + if (l + step_stride < steps) { + gemm_q_k.smem_q.move_to_next_write_buffer(); + gmem_q.move(step_stride); + gmem_q.load(); + } + + // Load the mask for that iteration. + mask.load(begin + l); + + // Convert from the accumulator type to FP32 for Softmax. + softmax.unpack_noscale(acc_p); + + // Apply the mask. + softmax.apply_mask(mask); + + if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l < step_stride ) { + // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction + __syncthreads(); + } + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l >= 0)) { + // printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]); + // } + // } + // Compute the max. + float p_max[Mma_tile_p::MMAS_M * 2]; + if (!Is_first) { + smem_softmax_lse.store_pair(p_prev_lse); + // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; } + for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; } + } + + // Trigger the load for the next LSE values. + if (l + step_stride < steps) { + if (!Is_first) { + gmem_softmax_lse.load_next(reinterpret_cast(p_prev_lse), + step_stride); + } + } + + softmax.template reduce_max(p_max); + + // if ((threadIdx.x == 0) && (l == 38)) { + // printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]); + // } + + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]); + // } + // } + + // Compute the exponential value. + // softmax.apply_exp(p_max); + softmax.scale_apply_exp(p_max, params.scale_bmm1f); + + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]); + // } + // } + + // Compute the sum. + float p_sum[Mma_tile_p::MMAS_M * 2]; + // if (!Is_first) { + // int warp = tidx / Cta_tile_p::THREADS_PER_WARP; + // int lane = tidx % Cta_tile_p::THREADS_PER_WARP; + // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { + // p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? expf(p_prev_lse[mi] - p_max[mi]) : 0; + // } + // } + // softmax.reduce_sum(p_sum); + softmax.reduce_sum_before_sync_(p_sum); + // softmax.template reduce_sum_before_sync_(p_sum); + + // float p_sum_log[Mma_tile_p::MMAS_M * 2]; + // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) { + // float sum = p_sum[mi]; + // // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + __logf(sum); + // constexpr float kLog2e = M_LOG2E; + // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum); + // } + // // gmem_softmax_lse.store(reinterpret_cast(p_sum)); + // gmem_softmax_lse.store(reinterpret_cast(p_sum_log)); + // gmem_softmax_lse.move(); + + // // Finalize softmax on the accumulators of P^T. + // softmax.scale(p_sum); + + constexpr bool encode_dropout_in_sign_bit = Return_softmax; + if (Is_dropout) { + // softmax.template apply_dropout(ph, params.p_dropout_in_uint); + // softmax.template apply_dropout(ph, ph1, params.p_dropout_in_uint); + // softmax.template apply_dropout_16bits(ph, ph1, params.p_dropout_in_uint16_t); + unsigned int warp_idx = threadIdx.x / 32; + // TODO: this should change after we rearrange the warps (e.g. cutlass branch) + unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx; + // We want to use actual_seqlen_k, not seqlen_k, since seqlen_k could be rounded + // differently in the fwd and bwd pass. E.g., for d=128 on A100, fwd rounds seqlen_k + // to multiples of 256 while bwd rounds seqlen_k to multiples of 128. + unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx; + softmax.template apply_dropout_16bits(ph, params.p_dropout_in_uint16_t, philox_subsequence); + } + + using Frag_p = fmha::Fragment_a; + Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; + static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); + static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); + softmax.template pack(frag_p); + if (Return_softmax) { + gmem_s.store(frag_p, mask); + gmem_s.move(step_stride); + } + + // Commit the values for Q into shared memory. + if (l + step_stride < steps) { + gmem_q.commit(gemm_q_k.smem_q); + } + + if (Is_dropout && encode_dropout_in_sign_bit) { + #pragma unroll + for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) { + #pragma unroll + for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) { + frag_p[ki][mi].template hrelu_(); + } + } + } + + // Declare the accumulators for the 2nd gemm. + fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; + fmha::Clear_accumulator::apply(acc_o); + + // Do this part of O = P^T * V^T. + #pragma unroll + for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { + fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]); + // if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki])); + // float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki])); + // printf("Per warp, threadIdx.x = %d, frag_p = %.6f, %.6f, frag_v = %.6f, %.6f, acc_o=%.6f\n", threadIdx.x, tmp_p.x, tmp_p.y, tmp_v.x, tmp_v.y, acc_o[0][0].elt(0)); + // } + } + + // if ((threadIdx.x % 32 == 16) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("Per warp, threadIdx.x = %d, acc_o=%.6f\n", threadIdx.x, acc_o[0][2].elt(0)); + // } + + // The mapping from tidx to rows changes between the softmax and the + // O-reduction. So we recalculate the max. + float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; + int rows[Gmem_tile_o::STGS_PER_LOOP]; + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG; + } + softmax.reduce_max_after_sync_(p_max_o, rows); + static_assert(Mma_tile_o::MMAS_M == 1); + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + p_max_o[jj][0] *= params.scale_bmm1f; + } + float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP]; + if (!Is_first) { + smem_softmax_lse.load(p_prev_scale_o, rows); + } + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]); + // } + // } + + static_assert(Gmem_tile_o::LOOPS == 1); + + // Swizzle the elements and do the final reduction. + smem_o.store(acc_o, 0); + + // Make sure the data is in shared memory. + __syncthreads(); + + static_assert(Mma_tile_o::MMAS_M == 1); + float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; + softmax.reduce_sum_after_sync_(p_sum_o, rows); + if (!Is_first) { + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]); + p_sum_o[jj][0] += p_prev_scale_o[jj]; + } + } + + float p_sum_log[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; + #pragma unroll + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + float sum = p_sum_o[jj][0]; + p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum); + // if (sum == 0.f || sum != sum) { + // printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]); + // } + // if (Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("p_sum_log=%.6f\n", p_sum_log[jj][0]); + // } + // } + if (tidx % Gmem_tile_o::THREADS_PER_ROW == 0) { + gmem_softmax_lse.store_row( + reinterpret_cast(p_sum_log[jj]), rows[jj]); + } + } + gmem_softmax_lse.move(step_stride); + + // Load from shared memory. + if (!Is_first) { + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + out[jj] = fmha::fmul4(out[jj], p_prev_scale_o[jj]); + } + } + smem_o.template load(out); + + const bool is_final_write = + Is_last + || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k) + || ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); + #pragma unroll + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + float sum = p_sum_o[jj][0]; + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + if (Is_dropout && is_final_write) { + inv_sum *= params.rp_dropout; + } + out[jj] = fmha::fmul4(out[jj], inv_sum); + } + + // if (Is_dropout && Is_last) { + // for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + // out[jj] = fmha::fmul4(out[jj], params.rp_dropout); + // } + // } + + // Output the values. + if (is_final_write) { + gmem_o.template store(out, 0); + gmem_o.move(step_stride); + } else { + gmem_o_tmp.store(out, 0); + } + + // Move to the next part of the output. + if (!(Is_first && Is_last)) { gmem_o_tmp.move(step_stride); } + gemm_q_k.reload_k(); + + // Make sure we are reading from the correct buffer. + gemm_q_k.smem_q.move_to_next_read_buffer(); + // Trigger the load from shared memory for the next series of Q values. + if (l + step_stride < steps) { + gemm_q_k.reload_q(); + } + } // Outer loop over the sequence length. +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void device_1xN_loop(const Params ¶ms) { + + // The block index for the batch. + const int bidb = blockIdx.x; + // The block index for the head. + const int bidh = blockIdx.y; + // The block index. + const int bidx = gridDim.x * bidh + bidb; + // The thread index. + const int tidx = threadIdx.x; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 16 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern. +// auto seeds = at::cuda::philox::unpack(params.philox_args); +// if (bidx == 0 && tidx == 0) { +// params.rng_state[0] = std::get<0>(seeds); +// params.rng_state[1] = std::get<1>(seeds); +// } + Philox ph(0, 0, 0 + (bidb * params.h + bidh) * 32 + tidx % 32); + constexpr int M = Kernel_traits::Cta_tile_p::M; + const int STEPS = (params.seqlen_q + M - 1) / M; + + constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + if (params.seqlen_k == blocksize_c) { + fmha::device_1xN_(params, bidb, bidh, STEPS, ph, 0); + } else { + const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; + fmha::device_1xN_(params, bidb, bidh, STEPS, ph, 0); + for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { + fmha::device_1xN_(params, bidb, bidh, STEPS, ph, loop_step_idx); + } + fmha::device_1xN_(params, bidb, bidh, STEPS, ph, max_loop_steps - 1); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_hdim128.cu b/candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_hdim128.cu new file mode 100644 index 00000000..5ce9d147 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_hdim128.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "fmha_fwd_launch_template.h" + +void run_fmha_fwd_hdim128(Launch_params &launch_params) { + FP16_SWITCH(launch_params.params.is_bf16, ([&] { + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fwd_loop(launch_params); + })); +} diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_hdim32.cu b/candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_hdim32.cu new file mode 100644 index 00000000..ed15e131 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_hdim32.cu @@ -0,0 +1,17 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "fmha_fwd_launch_template.h" + +void run_fmha_fwd_hdim32(Launch_params &launch_params) { + FP16_SWITCH(launch_params.params.is_bf16, ([&] { + if (launch_params.params.seqlen_k == 128) { + using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fwd_loop(launch_params); + } else if (launch_params.params.seqlen_k >= 256) { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fwd_loop(launch_params); + } + })); +} diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_hdim64.cu b/candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_hdim64.cu new file mode 100644 index 00000000..134efa63 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_hdim64.cu @@ -0,0 +1,17 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "fmha_fwd_launch_template.h" + +void run_fmha_fwd_hdim64(Launch_params &launch_params) { + FP16_SWITCH(launch_params.params.is_bf16, ([&] { + if (launch_params.params.seqlen_k == 128) { + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fwd_loop(launch_params); + } else if (launch_params.params.seqlen_k >= 256) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fwd_loop(launch_params); + } + })); +} diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_launch_template.h b/candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_launch_template.h new file mode 100644 index 00000000..2da27e92 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha_fwd_launch_template.h @@ -0,0 +1,91 @@ +// Copyright (c) 2022, Tri Dao. + +#pragma once + +#include + +#include +#include + +#include "static_switch.h" +#include "fmha.h" +#include "fmha_fprop_kernel_1xN.h" + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 95% +// of the best efficiency. +// [2022-11-25] TD: Mark this as "inline" otherwise we get "multiple definition" error. +inline int num_splits_heuristic_fwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits) { + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm); + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (efficiency[num_splits - 1] > 0.95 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +template +__global__ void fmha_fwd_loop_kernel(FMHA_fprop_params params) { + fmha::device_1xN_loop(params); +} + +template +void run_fmha_fwd_loop(Launch_params &launch_params) { + constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c; + + constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; + // Don't need smem_size_softmax_lse if we're not looping + const int smem_size = fmha::get_dynamic_smem_size() + + (loop_steps > 1 ? smem_size_softmax_lse : 0); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, ([&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fwd_loop_kernel + : &fmha_fwd_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fwd_loop_kernel + : &fmha_fwd_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // Automatically set num_splits to maximize occupancy + if (launch_params.params.num_splits <= 0) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size); +// auto dprops = at::cuda::getCurrentDeviceProperties(); + // printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount); + constexpr int M = Kernel_traits::Cta_tile_p::M; + launch_params.params.num_splits = num_splits_heuristic_fwd( + launch_params.params.b * launch_params.params.h, launch_params.multi_processor_count, + ctas_per_sm, + /*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M)) + ); + } + // printf("smem_size = %d\n", smem_size); + dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + })); +} diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha_kernel.h b/candle-extensions/candle-flash-attn-v1/kernels/fmha_kernel.h new file mode 100644 index 00000000..1e164bad --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha_kernel.h @@ -0,0 +1,78 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include "philox.cuh" + +#include "fmha.h" +#include "fmha/utils.h" +#include "fmha/smem_tile.h" +#include "fmha/gmem_tile.h" +#include "fmha/mask.h" +#include "fmha/softmax.h" + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfoPadded { + + template + __device__ BlockInfoPadded(const Params ¶ms, + const int bidb, + const int bidh, + const int tidx) + : bidb(bidb), bidh(bidh), h(params.h) { + + // The block index. + sum_s_k = params.cu_seqlens_k[bidb]; + actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - sum_s_k; + sum_s_q = params.cu_seqlens_q[bidb]; + actual_seqlen_q = params.cu_seqlens_q[bidb + 1] - sum_s_q; + + tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx; + } + + __device__ bool stop_early(const int start_col = 0) const { + return actual_seqlen_k <= start_col; + } + + int actual_seqlen_q; + int actual_seqlen_k; + int sum_s_q; + int sum_s_k; + int bidh; + int bidb; + int tidx_global; + int h; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/candle-extensions/candle-flash-attn-v1/kernels/fmha_utils.h b/candle-extensions/candle-flash-attn-v1/kernels/fmha_utils.h new file mode 100644 index 00000000..057eed28 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/fmha_utils.h @@ -0,0 +1,99 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define FMHA_CHECK_CUDA( call ) \ + do { \ + cudaError_t status_ = call; \ + if( status_ != cudaSuccess ) { \ + fprintf( stderr, \ + "CUDA error (%s:%d): %s\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString( status_ ) ); \ + exit( 1 ); \ + } \ + } while( 0 ) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +enum Data_type { DATA_TYPE_FP16, DATA_TYPE_BF16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) { + if( dtype == DATA_TYPE_FP16 ) { + half x = __float2half_rn( norm ); + uint16_t h = reinterpret_cast( x ); + ushort2 h2 = { h, h }; + alpha = reinterpret_cast( h2 ); + } else if( dtype == DATA_TYPE_BF16 ) { + __nv_bfloat16 x = __float2bfloat16( norm ); + uint16_t h = reinterpret_cast( x ); + ushort2 h2 = { h, h }; + alpha = reinterpret_cast( h2 ); + } else if( dtype == DATA_TYPE_FP32 ) { + alpha = reinterpret_cast( norm ); + } else if( dtype == DATA_TYPE_INT32 ) { + int32_t inorm = static_cast( norm ); + alpha = reinterpret_cast( inorm ); + } else { + assert( false ); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) { + switch( dtype ) { + case DATA_TYPE_FP32: + return n * 4; + case DATA_TYPE_FP16: + return n * 2; + case DATA_TYPE_BF16: + return n * 2; + case DATA_TYPE_INT32: + return n * 4; + case DATA_TYPE_INT8: + return n; + default: + assert( false ); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/candle-extensions/candle-flash-attn-v1/kernels/philox.cuh b/candle-extensions/candle-flash-attn-v1/kernels/philox.cuh new file mode 100644 index 00000000..bab5c39f --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/philox.cuh @@ -0,0 +1,157 @@ +// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/multihead_attn/philox.cuh +// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu +#pragma once +// Philox CUDA. + +namespace { + +class Philox { +public: + __device__ inline Philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) + : key(reinterpret_cast(seed)) { + //key.x = (unsigned int)seed; + //key.y = (unsigned int)(seed >> 32); + //counter = make_uint4(0, 0, 0, 0); + //counter.z = (unsigned int)(subsequence); + //counter.w = (unsigned int)(subsequence >> 32); + //STATE = 0; + //incr_n(offset / 4); + + ull2 * tmp = reinterpret_cast(&counter); + tmp->x = offset / 4; + tmp->y = subsequence; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w); + // } + } + + __device__ inline uint4 operator()() { + uint4 counter_ = counter; + uint2 key_ = key; + // 7-round philox + #pragma unroll + for (int i = 0; i < 6; i++) { + counter_ = single_round(counter_, key_); + key_.x += (kPhilox10A); + key_.y += (kPhilox10B); + } + uint4 output = single_round(counter_, key_); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); + // } + incr(); + return output; + } + + __device__ inline uint4 operator()(const unsigned long long subsequence) { + uint4 counter_ = counter; + ull2 * tmp = reinterpret_cast(&counter_); + tmp->y = subsequence; + // if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("tidx = %d, counter_: %u, %u, %u, %u\n", threadIdx.x, counter_.x, counter_.y, counter_.z, counter_.w); + // } + uint2 key_ = key; + // 7-round philox + #pragma unroll + for (int i = 0; i < 6; i++) { + counter_ = single_round(counter_, key_); + key_.x += (kPhilox10A); + key_.y += (kPhilox10B); + } + uint4 output = single_round(counter_, key_); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); + // } + return output; + } + +private: + struct ull2 { + uint64_t x; + uint64_t y; + }; + uint4 counter; + const uint2 key; + + // __device__ inline void incr_n(unsigned long long n) { + // unsigned int nlo = (unsigned int)(n); + // unsigned int nhi = (unsigned int)(n >> 32); + // counter.x += nlo; + // if (counter.x < nlo) + // nhi++; + // counter.y += nhi; + // if (nhi <= counter.y) + // return; + // if (++counter.z) + // return; + // ++counter.w; + // } + + __device__ uint4 incr(uint4 ctr) { + uint4 res; + asm ("add.cc.u32 %0, %4, %8;\n\t" + "addc.cc.u32 %1, %5, %9;\n\t" + "addc.cc.u32 %2, %6, %10;\n\t" + "addc.u32 %3, %7, %11;\n\t" + : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) + : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w), + "n"(1), "n"(0), "n"(0), "n"(0)); + return res; + } + + __device__ inline void incr() { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // } + counter = incr(counter); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // } + } + + // __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, + // unsigned int *result_high) { + // *result_high = __umulhi(a, b); + // return a * b; + // } + + __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { + uint2 *res; + unsigned long long tmp; + asm ("mul.wide.u32 %0, %1, %2;\n\t" + : "=l"(tmp) + : "r"(a), "r"(b)); + res = (uint2*)(&tmp); + return *res; + } + + __device__ inline uint4 single_round(const uint4 ctr, const uint2 key) { + //unsigned int hi0; + //unsigned int hi1; + //unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); + //unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); + //uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; + } + + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + static const unsigned long kPhiloxSA = 0xD2511F53; + static const unsigned long kPhiloxSB = 0xCD9E8D57; +}; + +// Inverse of 2^32. +constexpr float M_RAN_INVM32 = 2.3283064e-10f; +__device__ __inline__ float4 uniform4(const uint4 x) { + return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32, + x.w * M_RAN_INVM32); +} + +} // namespace diff --git a/candle-extensions/candle-flash-attn-v1/kernels/static_switch.h b/candle-extensions/candle-flash-attn-v1/kernels/static_switch.h new file mode 100644 index 00000000..53bcf35d --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/kernels/static_switch.h @@ -0,0 +1,40 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +// and https://github.com/facebookresearch/xformers/blob/main/xformers/csrc/attention/cuda/fmha/gemm_kernel_utils.h#L8 + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, ([&] { +/// some_function(...); +/// })); +/// ``` +/// We need "({" and "})" to make sure that the code is a single argument being passed to the macro. +#define BOOL_SWITCH(COND, CONST_NAME, F) \ + { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + F(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + F(); \ + } \ + } + +// modified from BOOL_SWITCH +// because MSVC cannot handle std::conditional with constexpr variable +#define FP16_SWITCH(COND, F) \ + { \ + if (COND) { \ + using elem_type = __nv_bfloat16; \ + F(); \ + } else { \ + using elem_type = __half; \ + F(); \ + } \ + } diff --git a/candle-extensions/candle-flash-attn-v1/src/ffi.rs b/candle-extensions/candle-flash-attn-v1/src/ffi.rs new file mode 100644 index 00000000..d5474f2a --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/src/ffi.rs @@ -0,0 +1,41 @@ +use core::ffi::{c_int, c_void}; + +extern "C" { + pub(crate) fn run_mha( + q_ptr: *const c_void, + k_ptr: *const c_void, + v_ptr: *const c_void, + o_ptr: *const c_void, + o_tmp_ptr: *const c_void, + softmax_lse_ptr: *const c_void, + cu_seqlens_q_ptr: *const i32, + cu_seqlens_k_ptr: *const i32, + + q_row_stride: u32, + k_row_stride: u32, + v_row_stride: u32, + o_row_stride: u32, + o_tmp_row_stride: u32, + + q_head_stride: u32, + k_head_stride: u32, + v_head_stride: u32, + o_head_stride: u32, + o_tmp_head_stride: u32, + + b: u32, + h: u32, + d: u32, + softmax_scale: f32, + + seqlen_q: u32, + seqlen_k: u32, + + is_causal: c_int, + is_bf16: c_int, + + multi_processor_count: i32, + num_splits: i32, + ); + +} diff --git a/candle-extensions/candle-flash-attn-v1/src/lib.rs b/candle-extensions/candle-flash-attn-v1/src/lib.rs new file mode 100644 index 00000000..bc1d6866 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/src/lib.rs @@ -0,0 +1,272 @@ +mod ffi; + +use candle::backend::BackendStorage; +use candle::cuda_backend::cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT; +use candle::cuda_backend::cudarc::driver::DevicePtr; +use candle::cuda_backend::WrapErr; +use candle::{CpuStorage, Layout, Result, Shape, Tensor}; +use half::{bf16, f16}; +use std::ptr; + +struct FlashAttnVarLen { + softmax_scale: f32, + causal: bool, + max_seqlen_q: usize, + max_seqlen_k: usize, + seqlens_q: Tensor, + seqlens_k: Tensor, +} + +fn round_multiple(x: usize, m: usize) -> usize { + (x + m - 1) / m * m +} + +impl FlashAttnVarLen { + fn cuda_fwd_t< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k: &candle::CudaStorage, + k_l: &Layout, + v: &candle::CudaStorage, + v_l: &Layout, + is_bf16: bool, + ) -> Result<(candle::CudaStorage, Shape)> { + // https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327 + let dev = q.device(); + let out_shape = q_l.shape().clone(); + let out_l = Layout::contiguous(&out_shape); + + let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout(); + let seqlens_q = match &*seqlens_q { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, // Should be i32! + _ => candle::bail!("seqlens_q must be a cuda tensor"), + }; + let seqlens_q = match seqlens_q_layout.contiguous_offsets() { + Some((o1, o2)) => seqlens_q.slice(o1..o2), + None => candle::bail!("seqlens_q has to be contiguous"), + }; + + let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout(); + let seqlens_k = match &*seqlens_k { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, // Should be i32! + _ => candle::bail!("seqlens_k must be a cuda tensor"), + }; + let seqlens_k = match seqlens_k_layout.contiguous_offsets() { + Some((o1, o2)) => seqlens_k.slice(o1..o2), + None => candle::bail!("seqlens_k has to be contiguous"), + }; + + let q = q.as_cuda_slice::()?; + let k = k.as_cuda_slice::()?; + let v = v.as_cuda_slice::()?; + let q = q.slice(q_l.start_offset()..); + let k = k.slice(k_l.start_offset()..); + let v = v.slice(v_l.start_offset()..); + + let q_stride = q_l.stride(); + let k_stride = k_l.stride(); + let v_stride = v_l.stride(); + let o_stride = out_l.stride(); + + let q_rank = q_stride.len(); + let k_rank = k_stride.len(); + let v_rank = v_stride.len(); + let o_rank = o_stride.len(); + + if q_rank != 3 || k_rank != 3 || v_rank != 3 { + candle::bail!( + "flash-attn-varlen expects input tensors of rank 3 (q: {q_rank}, k: {k_rank}, v: {v_rank}" + ) + } + if q_stride[q_rank - 1] != 1 { + candle::bail!("the last dim of q must be contiguous {q_stride:?}") + } + if k_stride[k_rank - 1] != 1 { + candle::bail!("the last dim of k must be contiguous {k_stride:?}") + } + if v_stride[v_rank - 1] != 1 { + candle::bail!("the last dim of v must be contiguous {v_stride:?}") + } + + let (total_q, num_heads, head_size) = q_l.shape().dims3()?; + let (total_k, num_heads_k, _head_size) = k_l.shape().dims3()?; + let expected_kv = (total_k, num_heads_k, head_size); + if expected_kv != k_l.shape().dims3()? { + candle::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape()) + } + if expected_kv != v_l.shape().dims3()? { + candle::bail!("shape mismatch q {:?} and v {:?}", q_l.shape(), v_l.shape()) + } + if head_size > 256 { + candle::bail!("only supports head dimension at most 256 (got {head_size})") + } + if head_size % 8 != 0 { + // TODO: Handle head sizes that are not a multiple of 8 via some padding. + candle::bail!("only supports head sizes that are a multiple of 8 (got {head_size})") + } + if num_heads % num_heads_k != 0 { + candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") + } + + let nseqlens_q = seqlens_q_layout.shape().dims1()?; + if nseqlens_q < 2 { + candle::bail!("seqlens_q should have a len >= 2 {nseqlens_q}") + } + let nseqlens_k = seqlens_k_layout.shape().dims1()?; + if nseqlens_k != nseqlens_q { + candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}") + } + let batch_size = nseqlens_q - 1; + + let elem_count = out_shape.elem_count(); + let dst = unsafe { dev.alloc::(elem_count) }.w()?; + let softmax_lse = dev + .alloc_zeros::(batch_size * num_heads * self.max_seqlen_q) + .w()?; + + let blocksize_c = if head_size > 64 { 128 } else { 256 }; + let max_seqlen_k_rounded = round_multiple(self.max_seqlen_k, blocksize_c); + let max_seqlen_q_rounded = round_multiple(self.max_seqlen_q, 16); + + let dst_temp = if max_seqlen_k_rounded > blocksize_c { + Some(unsafe { dev.alloc::(total_q * num_heads * head_size) }.w()?) + } else { + None + }; + + let causal = if self.causal { 1 } else { 0 }; + let is_bf16 = if is_bf16 { 1 } else { 0 }; + + let multi_processor_count = dev + .attribute(CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT) + .w()?; + + unsafe { + let q_ptr = *q.device_ptr() as *const core::ffi::c_void; + let k_ptr = *k.device_ptr() as *const core::ffi::c_void; + let v_ptr = *v.device_ptr() as *const core::ffi::c_void; + let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void; + let dst_tmp_ptr = if let Some(slice) = &dst_temp { + *slice.device_ptr() as *const core::ffi::c_void + } else { + ptr::null() + }; + let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void; + let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int; + let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int; + ffi::run_mha( + q_ptr, + k_ptr, + v_ptr, + dst_ptr, + dst_tmp_ptr, + softmax_lse_ptr, + /* cu_seqlens_q_ptr */ seqlens_q_ptr, + /* cu_seqlens_k_ptr */ seqlens_k_ptr, + /* q_row_stride */ q_stride[q_rank - 3] as u32, + /* k_row_stride */ k_stride[k_rank - 3] as u32, + /* v_row_stride */ v_stride[v_rank - 3] as u32, + /* o_row_stride */ o_stride[o_rank - 3] as u32, + /* o_tmp_row_stride */ (num_heads * head_size) as u32, + /* q_head_stride */ q_stride[q_rank - 2] as u32, + /* k_head_stride */ k_stride[k_rank - 2] as u32, + /* v_head_stride */ v_stride[v_rank - 2] as u32, + /* o_head_stride */ o_stride[o_rank - 2] as u32, + /* o_tmp_head_stride */ head_size as u32, + /* b */ batch_size as u32, + /* h */ num_heads as u32, + /* d */ head_size as u32, + /* softmax_scale*/ self.softmax_scale, + /* seqlen_q */ max_seqlen_q_rounded as u32, + /* seqlen_k */ max_seqlen_k_rounded as u32, + /* is_causal */ causal, + /* is_bf16 */ is_bf16, + /* multi_processor_count */ multi_processor_count, + /* num_splits */ 0, + ) + } + + let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev.clone()); + Ok((dst, out_shape)) + } +} + +impl candle::CustomOp3 for FlashAttnVarLen { + fn name(&self) -> &'static str { + "flash-attn-varlen" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for flash-attn") + } + + fn cuda_fwd( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k: &candle::CudaStorage, + k_l: &Layout, + v: &candle::CudaStorage, + v_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match q.dtype() { + candle::DType::F16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, false), + candle::DType::BF16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, true), + dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"), + } + } +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +pub fn flash_attn_varlen( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + causal: bool, +) -> Result { + let op = FlashAttnVarLen { + softmax_scale, + causal, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + }; + q.apply_op3(k, v, op) +} diff --git a/candle-extensions/candle-flash-attn-v1/tests/flash_attn_tests.rs b/candle-extensions/candle-flash-attn-v1/tests/flash_attn_tests.rs new file mode 100644 index 00000000..4b266cb3 --- /dev/null +++ b/candle-extensions/candle-flash-attn-v1/tests/flash_attn_tests.rs @@ -0,0 +1,61 @@ +use anyhow::Result; +use candle::{DType, Device, Tensor}; + +fn to_vec3_round(t: Tensor, digits: i32) -> Result>>> { + let b = 10f32.powi(digits); + let t = t.to_vec3::()?; + let t = t + .iter() + .map(|t| { + t.iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect() + }) + .collect(); + Ok(t) +} + +#[test] +fn flash_attn_varlen() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 48, &device)? + .to_dtype(DType::F16)? + .reshape((3, 2, 8))?; + let k = (&q / 40.)?; + let v = (&q / 50.)?; + let q = (&q / 30.)?; + + let seqlens_q = Tensor::new(&[0u32, 2u32], &device)?; + let seqlens_k = Tensor::new(&[0u32, 2u32], &device)?; + + let ys = { + let q = q.transpose(0, 1)?; + let k = k.transpose(0, 1)?; + let v = v.transpose(0, 1)?; + candle_flash_attn::flash_attn_varlen( + &q, &k, &v, &seqlens_q, &seqlens_k, 32, 32, 0.5, false, + )? + .transpose(0, 1)? + }; + let ys = ys.to_dtype(DType::F32)?; + + assert_eq!(ys.dims(), &[3, 2, 8]); + assert_eq!( + to_vec3_round(ys, 4)?, + &[ + [ + [0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238], + [0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322] + ], + [ + [0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605], + [0.428, 0.448, 0.468, 0.488, 0.5083, 0.5283, 0.5483, 0.5684] + ], + [ + [0.7554, 0.7754, 0.7954, 0.8154, 0.8354, 0.8555, 0.8755, 0.8955], + [0.7622, 0.7822, 0.8022, 0.8223, 0.8423, 0.8623, 0.8823, 0.9023] + ] + ] + ); + Ok(()) +} diff --git a/candle-extensions/candle-layer-norm/.gitignore b/candle-extensions/candle-layer-norm/.gitignore new file mode 100644 index 00000000..fbc9a58c --- /dev/null +++ b/candle-extensions/candle-layer-norm/.gitignore @@ -0,0 +1,3 @@ +.idea +target +Cargo.lock diff --git a/candle-extensions/candle-layer-norm/Cargo.toml b/candle-extensions/candle-layer-norm/Cargo.toml new file mode 100644 index 00000000..8ab1b6d3 --- /dev/null +++ b/candle-extensions/candle-layer-norm/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "candle-layer-norm" +version = "0.0.1" +edition = "2021" + +description = "Layer Norm layer for the candle ML framework." + +[dependencies] +candle = { workspace = true, features = ["cuda"] } +half = { workspace = true } + +[build-dependencies] +anyhow = { version = "1", features = ["backtrace"] } +num_cpus = "1.15.0" +rayon = "1.7.0" diff --git a/candle-extensions/candle-layer-norm/LICENSE-APACHE b/candle-extensions/candle-layer-norm/LICENSE-APACHE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/candle-extensions/candle-layer-norm/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/candle-extensions/candle-layer-norm/LICENSE-MIT b/candle-extensions/candle-layer-norm/LICENSE-MIT new file mode 100644 index 00000000..31aa7938 --- /dev/null +++ b/candle-extensions/candle-layer-norm/LICENSE-MIT @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/candle-extensions/candle-layer-norm/README.md b/candle-extensions/candle-layer-norm/README.md new file mode 100644 index 00000000..e5b3f4ac --- /dev/null +++ b/candle-extensions/candle-layer-norm/README.md @@ -0,0 +1,14 @@ +# Candle Cuda Layer Norm + +Layer Norm fused operation for the Candle ML framework. + +This Layer was adapted from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm. + +It implements fused dropout + residual + LayerNorm, building on Apex's FastLayerNorm. + +Major changes: + +- Add residual. +- Make it work for both pre-norm and post-norm architecture. +- Support more hidden dimensions (all dimensions divisible by 8, up to 8192). +- Implement RMSNorm as an option. diff --git a/candle-extensions/candle-layer-norm/build.rs b/candle-extensions/candle-layer-norm/build.rs new file mode 100644 index 00000000..d7a78038 --- /dev/null +++ b/candle-extensions/candle-layer-norm/build.rs @@ -0,0 +1,256 @@ +// Build script to run nvcc and generate the C glue code for launching the layer-norm kernel. +// The cuda build time is very long so one can set the CANDLE_LAYER_NORM_BUILD_DIR environment +// variable in order to cache the compiled artifacts and avoid recompiling too often. +use anyhow::{Context, Result}; +use rayon::prelude::*; +use std::path::PathBuf; +use std::str::FromStr; + +const KERNEL_FILES: [&str; 1] = ["ln_api.cu"]; + +fn main() -> Result<()> { + let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else( + |_| num_cpus::get_physical(), + |s| usize::from_str(&s).unwrap(), + ); + + rayon::ThreadPoolBuilder::new() + .num_threads(num_cpus) + .build_global() + .unwrap(); + + println!("cargo:rerun-if-changed=build.rs"); + for kernel_file in KERNEL_FILES.iter() { + println!("cargo:rerun-if-changed=kernels/{kernel_file}"); + } + println!("cargo:rerun-if-changed=kernels/**.cu"); + println!("cargo:rerun-if-changed=kernels/ln_fwd_kernels.cuh"); + println!("cargo:rerun-if-changed=kernels/ln_kernel_traits.h"); + println!("cargo:rerun-if-changed=kernels/ln_utils.cuh"); + println!("cargo:rerun-if-changed=kernels/static_switch.h"); + let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?); + let build_dir = match std::env::var("CANDLE_LAYER_NORM_BUILD_DIR") { + Err(_) => + { + #[allow(clippy::redundant_clone)] + out_dir.clone() + } + Ok(build_dir) => { + let path = PathBuf::from(build_dir); + path.canonicalize().expect(&format!( + "Directory doesn't exists: {} (the current directory is {})", + &path.display(), + std::env::current_dir()?.display() + )) + } + }; + set_cuda_include_dir()?; + + let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); + println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); + + let compute_cap = compute_cap()?; + + let out_file = build_dir.join("liblayernorm.a"); + + let kernel_dir = PathBuf::from("kernels"); + let cu_files: Vec<_> = KERNEL_FILES + .iter() + .map(|f| { + let mut obj_file = out_dir.join(f); + obj_file.set_extension("o"); + (kernel_dir.join(f), obj_file) + }) + .collect(); + + let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified()); + let should_compile = if out_file.exists() { + kernel_dir + .read_dir() + .expect("kernels folder should exist") + .any(|entry| { + if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) { + let in_modified = entry.metadata().unwrap().modified().unwrap(); + in_modified.duration_since(*out_modified).is_ok() + } else { + true + } + }) + } else { + true + }; + if should_compile { + cu_files + .par_iter() + .map(|(cu_file, obj_file)| { + let mut command = std::process::Command::new("nvcc"); + command + .arg("-std=c++17") + .arg("-O3") + .arg("-U__CUDA_NO_HALF_OPERATORS__") + .arg("-U__CUDA_NO_HALF_CONVERSIONS__") + .arg("-U__CUDA_NO_BFLOAT16_OPERATORS__") + .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + .arg("-U__CUDA_NO_BFLOAT162_OPERATORS__") + .arg("-U__CUDA_NO_BFLOAT162_CONVERSIONS__") + .arg(format!("--gpu-architecture=sm_{compute_cap}")) + .arg("-c") + .args(["-o", obj_file.to_str().unwrap()]) + .args(["--default-stream", "per-thread"]) + .arg("--expt-relaxed-constexpr") + .arg("--expt-extended-lambda") + .arg("--use_fast_math") + .arg("--verbose"); + if let Ok(ccbin_path) = &ccbin_env { + command + .arg("-allow-unsupported-compiler") + .args(["-ccbin", ccbin_path]); + } + command.arg(cu_file); + let output = command + .spawn() + .context("failed spawning nvcc")? + .wait_with_output()?; + if !output.status.success() { + anyhow::bail!( + "nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + &command, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ) + } + Ok(()) + }) + .collect::>()?; + let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::>(); + let mut command = std::process::Command::new("nvcc"); + command + .arg("--lib") + .args(["-o", out_file.to_str().unwrap()]) + .args(obj_files); + let output = command + .spawn() + .context("failed spawning nvcc")? + .wait_with_output()?; + if !output.status.success() { + anyhow::bail!( + "nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + &command, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ) + } + } + println!("cargo:rustc-link-search={}", build_dir.display()); + println!("cargo:rustc-link-lib=layernorm"); + println!("cargo:rustc-link-lib=dylib=cudart"); + println!("cargo:rustc-link-lib=dylib=stdc++"); + + Ok(()) +} + +fn set_cuda_include_dir() -> Result<()> { + // NOTE: copied from cudarc build.rs. + let env_vars = [ + "CUDA_PATH", + "CUDA_ROOT", + "CUDA_TOOLKIT_ROOT_DIR", + "CUDNN_LIB", + ]; + let env_vars = env_vars + .into_iter() + .map(std::env::var) + .filter_map(Result::ok) + .map(Into::::into); + + let roots = [ + "/usr", + "/usr/local/cuda", + "/opt/cuda", + "/usr/lib/cuda", + "C:/Program Files/NVIDIA GPU Computing Toolkit", + "C:/CUDA", + ]; + let roots = roots.into_iter().map(Into::::into); + let root = env_vars + .chain(roots) + .find(|path| path.join("include").join("cuda.h").is_file()) + .context("cannot find include/cuda.h")?; + println!( + "cargo:rustc-env=CUDA_INCLUDE_DIR={}", + root.join("include").display() + ); + Ok(()) +} + +#[allow(unused)] +fn compute_cap() -> Result { + println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); + + // Try to parse compute caps from env + let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { + println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}"); + compute_cap_str + .parse::() + .context("Could not parse code")? + } else { + // Use nvidia-smi to get the current compute cap + let out = std::process::Command::new("nvidia-smi") + .arg("--query-gpu=compute_cap") + .arg("--format=csv") + .output() + .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; + let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; + let mut lines = out.lines(); + assert_eq!( + lines.next().context("missing line in stdout")?, + "compute_cap" + ); + let cap = lines + .next() + .context("missing line in stdout")? + .replace('.', ""); + let cap = cap + .parse::() + .with_context(|| format!("cannot parse as int {cap}"))?; + println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}"); + cap + }; + + // Grab available GPU codes from nvcc and select the highest one + let (supported_nvcc_codes, max_nvcc_code) = { + let out = std::process::Command::new("nvcc") + .arg("--list-gpu-code") + .output() + .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); + let out = std::str::from_utf8(&out.stdout).unwrap(); + + let out = out.lines().collect::>(); + let mut codes = Vec::with_capacity(out.len()); + for code in out { + let code = code.split('_').collect::>(); + if !code.is_empty() && code.contains(&"sm") { + if let Ok(num) = code[1].parse::() { + codes.push(num); + } + } + } + codes.sort(); + let max_nvcc_code = *codes.last().unwrap(); + (codes, max_nvcc_code) + }; + + // Check that nvcc supports the asked compute cap + if !supported_nvcc_codes.contains(&compute_cap) { + anyhow::bail!( + "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}." + ); + } + if compute_cap > max_nvcc_code { + anyhow::bail!( + "CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}" + ); + } + + Ok(compute_cap) +} diff --git a/candle-extensions/candle-layer-norm/kernels/ln.h b/candle-extensions/candle-layer-norm/kernels/ln.h new file mode 100644 index 00000000..aa49093d --- /dev/null +++ b/candle-extensions/candle-layer-norm/kernels/ln.h @@ -0,0 +1,204 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + + +//#ifdef OLD_GENERATOR_PATH +//#include +//#else +//#include +//#endif + +namespace layer_norm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LaunchParams{ + + size_t elts_per_thread; + size_t workspace_bytes; + size_t barrier_size; + + int multi_processor_count; + + cudaStream_t stream; + + Params params; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ParamsBase { + ParamsBase() + : ctas_per_col(0) + , rows(0) + , cols(0) + , x(nullptr) + , mu(nullptr) + , rs(nullptr) + , gamma(nullptr) + , gamma1(nullptr) + , rowscale(nullptr) + , colscale(nullptr) + , dropout_keep_p(1.f) + , dropout_scale(1.f) + , is_rms_norm(false) + , workspace(nullptr) + , barrier(nullptr) + { + } + + // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. + int ctas_per_col; + + // Input is interpreted as matrix. We normalize across columns. + int rows; + int cols; + + // Common data pointers. + void *x0; + void *x1; + void *residual; + void *x; + void *dmask; + void *dmask1; + void *mu; + void *rs; + void *gamma; + void *gamma1; + void *rowscale; + void *colscale; + void *x0_subset; + void *z_subset; + + float inverse_cols; + + float dropout_keep_p; + float dropout_scale; + float rowscale_const; + + bool is_rms_norm; + + // Multi-CTA workspace in gmem. + void *workspace; + + // Multi-CTA sync barriers in gmem. + int *barrier; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct FwdParams : public ParamsBase { + FwdParams() + : ParamsBase() + , z(nullptr) + , z1(nullptr) + , beta(nullptr) + , beta1(nullptr) + , epsilon(0.f) + { + } + + // Output of LN FWD. + void *z; + void *z1; + void *beta; + void *beta1; + float epsilon; + + // Random state. + // at::PhiloxCudaState philox_args; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using FwdFunction = std::function&, const bool)>; +using FunctionKey = uint64_t; +using FwdRegistry = std::unordered_map; + +extern FwdRegistry FWD_FUNCS; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using fp32 = float; +using fp16 = half; +using bf16 = nv_bfloat16; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TypeId{}; + +template<> +struct TypeId{ + constexpr static uint32_t Value = 0; +}; + +template<> +struct TypeId{ + constexpr static uint32_t Value = 1; +}; + +template<> +struct TypeId{ + constexpr static uint32_t Value = 2; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Type2Key{ + constexpr static uint32_t Value = TypeId::Value << S; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct WeightType2Key : public Type2Key{}; + +template +struct InputType2Key : public Type2Key{}; + +template +struct ResidualType2Key : public Type2Key{}; + +template +struct OutputType2Key : public Type2Key{}; + +template +struct ComputeType2Key : public Type2Key{}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Types2Key{ + constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | ResidualType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; + constexpr static inline uint64_t get(const uint64_t hidden_size){ + constexpr uint64_t type_key = Value; + return (type_key << 32) | hidden_size; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FwdRegistrar{ + FwdRegistrar(FwdFunction f){ + uint64_t key = Types2Key::get(HIDDEN_SIZE); + FWD_FUNCS.insert({ key, f }); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layer_norm diff --git a/candle-extensions/candle-layer-norm/kernels/ln_api.cu b/candle-extensions/candle-layer-norm/kernels/ln_api.cu new file mode 100644 index 00000000..7897cd1b --- /dev/null +++ b/candle-extensions/candle-layer-norm/kernels/ln_api.cu @@ -0,0 +1,262 @@ +#include "ln.h" +#include "ln_fwd_kernels.cuh" +#include + +/* +Ada + +Supported Type combinations: + +input residual compute weights output +============================================ +fp32 fp32 fp32 fp32 fp32 +fp16 fp32 fp32 fp32 fp16 +fp16 fp16 fp32 fp32 fp16 +bf16 fp32 fp32 fp32 bf16 +bf16 bf16 fp32 fp32 bf16 +fp16 fp16 fp32 fp16 fp16 +bf16 bf16 fp32 bf16 bf16 + +Remarks: +Output type = Input type +Compute always in FP32 + +*/ + +namespace layer_norm { + +FwdRegistry FWD_FUNCS; + +uint64_t get_key(uint32_t wtype, uint32_t itype, uint32_t rtype, uint32_t otype, uint32_t ctype, uint64_t hidden_size) { + using namespace layer_norm; + uint64_t type_key = wtype | (itype << 2) | (rtype << 4) | (otype << 6) | (ctype << 8); + uint64_t launcher_key = (type_key << 32) | hidden_size; + return launcher_key; +} + +} + +layer_norm::FwdFunction & get_fwd_launcher(uint32_t wtype, uint32_t itype, uint32_t rtype, uint32_t otype, uint32_t ctype, uint32_t hidden_size) { + auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); + return iter->second; +} + +REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); + +REGISTER_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); + +extern "C" void run_ln( + void *x, + void *residual, + void *gamma, + void *beta, + void *dst_add, + void *dst, + void *mu, + void *rsigma, + + float epsilon, + + uint32_t hidden_size_rounded, + uint32_t rows, + uint32_t cols, + int32_t multi_processor_count, + + uint32_t wtype, + uint32_t itype, + uint32_t rtype, + uint32_t otype, + uint32_t ctype, + + int is_rms_norm +) { + layer_norm::LaunchParams launch_params; + + launch_params.multi_processor_count = multi_processor_count; + launch_params.stream = 0; + + launch_params.params.dropout_keep_p = 1.f; + launch_params.params.residual = residual; + launch_params.params.rowscale = nullptr; + launch_params.params.colscale = nullptr; + launch_params.params.x0_subset = nullptr; + launch_params.params.z_subset = nullptr; + + // Request the kernel launcher. + auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size_rounded); + + // Set the kernel runtime parameters. + layer_norm::FwdParams ¶ms = launch_params.params; + + params.rows = rows; + params.cols = cols; + params.x0 = x; + params.x = dst_add; + params.dmask = nullptr; + params.mu = mu; + params.rs = rsigma; + params.gamma = gamma; + params.beta = beta; + params.z = dst; + params.epsilon = epsilon; + params.dropout_scale = 1.f; + params.inverse_cols = 1.f / float(params.cols); + params.rowscale_const = 1.f; + params.is_rms_norm = is_rms_norm; + + // Query the kernel-specific launch parameters. + launcher(launch_params, true); + + // Launch the kernel. + launcher(launch_params, false); +} diff --git a/candle-extensions/candle-layer-norm/kernels/ln_fwd_kernels.cuh b/candle-extensions/candle-layer-norm/kernels/ln_fwd_kernels.cuh new file mode 100644 index 00000000..cf6b6583 --- /dev/null +++ b/candle-extensions/candle-layer-norm/kernels/ln_fwd_kernels.cuh @@ -0,0 +1,273 @@ +#pragma once + +// #ifdef OLD_GENERATOR_PATH +// #include +// #else +// #include +// #endif + +// #include // For at::cuda::philox::unpack +#include + +#include "ln.h" +#include "ln_utils.cuh" +#include "ln_kernel_traits.h" +#include "static_switch.h" + +namespace layer_norm { + +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) +void ln_fwd_kernel(FwdParams params) { + + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::NUM_ELTS }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using input_t = typename Ktraits::input_t; + using residual_t = typename Ktraits::residual_t; + using output_t = typename Ktraits::output_t; + using index_t = typename Ktraits::index_t; + using compute_t = typename Ktraits::compute_t; + using mask_t = typename Ktraits::mask_t; + using Ivec = typename Ktraits::Ivec; + using Rvec = typename Ktraits::Rvec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + using Mvec = typename Ktraits::Mvec; + + using Stats = typename Ktraits::Stats; + using stats_t = typename Stats::stats_t; + + const bool has_residual = params.residual != nullptr; + const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same::value); + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t r = bidm * ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); + + compute_t *mu_ptr = static_cast(params.mu); + compute_t *rs_ptr = static_cast(params.rs); + + const input_t *rowscale = static_cast(params.rowscale); + const index_t *x0_subset = static_cast(params.x0_subset); + const index_t *z_subset = static_cast(params.z_subset); + + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu + // curandStatePhilox4_32_10_t state; + // if (Is_dropout) { + // auto seeds = at::cuda::philox::unpack(params.philox_args); + // const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x; + // curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state); + //} + + const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG; + + Wvec gamma[LDGS]; + Wvec beta[LDGS]; + Wvec colscale[LDGS]; + index_t idx = c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + gamma[it].load_from(params.gamma, idx); + if (params.beta != nullptr) { + beta[it].load_from(params.beta, idx); + } else { + beta[it].zero_(); + } + if (Has_colscale) { colscale[it].load_from(params.colscale, idx); } + idx += VEC_COLS_PER_LDG; + } + } + + for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { + const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const; + const int row_x0 = !Has_subset ? row + 1 : x0_subset[row]; + const int row_z = !Has_subset ? row + 1 : z_subset[row]; + const bool load_x0 = !Has_subset || row_x0 > 0; + index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; + index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); + compute_t xf[LDGS * NUM_ELTS]; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Ivec x0; + Rvec residual; + Rvec x; + Mvec dmask; + if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } + if (has_residual) { residual.load_from(params.residual, idx_x); } + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use + // the more efficient curand_uniform4. + compute_t x_ij; + if (load_x0) { + //mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; + mask_t keep = false; + // if (Is_dropout) { dmask.data.elt[jt] = keep; } + compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val; + // x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; + if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); } + x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij; + } else { + x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f; + } + if (save_x) { x.data.elt[jt] = x_ij; } + xf[it * NUM_ELTS + jt] = x_ij; + } + if (save_x) { x.store_to(params.x, idx_x); } + // if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); } + idx_x += VEC_COLS_PER_LDG; + idx_x0 += VEC_COLS_PER_LDG; + } + } + + static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now"); + const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG; + const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG; + const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG; + auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int { + // Need to convert to int, otherwise the subtraction will wrap around. + const index_t valid_partial_vecs_in_warp = + std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)), + int(THREADS_PER_WARP)); + return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS; + }; + stats_t s = stats.template compute( + xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS + ); + + compute_t mu = layer_norm::Get<0>::of(s); + compute_t m2 = layer_norm::Get<1>::of(s); + + if( bidn == 0 && warp_n == 0 && lane == 0 ) { + mu_ptr[row] = mu; + } + + compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu)); + + if( bidn == 0 && warp_n == 0 && lane == 0 ) { + rs_ptr[row] = rs; + } + + const bool save_z = !Has_subset || row_z > 0; + if (save_z) { + index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Ovec z; + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f))); + compute_t g_ij = gamma[it].data.elt[jt]; + compute_t b_ij = beta[it].data.elt[jt]; + z.data.elt[jt] = output_t(g_ij * y_ij + b_ij); + } + z.store_to(params.z, idx_z); + idx_z += VEC_COLS_PER_LDG; + } + } + } + + } +} + +} // namespace layer_norm + +using namespace layer_norm; + +template< + typename weight_t, + typename input_t, + typename residual_t, + typename output_t, + typename compute_t, + typename index_t, + int HIDDEN_SIZE, + int CTAS_PER_ROW, + int WARPS_M, + int WARPS_N, + int BYTES_PER_LDG +> +void launch_(LaunchParams &launch_params, const bool configure_params){ + + using Kernel_traits = Kernel_traits; + bool has_colscale = launch_params.params.colscale != nullptr; + bool has_subset = launch_params.params.x0_subset != nullptr; + bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; + BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { + BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { + BOOL_SWITCH(has_subset, HasSubsetConst, [&] { + BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { + auto kernel = &ln_fwd_kernel; + if( configure_params ) { + int ctas_per_sm; + CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); + launch_params.params.ctas_per_col = launch_params.multi_processor_count * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; + launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if(Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW + * sizeof(typename Kernel_traits::Stats::stats_t) + * 2; + } + return; + } + + if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + + if( Kernel_traits::CTAS_PER_ROW == 1 ) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); + } + }); + }); + }); + }); +} diff --git a/candle-extensions/candle-layer-norm/kernels/ln_kernel_traits.h b/candle-extensions/candle-layer-norm/kernels/ln_kernel_traits.h new file mode 100644 index 00000000..1006f8af --- /dev/null +++ b/candle-extensions/candle-layer-norm/kernels/ln_kernel_traits.h @@ -0,0 +1,172 @@ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace layer_norm { +template< + uint32_t HIDDEN_SIZE_, + typename weight_t_, + typename input_t_, + typename residual_t_, + typename output_t_, + typename compute_t_, + typename index_t_, + uint32_t THREADS_PER_CTA_ +> +struct Kernel_traits_base { + + using weight_t = weight_t_; + using input_t = input_t_; + using residual_t = residual_t_; + using output_t = output_t_; + using compute_t = compute_t_; + using index_t = index_t_; + + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; + enum { THREADS_PER_WARP = 32 }; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + uint32_t HIDDEN_SIZE_, + typename weight_t_, + typename input_t_, + typename residual_t_, + typename output_t_, + typename compute_t_, + typename index_t_, + bool Has_colscale, + uint32_t THREADS_PER_CTA_, + uint32_t BYTES_PER_LDG_, + typename Base = Kernel_traits_base +> +struct Kernel_traits_finalize : public Base { + enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; + static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP); + // Bytes per global load from the input. + enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; + // Number of elements fetched by a global load. + enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; + // Bytes per global store of the weights. + enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; + static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!"); + static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!"); + // The total number of BYTES_PER_LDG-wide words in a hidden vector. + enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; + static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); + + // Shared memory size to transpose the CTA result. + enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; + // Shared memory size to coalsece the CTA result. + enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; + // Shared memory requirement per CTA. + static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2; + enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT }; + + // The type of the reducer. + using Reducer = layer_norm::Reducer; + + // Condition for the whole CTA to participate in syncthreads. + static_assert(COLS % Base::THREADS_PER_WARP == 0); + enum { CTAS = COLS / Base::THREADS_PER_WARP }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template< + typename weight_t_, + typename input_t_, + typename residual_t_, + typename output_t_, + typename compute_t_, + typename index_t_, + uint32_t HIDDEN_SIZE_, + uint32_t CTAS_PER_ROW_, + uint32_t WARPS_M_, + uint32_t WARPS_N_, + uint32_t BYTES_PER_LDG_ = 16, + typename Base = Kernel_traits_base< + HIDDEN_SIZE_, + weight_t_, + input_t_, + residual_t_, + output_t_, + compute_t_, + index_t_, + WARPS_M_*WARPS_N_*THREADS_PER_WARP + > +> +struct Kernel_traits : public Base { + + using input_t = typename Base::input_t; + using residual_t = typename Base::residual_t; + using weight_t = typename Base::weight_t; + using compute_t = typename Base::compute_t; + using output_t = typename Base::output_t; + using index_t = typename Base::index_t; + // using mask_t = unsigned char; + using mask_t = bool; + + enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; + enum { WARPS_M = WARPS_M_ }; + enum { WARPS_N = WARPS_N_ }; + enum { COLS = HIDDEN_SIZE_ }; + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; + enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; + + enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; + enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; + enum { ROWS_PER_CTA = WARPS_M }; + + enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; + enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; + // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed + enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; + static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); + + using reduce_t = typename layer_norm::TypeToVec2::Type; + using Reducer = layer_norm::Reducer; + + enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; + enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; + + using Ivec = layer_norm::Vec; + using Rvec = layer_norm::Vec; + using Ovec = layer_norm::Vec; + using Wvec = layer_norm::Vec; + using Cvec = layer_norm::Vec; + using Mvec = layer_norm::Vec; + enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; + + // Assume that each thread can handle the same number of elements in the output and weights as in the input. + static_assert(sizeof(input_t) == sizeof(output_t)); + static_assert(sizeof(input_t) <= sizeof(residual_t)); + // The number of columns fetched per load from input: one per thread. + enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; + // The total number of vectorized loads/stores per hidden vector. + enum { VEC_COLS = COLS / ELTS_PER_LDG }; + // The number of loads per thread for the input. + enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; + static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); + //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); + + using Stats = layer_norm::Stats; + enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layer_norm diff --git a/candle-extensions/candle-layer-norm/kernels/ln_utils.cuh b/candle-extensions/candle-layer-norm/kernels/ln_utils.cuh new file mode 100644 index 00000000..3e5eeb77 --- /dev/null +++ b/candle-extensions/candle-layer-norm/kernels/ln_utils.cuh @@ -0,0 +1,728 @@ +#pragma once + +#include + +#include +#include + +#include "ln.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr uint32_t THREADS_PER_WARP = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline void check_cuda_(cudaError_t status, const char *file, int line) { + if( status != cudaSuccess ) { + fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line); + exit(status); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CHECK_CUDA(ans) \ + { check_cuda_((ans), __FILE__, __LINE__); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define DIVUP(x, y) (((x) + ((y)-1)) / (y)) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ + void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ + const bool configure_params) { \ + launch_( \ + launch_params, configure_params); \ + } \ + static FwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ + ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void operator+=(float2 & a, const float2 & b){ + a.x += b.x; + a.y += b.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Sum { + inline __device__ Sum(){} + inline __device__ T operator()(const T &a, const T &b){ + return a + b; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){ + return __shfl_xor_sync(uint32_t(-1), x, idx); +} + +template<> +inline __device__ float2 warp_shuffle_xor(const float2 & x, uint32_t idx){ + return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) }; +} + +template +inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){ + return __shfl_down_sync(uint32_t(-1), x, idx); +} + +template<> +inline __device__ float2 warp_shuffle_down(const float2 & x, uint32_t idx){ + return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace layer_norm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct uint16 { + uint4 u; + uint4 v; + uint4 s; + uint4 t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct uint8 { + uint4 u; + uint4 v; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BytesToType {}; + +template<> +struct BytesToType<64> { + using Type = uint16; + static_assert(sizeof(Type) == 64); +}; + +template<> +struct BytesToType<32> { + using Type = uint8; + static_assert(sizeof(Type) == 32); +}; + +template<> +struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> +struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> +struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> +struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> +struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TypeToVec2 {}; + +template<> +struct TypeToVec2 { + using Type = float2; +}; + +template<> +struct TypeToVec2 { + using Type = half2; +}; + +template<> +struct TypeToVec2 { + using Type = nv_bfloat162; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Get { + template + static inline __device__ R of(const T &vec); +}; + +template<> +template +inline __device__ R Get<0>::of(const T &vec) { + return vec.x; +} + +template<> +template +inline __device__ R Get<1>::of(const T &vec) { + return vec.y; +} + +template<> +template +inline __device__ R Get<2>::of(const T &vec) { + return vec.z; +} + +template<> +template +inline __device__ R Get<3>::of(const T &vec) { + return vec.w; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ Dst convert(const Src &from) { + return Dst(from); + } +}; + +template<> +struct Converter{ + static inline __device__ half2 convert(const float2 &x) { + return __float22half2_rn(x); + } +}; + +template<> +struct Converter{ + static inline __device__ nv_bfloat162 convert(const float2 &x) { +#if __CUDA_ARCH__ >= 800 + return __float22bfloat162_rn(x); +#else + union { + nv_bfloat162 raw; + nv_bfloat16 x; + nv_bfloat16 y; + } tmp; + tmp.x = __float2bfloat16_rn(x.x); + tmp.y = __float2bfloat16_rn(x.y); + return tmp.raw; +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Zeros{ + static inline __device__ T get() { + return T(0.f); + } +}; + +template<> +struct Zeros{ + static inline __device__ float2 get() { + return make_float2(0.f, 0.f); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Vec { + + enum { BYTES = NUM_ELT * sizeof(Elt_type) }; + + using Vec_type = typename BytesToType::Type; + + using Alias_type = union { + Vec_type vec; + Elt_type elt[NUM_ELT]; + }; + + Alias_type data; + + template + inline __device__ void to(Vec &other) { + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + other.data.elt[it] = S(this->data.elt[it]); + } + } + + template + inline __device__ void assign(const Op &op) { + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + this->data.elt[it] = op(it); + } + } + + inline __device__ void zero_() { + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + this->data.elt[it] = Elt_type(0.f); + } + } + + inline __device__ void load_from(const void *base_ptr, const size_t idx) { + this->data.vec = static_cast(base_ptr)[idx]; + } + + inline __device__ void store_to(void *base_ptr, const size_t idx) { + static_cast(base_ptr)[idx] = this->data.vec; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct InterCTASync { + + template + inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn) + : phase_counter_(0) + , b0_(params.barrier + bidm) // The barrier for this group of CTAs. + , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs. + { + // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! + } + + inline __device__ void spin_wait_(int *barrier, int step, int expected) { + asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); + for( int found = -1; found != expected; ) { + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); + } + } + + inline __device__ void sync(){ + // ALL THREADS MUST ENTER! + + // We switch barrier every iteration. + int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; + // We decrement every other iteration. + bool dec = phase_counter_ & 0x2; + int step = dec ? -1 : 1; + int expected = dec ? 0 : CTAS_PER_ROW; + // There are only 4 phases: up/down for b0/b1. + phase_counter_ = (phase_counter_ + 1) & 0x3; + + if( threadIdx.x == 0 ) { + spin_wait_(barrier, step, expected); + } + // CTA waits for thread 0 + __syncthreads(); + } + + int phase_counter_; + int * b0_; + int * b1_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer : public Reducer { + + using InterCTASync = InterCTASync; + using Base = Reducer; + using Type = typename Base::Type; + + enum { SMEM_BYTES = Base::SMEM_BYTES }; + + enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; + enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; + + // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) + enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES }; + + template + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) + , inter_cta_(params, bidm, bidn) + , bidn_(bidn) // CTA id within the group. + , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) + , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) + { + } + + template + inline __device__ T allreduce(T data, Op &op) { + data = Base::reduce(data, op); + // We switch workspace every iteration. + T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + // Warp leaders 0 hold the CTA-local results. + if( this->warp_n_ == 0 && this->lane_ == 0 ) { + workspace[bidn_] = data; + } + inter_cta_.sync(); + static_assert(CTAS_PER_ROW <= 32); + T total = Zeros::get(); + if(this->lane_ < CTAS_PER_ROW){ + total = workspace[this->lane_]; + } + total = Reducer::allreduce_(total, op); + + return total; + } + + InterCTASync inter_cta_; + + T *w0_; + T *w1_; + int bidn_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer { + + using Type = T; + enum { SMEM_BYTES = 0 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : warp_n_(warp_n) + , lane_(lane) + { + } + + template + static inline __device__ T allreduce_(T data, Op &op) { + #pragma unroll + for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) { + data = op(data, warp_shuffle_xor(data, it)); + } + return data; + } + + template + inline __device__ T allreduce(T data, Op &op) { + return allreduce_(data, op); + } + + template + inline __device__ T reduce(T data, Op &op){ + // only lane 0 holds the result! + #pragma unroll + for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) { + data = op(data, warp_shuffle_down(data, it)); + } + return data; + } + int warp_n_; + int lane_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer : public Reducer { + + using Base = Reducer; + + using Type = T; + + enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) + , use0_(true) + { + smem0_ = &static_cast(smem)[warp_m * WARPS_N]; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ T allreduce(T data, Op & op) { + T * smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + data = Base::reduce(data, op); + if( this->lane_ == 0 ) { + smem[this->warp_n_] = data; + } + __syncthreads(); + T out = Zeros::get(); + #pragma unroll + for( int it = 0; it < WARPS_N; it++ ) { + out = op(out, smem[it]); + } + return out; + } + + template + inline __device__ T reduce(T data, Op &op) { + T * smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // only intra-CTA group leader holds the result! + data = Base::reduce(data, op); + if( this->lane_ == 0 ) { + smem[this->warp_n_] = data; + } + __syncthreads(); + T out = Zeros::get(); + if( this->warp_n_ == 0 && this->lane_ == 0 ) { + #pragma unroll + for( int it = 0; it < WARPS_N; it++ ) { + out = op(out, smem[it]); + } + } + return out; + } + + T * smem0_; + T * smem1_; + bool use0_; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int num_active){ + //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise) + const int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); + + #pragma unroll + for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) { + // Exchange + int_t n_b = warp_shuffle_down(n_a, step); + T m_b = warp_shuffle_down(m_a, step); + T m2_b = warp_shuffle_down(m2_a, step); + + // Update + const int_t n_ab = n_a + n_b; // We can handle one of them being 0, not both. + const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :( + const T delta = m_a - m_b; + const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; + const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; + + n_a = n_ab; + m_a = m_ab; + m2_a = m2_ab; + } + // Intra-warp broadcast (only lane 0 has valid stats). + m_a = __shfl_sync(uint32_t(-1), m_a, 0); + m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields. + + using InterCTASync = InterCTASync; + using BlockStats = Stats; + using stats_t = typename BlockStats::stats_t; + + enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; + + template + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : inter_cta_(params, bidm, bidn) + , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) + , bidn_(bidn) // CTA id within the group. + , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) + , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) + , warp_n_(warp_n) + , lane_(lane) + { + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; + // TODO rn is not really needed here.. + constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); + stats_t block_stats = block_stats_.compute(elts, block_rn); + + stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + if( warp_n_ == 0 && lane_ == 0 ) { + workspace[bidn_] = block_stats; + } + + // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. + inter_cta_.sync(); + + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); + + // Assume CTA group size in N less than 32, such that we can finalize with a single warp. + static_assert(CTAS_PER_ROW <= 32); + + // Every warp does the final reduction locally. + if( lane_ < CTAS_PER_ROW ) { + stats_t result = workspace[lane_]; + n = ELTS_PER_ROW_PER_CTA; + m = layer_norm::Get<0>::of(result); + m2 = layer_norm::Get<1>::of(result); + } + + warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); + + return { m, m2 }; + } + + InterCTASync inter_cta_; + BlockStats block_stats_; + + stats_t *w0_; + stats_t *w1_; + int bidn_; + int warp_n_; + int lane_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + + using WarpStats = Stats; + using stats_t = typename WarpStats::stats_t; + + enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; + + template + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) + , use0_(true) + { + smem0_ = static_cast(smem) + warp_m * WARPS_N; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor, + function_t valid_elts_in_warp_fn, const int num_valid_elts = N) { + stats_t * smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // Compute warp local for all WARPS_N + const auto warp_n = warp_stats_.reducer_.warp_n_; + const T warp_norm_factor = 1.f / T(Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(warp_n)); + stats_t warp_stats = warp_stats_.template compute( + elts, warp_norm_factor, valid_elts_in_warp_fn, num_valid_elts + ); + + //Each warp warp leader stores its stats + const auto lane = warp_stats_.reducer_.lane_; + if( lane == 0 ) { + smem[warp_n] = warp_stats; + } + __syncthreads(); + + int n = 0;; + T m = Zeros::get(); + T m2 = Zeros::get(); + + // Assume that there are less than 32 warps, such that we can finalize with a single warp + static_assert(WARPS_N <= 32); + if(lane < WARPS_N){ + stats_t result = smem[lane]; + n = Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(lane); + m = layer_norm::Get<0>::of(result); + m2 = layer_norm::Get<1>::of(result); + } + + warp_chan_upd_dynamic(m, m2, n, WARPS_N); + + return { m, m2 }; + } + WarpStats warp_stats_; + stats_t * smem0_; + stats_t * smem1_; + bool use0_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + + using stats_t = typename TypeToVec2::Type; + // The simple Warp reducer. + using Reducer = Reducer; + + enum { SMEM_BYTES = 0 }; + + template + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) + { + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor, + // const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) { + function_t valid_elts_in_warp_fn, const int num_valid_elts = N) { + + auto sum = Sum(); + + T m = Zeros::get(); + #pragma unroll + for( int it = 0; it < N; it++ ) { + if (Is_even_cols || (it < num_valid_elts)) { + m += elts[it]; + } + } + m = reducer_.allreduce(m, sum) * row_norm_factor; + + T m2 = Zeros::get(); + #pragma unroll + for( int it = 0; it < N; it++ ) { + if (Is_even_cols || (it < num_valid_elts)) { + T diff = (elts[it] - m); + m2 += diff * diff; + } + } + m2 = reducer_.allreduce(m2, sum); + + return {m, m2}; + } + + Reducer reducer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layer_norm diff --git a/candle-extensions/candle-layer-norm/kernels/static_switch.h b/candle-extensions/candle-layer-norm/kernels/static_switch.h new file mode 100644 index 00000000..7920ac04 --- /dev/null +++ b/candle-extensions/candle-layer-norm/kernels/static_switch.h @@ -0,0 +1,25 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/candle-extensions/candle-layer-norm/src/ffi.rs b/candle-extensions/candle-layer-norm/src/ffi.rs new file mode 100644 index 00000000..dff4676d --- /dev/null +++ b/candle-extensions/candle-layer-norm/src/ffi.rs @@ -0,0 +1,29 @@ +use core::ffi::{c_int, c_void}; + +extern "C" { + pub(crate) fn run_ln( + x: *const c_void, + residual: *const c_void, + gamma: *const c_void, + beta: *const c_void, + dst_add: *const c_void, + dst: *const c_void, + mu: *const c_void, + rsigma: *const c_void, + + epsilon: f32, + + hidden_size_rounded: u32, + rows: u32, + cols: u32, + multi_processor_count: i32, + + wtype: u32, + itype: u32, + rtype: u32, + otype: u32, + ctype: u32, + + is_rms_norm: c_int, + ); +} diff --git a/candle-extensions/candle-layer-norm/src/lib.rs b/candle-extensions/candle-layer-norm/src/lib.rs new file mode 100644 index 00000000..fb147db9 --- /dev/null +++ b/candle-extensions/candle-layer-norm/src/lib.rs @@ -0,0 +1,509 @@ +mod ffi; + +use candle::backend::BackendStorage; +use candle::cuda_backend::cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT; +use candle::cuda_backend::cudarc::driver::DevicePtr; +use candle::cuda_backend::WrapErr; +use candle::{CpuStorage, DType, Layout, Result, Shape, Storage, Tensor}; +use half::{bf16, f16}; +use std::ptr; + +fn layer_norm_internal_type(dtype: DType) -> Result { + let internal_type = match dtype { + DType::F16 => 0, + DType::BF16 => 1, + DType::F32 => 2, + dtype => candle::bail!("dtype {dtype:?} is not supported"), + }; + Ok(internal_type) +} + +pub struct LayerNorm { + pub epsilon: f32, + pub is_rms_norm: bool, + pub gamma: Tensor, + pub beta: Option, +} + +fn round_multiple(x: usize, m: usize) -> usize { + (x + m - 1) / m * m +} + +impl LayerNorm { + fn fwd< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( + &self, + x: &candle::CudaStorage, + x_l: &Layout, + r: Option<&candle::CudaStorage>, + r_l: Option<&Layout>, + ) -> Result<(candle::CudaStorage, Shape)> { + // Assume all tensors are on the same device and take device of x + let dev = x.device(); + + // Get internal layer norm type id for the given dtype + let layer_norm_type = layer_norm_internal_type(x.dtype())?; + + // Make sure that gamma is a CUDA tensor and get the underlying storage + let (g, g_l) = self.gamma.storage_and_layout(); + let g = match &*g { + Storage::Cuda(g) => g, + _ => candle::bail!("gamma must be a cuda tensor"), + }; + + // Get cuda slices for all tensors + let x = x.as_cuda_slice::()?; + let g = g.as_cuda_slice::()?; + + // Get cuda views for all tensors + let x = x.slice(x_l.start_offset()..); + let g = g.slice(g_l.start_offset()..); + + // Input matrix layout + let rows = x_l.dims()[0]; + let cols = x_l.dims()[1]; + + if !(cols % 8 == 0 && cols <= 8192) { + candle::bail!("hidden size must be % 8 and <= 8192") + } + + let x_stride = x_l.stride(); + let g_stride = g_l.stride(); + + let x_rank = x_stride.len(); + let g_rank = g_stride.len(); + + if x_rank != 2 { + candle::bail!("layer-norm expects input tensors of rank 2. Found: {x_rank}") + } + if x_stride[x_rank - 1] != 1 { + candle::bail!("the last dim of x must be contiguous {x_stride:?}") + } + if g_stride[g_rank - 1] != 1 { + candle::bail!("the last dim of g must be contiguous {g_stride:?}") + } + + // Round cols to match with the correct kernel + let cols_rounded = if cols <= 1536 { + round_multiple(cols, 256) + } else if cols <= 3072 { + round_multiple(cols, 512) + } else { + round_multiple(cols, 1024) + }; + + let is_rms_norm = if self.is_rms_norm { 1 } else { 0 }; + + // If beta is et, get ids device pointer + let b_ptr = if let Some(beta) = &self.beta { + // Make sure that beta is a CUDA tensor and get the underlying storage + let (b, b_l) = beta.storage_and_layout(); + let b = match &*b { + Storage::Cuda(b) => b, + _ => candle::bail!("gamma must be a cuda tensor"), + }; + + let b = b.as_cuda_slice::()?; + let b = b.slice(b_l.start_offset()..); + + let b_stride = b_l.stride(); + let b_rank = b_stride.len(); + + if b_stride[b_rank - 1] != 1 { + candle::bail!("the last dim of b must be contiguous {b_stride:?}") + } + *b.device_ptr() as *const core::ffi::c_void + } else { + ptr::null() as *const std::ffi::c_void + }; + + // If residual is set, get its device pointer + let r_ptr = if let (Some(r), Some(r_l)) = (r, r_l) { + // Check shape + let expected_shape = x_l.shape().dims2()?; + if r_l.shape().dims2()? != expected_shape { + candle::bail!("shape mismatch x {:?} and r {:?}", x_l.shape(), r_l.shape()); + } + + let r = r.as_cuda_slice::()?; + let r = r.slice(r_l.start_offset()..); + + let r_stride = r_l.stride(); + let r_rank = r_stride.len(); + + if r_rank != 2 { + candle::bail!("layer-norm expects input tensors of rank 2. Found: {r_rank}") + } + + if r_stride[r_rank - 1] != 1 { + candle::bail!("the last dim of r must be contiguous {r_stride:?}") + } + *r.device_ptr() as *const std::ffi::c_void + } else { + ptr::null() as *const std::ffi::c_void + }; + + // We will store the results of the residual add next to the main results + // so out has the same shape as inp * 2 + let out_shape = Shape::from((rows * 2, cols)); + + let out = unsafe { dev.alloc::(out_shape.elem_count()) }.w()?; + let dst = out.slice(..rows * cols); + let dst_add = out.slice(rows * cols..); + + // Alloc internal buffers + let mu = unsafe { dev.alloc::(rows) }.w()?; + let rsigma = unsafe { dev.alloc::(rows) }.w()?; + + // Get cuda device pointers from cuda slices + let x_ptr = *x.device_ptr() as *const core::ffi::c_void; + let g_ptr = *g.device_ptr() as *const core::ffi::c_void; + let dst_add_ptr = *dst_add.device_ptr() as *const core::ffi::c_void; + let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void; + let mu_ptr = *mu.device_ptr() as *const core::ffi::c_void; + let rsigma_ptr = *rsigma.device_ptr() as *const core::ffi::c_void; + + let multi_processors_count = dev + .attribute(CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT) + .unwrap(); + + unsafe { + // Launch Kernel + ffi::run_ln( + x_ptr, + r_ptr, + g_ptr, + b_ptr, + dst_add_ptr, + dst_ptr, + mu_ptr, + rsigma_ptr, + self.epsilon, + cols_rounded as u32, + rows as u32, + cols as u32, + multi_processors_count, + layer_norm_type, + layer_norm_type, + layer_norm_type, + layer_norm_type, + 2, + is_rms_norm, + ) + } + + let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } +} + +impl candle::CustomOp1 for LayerNorm { + fn name(&self) -> &'static str { + "fused-layer-norm" + } + + fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for fused-layer-norm") + } + + fn cuda_fwd( + &self, + x: &candle::CudaStorage, + x_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match x.dtype() { + DType::F16 => self.fwd::(x, x_l, None, None), + DType::BF16 => self.fwd::(x, x_l, None, None), + DType::F32 => self.fwd::(x, x_l, None, None), + dt => { + candle::bail!("fused-layer-norm is only supported for f32, f16 and bf16 ({dt:?})") + } + } + } +} + +impl candle::CustomOp2 for LayerNorm { + fn name(&self) -> &'static str { + "fused-layer-norm" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for fused-layer-norm") + } + + fn cuda_fwd( + &self, + x: &candle::CudaStorage, + x_l: &Layout, + r: &candle::CudaStorage, + r_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match x.dtype() { + DType::F16 => self.fwd::(x, x_l, Some(r), Some(r_l)), + DType::BF16 => self.fwd::(x, x_l, Some(r), Some(r_l)), + DType::F32 => self.fwd::(x, x_l, Some(r), Some(r_l)), + dt => { + candle::bail!("fused-layer-norm is only supported for f32, f16 and bf16 ({dt:?})") + } + } + } +} + +/// Layer Normalization Layer +/// +/// # Arguments +/// +/// * `x` - Input tensor of rank 2 +/// * `gamma` - Channel scale +/// * `beta` - Channel bias +/// * `epsilon` - A value added to the denominator for numerical stability +/// +/// The resulting tensor has the same dimensions as `x` +pub fn layer_norm( + x: &Tensor, + gamma: &Tensor, + beta: Option<&Tensor>, + epsilon: f32, +) -> Result { + let op = LayerNorm { + epsilon, + gamma: gamma.clone(), + beta: beta.cloned(), + is_rms_norm: false, + }; + let results = x.apply_op1(op)?; + let rows = x.dims()[0]; + results.narrow(0, 0, rows) +} + +/// Fused Add Layer Normalization Layer +/// +/// # Arguments +/// +/// * `x` - Input tensor of rank 2 +/// * `res` - Residual tensor of rank 2. Will be added to `x` before normalization. Must have +/// the same shape as `x`. +/// * `gamma` - Channel scale +/// * `beta` - Channel bias +/// * `epsilon` - A value added to the denominator for numerical stability +/// +/// The resulting tensors have the same dimensions as `x` +/// First tensor is the result of the normalization, second is the result of the residual add +pub fn fused_add_layer_norm( + x: &Tensor, + res: &Tensor, + gamma: &Tensor, + beta: Option<&Tensor>, + epsilon: f32, +) -> Result<(Tensor, Tensor)> { + let op = LayerNorm { + epsilon, + gamma: gamma.clone(), + beta: beta.cloned(), + is_rms_norm: false, + }; + let results = x.apply_op2(&res, op)?; + let rows = x.dims()[0]; + Ok((results.narrow(0, 0, rows)?, results.narrow(0, rows, rows)?)) +} + +/// Layer RMS Normalization Layer +/// +/// # Arguments +/// +/// * `x` - Input tensor of rank 2 +/// * `gamma` - Channel scale +/// * `beta` - Channel bias +/// * `epsilon` - A value added to the denominator for numerical stability +/// +/// The resulting tensor has the same dimensions as `x` +pub fn rms_norm(x: &Tensor, gamma: &Tensor, beta: Option<&Tensor>, epsilon: f32) -> Result { + let op = LayerNorm { + epsilon, + gamma: gamma.clone(), + beta: beta.cloned(), + is_rms_norm: true, + }; + let results = x.apply_op1(op)?; + let rows = x.dims()[0]; + results.narrow(0, 0, rows) +} + +/// Fused Add RMS Normalization Layer +/// +/// # Arguments +/// +/// * `x` - Input tensor of rank 2 +/// * `res` - Residual tensor of rank 2. Will be added to `x` before normalization. Must have +/// the same shape as `x`. +/// * `gamma` - Channel scale +/// * `beta` - Channel bias +/// * `epsilon` - A value added to the denominator for numerical stability +/// +/// The resulting tensors have the same dimensions as `x` +/// First tensor is the result of the normalization, second is the result of the residual add +pub fn fused_add_rms_norm( + x: &Tensor, + res: &Tensor, + gamma: &Tensor, + beta: Option<&Tensor>, + epsilon: f32, +) -> Result<(Tensor, Tensor)> { + let op = LayerNorm { + epsilon, + gamma: gamma.clone(), + beta: beta.cloned(), + is_rms_norm: true, + }; + let results = x.apply_op2(&res, op)?; + let rows = x.dims()[0]; + Ok((results.narrow(0, 0, rows)?, results.narrow(0, rows, rows)?)) +} + +#[cfg(test)] +mod tests { + use super::*; + use candle::{DType, Device}; + + fn layer_norm_truth( + x: &Tensor, + gamma: &Tensor, + beta: Option<&Tensor>, + epsilon: f64, + rms: bool, + ) -> Result { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + + let (_seq_len, hidden_size) = x.shape().dims2()?; + let x = x.to_dtype(internal_dtype)?; + + let x = if !rms { + let mean_x = (x.sum_keepdim(1)? / hidden_size as f64)?; + x.broadcast_sub(&mean_x)? + } else { + x + }; + + let norm_x = (x.sqr()?.sum_keepdim(1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + epsilon)?.sqrt()?)?; + + let mut x = x_normed.to_dtype(x_dtype)?.broadcast_mul(gamma)?; + if let Some(beta) = beta { + x = x.broadcast_add(beta)?; + } + Ok(x) + } + + fn to_vec2_round(t: Tensor, digits: i32) -> Result>> { + let b = 10f32.powi(digits); + let t = t.to_dtype(DType::F32)?.to_vec2::()?; + let t = t + .iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect(); + Ok(t) + } + + #[test] + fn test_layer_norm() -> Result<()> { + let device = Device::new_cuda(0)?; + + let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?; + let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + let b = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + + let res = layer_norm(&x, &g, Some(&b), 1e-12)?; + let truth = layer_norm_truth(&x, &g, Some(&b), 1e-12, false)?; + + assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?); + Ok(()) + } + + #[test] + fn test_layer_norm_no_bias() -> Result<()> { + let device = Device::new_cuda(0)?; + + let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?; + let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + + let res = layer_norm(&x, &g, None, 1e-12)?; + let truth = layer_norm_truth(&x, &g, None, 1e-12, false)?; + + assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?); + Ok(()) + } + + #[test] + fn test_rms_norm() -> Result<()> { + let device = Device::new_cuda(0)?; + + let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?; + let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + let b = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + + let res = rms_norm(&x, &g, Some(&b), 1e-12)?; + let truth = layer_norm_truth(&x, &g, Some(&b), 1e-12, true)?; + assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?); + Ok(()) + } + + #[test] + fn test_rms_norm_no_bias() -> Result<()> { + let device = Device::new_cuda(0)?; + + let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?; + let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + + let res = rms_norm(&x, &g, None, 1e-12)?; + let truth = layer_norm_truth(&x, &g, None, 1e-12, true)?; + + assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?); + Ok(()) + } + + #[test] + fn test_layer_norm_add() -> Result<()> { + let device = Device::new_cuda(0)?; + + let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?; + let r = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?; + let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + let b = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + + let (res, res_add) = fused_add_layer_norm(&x, &r, &g, Some(&b), 1e-12)?; + let truth_add = (x + r)?; + let truth = layer_norm_truth(&truth_add, &g, Some(&b), 1e-12, false)?; + assert_eq!(to_vec2_round(res_add, 3)?, to_vec2_round(truth_add, 3)?); + assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?); + Ok(()) + } + + #[test] + fn test_rms_norm_add() -> Result<()> { + let device = Device::new_cuda(0)?; + + let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?; + let r = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?; + let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + let b = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + + let (res, res_add) = fused_add_rms_norm(&x, &r, &g, Some(&b), 1e-12)?; + let truth_add = (x + r)?; + let truth = layer_norm_truth(&truth_add, &g, Some(&b), 1e-12, true)?; + assert_eq!(to_vec2_round(res_add, 3)?, to_vec2_round(truth_add, 3)?); + assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?); + Ok(()) + } +} diff --git a/candle-extensions/candle-rotary/.gitignore b/candle-extensions/candle-rotary/.gitignore new file mode 100644 index 00000000..fbc9a58c --- /dev/null +++ b/candle-extensions/candle-rotary/.gitignore @@ -0,0 +1,3 @@ +.idea +target +Cargo.lock diff --git a/candle-extensions/candle-rotary/Cargo.toml b/candle-extensions/candle-rotary/Cargo.toml new file mode 100644 index 00000000..33f4eb81 --- /dev/null +++ b/candle-extensions/candle-rotary/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "candle-rotary" +version = "0.0.1" +edition = "2021" + +description = "Rotary layer for the candle ML framework." +keywords = ["tensor", "machine-learning"] +categories = ["science"] +license = "MIT OR Apache-2.0" +readme = "README.md" + +[dependencies] +candle = { workspace = true, features = ["cuda"]} +half = { workspace = true } + +[build-dependencies] +anyhow = { version = "1", features = ["backtrace"] } +bindgen_cuda = "0.1.1" + +[dev-dependencies] +anyhow = { version = "1", features = ["backtrace"] } +candle-nn = { version = "0.3.0", features = ["cuda"] } diff --git a/candle-extensions/candle-rotary/LICENSE-APACHE b/candle-extensions/candle-rotary/LICENSE-APACHE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/candle-extensions/candle-rotary/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/candle-extensions/candle-rotary/LICENSE-MIT b/candle-extensions/candle-rotary/LICENSE-MIT new file mode 100644 index 00000000..31aa7938 --- /dev/null +++ b/candle-extensions/candle-rotary/LICENSE-MIT @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/candle-extensions/candle-rotary/README.md b/candle-extensions/candle-rotary/README.md new file mode 100644 index 00000000..303d92d7 --- /dev/null +++ b/candle-extensions/candle-rotary/README.md @@ -0,0 +1,4 @@ +# Candle Rotary + +All files in `kernels` are adapted from https://github.com/vllm-project/vllm/tree/main/csrc and are under the vLLM +Project copyright. diff --git a/candle-extensions/candle-rotary/build.rs b/candle-extensions/candle-rotary/build.rs new file mode 100644 index 00000000..46010fba --- /dev/null +++ b/candle-extensions/candle-rotary/build.rs @@ -0,0 +1,56 @@ +// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel. +// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment +// variable in order to cache the compiled artifacts and avoid recompiling too often. +use anyhow::{Context, Result}; +use std::path::PathBuf; + +const KERNEL_FILES: [&str; 1] = ["kernels/rotary.cu"]; + +fn main() -> Result<()> { + println!("cargo:rerun-if-changed=build.rs"); + for kernel_file in KERNEL_FILES.iter() { + println!("cargo:rerun-if-changed={kernel_file}"); + } + let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?); + let build_dir = match std::env::var("CANDLE_ROTARY_BUILD_DIR") { + Err(_) => + { + #[allow(clippy::redundant_clone)] + out_dir.clone() + } + Ok(build_dir) => { + let path = PathBuf::from(build_dir); + path.canonicalize().expect(&format!( + "Directory doesn't exists: {} (the current directory is {})", + &path.display(), + std::env::current_dir()?.display() + )) + } + }; + + let kernels: Vec<_> = KERNEL_FILES.iter().collect(); + let builder = bindgen_cuda::Builder::default() + .kernel_paths(kernels) + .out_dir(build_dir.clone()) + .arg("-std=c++17") + .arg("-O3") + .arg("-U__CUDA_NO_HALF_OPERATORS__") + .arg("-U__CUDA_NO_HALF_CONVERSIONS__") + .arg("-U__CUDA_NO_HALF2_OPERATORS__") + .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + .arg("--expt-relaxed-constexpr") + .arg("--expt-extended-lambda") + .arg("--use_fast_math") + .arg("--ptxas-options=-v") + .arg("--verbose"); + + let out_file = build_dir.join("librotary.a"); + builder.build_lib(out_file); + + println!("cargo:rustc-link-search={}", build_dir.display()); + println!("cargo:rustc-link-lib=rotary"); + println!("cargo:rustc-link-lib=dylib=cudart"); + println!("cargo:rustc-link-lib=dylib=stdc++"); + + Ok(()) +} diff --git a/candle-extensions/candle-rotary/kernels/cuda_compat.h b/candle-extensions/candle-rotary/kernels/cuda_compat.h new file mode 100644 index 00000000..ed3ebe7d --- /dev/null +++ b/candle-extensions/candle-rotary/kernels/cuda_compat.h @@ -0,0 +1,27 @@ +#pragma once + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif + +#ifndef USE_ROCM + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) +#else + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#endif diff --git a/candle-extensions/candle-rotary/kernels/rotary.cu b/candle-extensions/candle-rotary/kernels/rotary.cu new file mode 100644 index 00000000..b5c975cb --- /dev/null +++ b/candle-extensions/candle-rotary/kernels/rotary.cu @@ -0,0 +1,131 @@ +#include +#include +#include + +#include "cuda_compat.h" + +namespace vllm { + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int rot_dim) +{ + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = rot_dim + rot_offset; + cos = VLLM_LDG(cos_ptr + x_index); + sin = VLLM_LDG(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = VLLM_LDG(cos_ptr + x_index / 2); + sin = VLLM_LDG(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +__global__ void rotary_embedding_kernel( + scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ cos_cache, // [num_tokens, rot_dim] + const scalar_t* __restrict__ sin_cache, // [num_tokens, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + + const scalar_t* cos_ptr = cos_cache + token_idx * rot_dim; + const scalar_t* sin_ptr = sin_cache + token_idx * rot_dim; + + const int nq = num_heads * rot_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / rot_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % rot_dim; + apply_rotary_embedding(query + token_head, cos_ptr, + sin_ptr, rot_offset, rot_dim); + } + + const int nk = num_kv_heads * rot_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / rot_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % rot_dim; + apply_rotary_embedding(key + token_head, cos_ptr, + sin_ptr, rot_offset, rot_dim); + } +} + +} // namespace vllm + +#define CALL_ROTARY(T, IS_NEOX) \ + vllm::rotary_embedding_kernel<<>>( \ + reinterpret_cast(query), \ + reinterpret_cast(key), \ + reinterpret_cast(cos_cache), \ + reinterpret_cast(sin_cache), \ + rot_dim, \ + query_stride, \ + key_stride, \ + num_heads, \ + num_kv_heads, \ + head_size); + +extern "C" void rotary_embedding( + void *query, // [num_tokens, num_heads, head_size] + void *key, // [num_tokens, num_kv_heads, head_size] + void *cos_cache, // [num_tokens, rot_dim] + void *sin_cache, // [num_tokens, rot_dim] + int32_t is_neox, + + int32_t head_size, + int64_t num_tokens, + int32_t rot_dim, + int32_t num_heads, + int32_t num_kv_heads, + int64_t query_stride, + int64_t key_stride, + + uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 + ) { + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim, 512)); + const cudaStream_t stream = 0; + const bool is_neox_bool = is_neox; + + if (is_neox_bool) { + if (dtype == 0){ + CALL_ROTARY(half, true); + } else if (dtype == 1) { + CALL_ROTARY(__nv_bfloat16, true); + } else if (dtype == 2) { + CALL_ROTARY(float, true); + } + } else { + if (dtype == 0){ + CALL_ROTARY(half, false); + } else if (dtype == 1) { + CALL_ROTARY(__nv_bfloat16, false); + } else if (dtype == 2) { + CALL_ROTARY(float, false); + } + } +} diff --git a/candle-extensions/candle-rotary/src/ffi.rs b/candle-extensions/candle-rotary/src/ffi.rs new file mode 100644 index 00000000..3839177b --- /dev/null +++ b/candle-extensions/candle-rotary/src/ffi.rs @@ -0,0 +1,22 @@ +use core::ffi::{c_int, c_long, c_void}; + +extern "C" { + pub(crate) fn rotary_embedding( + query: *const c_void, + key: *const c_void, + cos_cache: *const c_void, + sin_cache: *const c_void, + + is_neox: c_int, + + head_size: c_int, + num_tokens: c_long, + rot_dim: c_int, + num_heads: c_int, + num_kv_heads: c_int, + query_stride: c_long, + key_stride: c_long, + + dtype: u32, + ); +} diff --git a/candle-extensions/candle-rotary/src/lib.rs b/candle-extensions/candle-rotary/src/lib.rs new file mode 100644 index 00000000..bb39aa90 --- /dev/null +++ b/candle-extensions/candle-rotary/src/lib.rs @@ -0,0 +1,175 @@ +mod ffi; + +use candle::cuda_backend::cudarc::driver::DevicePtr; +use candle::{DType, Device, Result, Storage, Tensor}; +use half::{bf16, f16}; +use std::ffi::{c_int, c_long}; + +fn apply_rotary_< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, +>( + query: &Tensor, + key: &Tensor, + cos_cache: &Tensor, + sin_cache: &Tensor, + is_neox: bool, +) -> Result<()> { + let dtype = query.dtype(); + if key.dtype() != dtype || cos_cache.dtype() != dtype || sin_cache.dtype() != dtype { + candle::bail!("apply-rotary expects all tensors to have the same dtype"); + } + + let internal_type = match dtype { + DType::F16 => 0, + DType::BF16 => 1, + DType::F32 => 2, + dtype => candle::bail!("dtype {dtype:?} is not supported"), + }; + + let (q, q_l) = query.storage_and_layout(); + let q = match &*q { + Storage::Cuda(q) => q, + _ => candle::bail!("query must be a cuda tensor"), + }; + + let (k, k_l) = key.storage_and_layout(); + let k = match &*k { + Storage::Cuda(k) => k, + _ => candle::bail!("key must be a cuda tensor"), + }; + + let (cc, cc_l) = cos_cache.storage_and_layout(); + let cc = match &*cc { + Storage::Cuda(cc) => cc, + _ => candle::bail!("cos_cache must be a cuda tensor"), + }; + + let (sc, sc_l) = sin_cache.storage_and_layout(); + let sc = match &*sc { + Storage::Cuda(sc) => sc, + _ => candle::bail!("sin_cache must be a cuda tensor"), + }; + + let q_rank = q_l.stride().len(); + let k_rank = k_l.stride().len(); + let cc_rank = cc_l.stride().len(); + let sc_rank = sc_l.stride().len(); + + if q_rank != 3 || k_rank != 3 { + candle::bail!("apply-rotary expects input tensors of rank 3 (k: {q_l:?}, v: {k_l:?})") + } + + if cc_rank != 2 || sc_rank != 2 { + candle::bail!("apply-rotary expects cache tensors of rank 2 (k: {cc_l:?}, v: {sc_l:?})") + } + + // Get cuda slices for all tensors + let q = q.as_cuda_slice::()?; + let k = k.as_cuda_slice::()?; + let cc = cc.as_cuda_slice::()?; + let sc = sc.as_cuda_slice::()?; + + // Get cuda views for all tensors + let q = q.slice(q_l.start_offset()..); + let k = k.slice(k_l.start_offset()..); + let cc = cc.slice(cc_l.start_offset()..); + let sc = sc.slice(sc_l.start_offset()..); + + let (num_tokens, num_heads, head_size) = q_l.shape().dims3()?; + let (num_tokens_kv, num_kv_heads, head_size_kv) = k_l.shape().dims3()?; + + if (num_tokens, head_size) != (num_tokens_kv, head_size_kv) { + candle::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape()) + } + + let rot_dim = cc_l.dims()[1]; + if (num_tokens, rot_dim) != cc_l.shape().dims2()? { + candle::bail!( + "shape mismatch cos_cache {:?}, expected {:?}", + cc_l.shape(), + (num_tokens, rot_dim) + ) + } + + if (num_tokens, rot_dim) != sc_l.shape().dims2()? { + candle::bail!( + "shape mismatch sin_cache {:?}, expected {:?}", + sc_l.shape(), + (num_tokens, rot_dim) + ) + } + + let query_stride = q_l.stride()[0]; + let key_stride = k_l.stride()[0]; + + let q_ptr = *q.device_ptr() as *const core::ffi::c_void; + let k_ptr = *k.device_ptr() as *const core::ffi::c_void; + let cc_ptr = *cc.device_ptr() as *const core::ffi::c_void; + let sc_ptr = *sc.device_ptr() as *const core::ffi::c_void; + + let neox = if is_neox { 1 } else { 0 }; + + unsafe { + ffi::rotary_embedding( + q_ptr, + k_ptr, + cc_ptr, + sc_ptr, + neox, + head_size as c_int, + num_tokens as c_long, + rot_dim as c_int, + num_heads as c_int, + num_kv_heads as c_int, + query_stride as c_long, + key_stride as c_long, + internal_type, + ) + } + Ok(()) +} + +pub fn inv_freqs(dim: usize, base: f32, device: &Device) -> Result { + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / base.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + Tensor::from_vec(inv_freq, (1, inv_freq_len), device) +} + +pub fn cos_sin(length: usize, inv_freqs: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> { + let t = Tensor::arange(0u32, length as u32, inv_freqs.device())? + .to_dtype(DType::F32)? + .reshape((length, 1))?; + let freqs = t.matmul(&inv_freqs)?; + let cos = freqs.cos()?.to_dtype(dtype)?; + let sin = freqs.sin()?.to_dtype(dtype)?; + Ok((cos, sin)) +} + +/// Apply Rotary position encoding inplace +/// +/// # Arguments +/// +/// * `query` - Query tensor of shape `(num_tokens, num_heads, head_size)`. +/// * `key` - Key tensor of shape `(num_tokens, num_kv_heads, head_size)`. +/// * `cos_cache` - Aligned cache of shape `(num_tokens, rot_dim)` +/// * `sin_cache` - Aligned cache of shape `(num_tokens, rot_dim)` +/// * `is_neox` - Use neox encoding instead of gpt-j style rotary +pub fn apply_rotary_inplace( + query: &Tensor, + key: &Tensor, + cos_cache: &Tensor, + sin_cache: &Tensor, + is_neox: bool, +) -> Result<()> { + match key.dtype() { + DType::F16 => apply_rotary_::(query, key, cos_cache, sin_cache, is_neox), + DType::BF16 => apply_rotary_::(query, key, cos_cache, sin_cache, is_neox), + DType::F32 => apply_rotary_::(query, key, cos_cache, sin_cache, is_neox), + dt => { + candle::bail!("apply_rotary is only supported for f32, f16 and bf16 ({dt:?})") + } + } +} diff --git a/candle-extensions/candle-rotary/tests/rotary_tests.rs b/candle-extensions/candle-rotary/tests/rotary_tests.rs new file mode 100644 index 00000000..444d561d --- /dev/null +++ b/candle-extensions/candle-rotary/tests/rotary_tests.rs @@ -0,0 +1,85 @@ +use anyhow::Result; +use candle::{DType, Device, Tensor, D}; + +fn to_vec3_round(t: Tensor, digits: i32) -> Result>>> { + let b = 10f32.powi(digits); + let t = t.to_vec3::()?; + let t = t + .iter() + .map(|t| { + t.iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect() + }) + .collect(); + Ok(t) +} + +fn rotate_half(xs: &Tensor) -> candle::error::Result { + let last_dim = xs.dim(D::Minus1)?; + let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; + let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim / 2)?; + Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) +} + +fn freqs(rot_dim: usize, seqlen: usize, dev: &Device) -> candle::error::Result { + let inv_freq: Vec<_> = (0..rot_dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / rot_dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; + let t = Tensor::arange(0u32, seqlen as u32, dev)? + .to_dtype(DType::F32)? + .reshape((seqlen, 1))?; + t.matmul(&inv_freq) +} + +fn apply_rotary_emb_qkv( + q: &Tensor, + k: &Tensor, + cos: &Tensor, + sin: &Tensor, +) -> Result<(Tensor, Tensor)> { + let cos = cos.unsqueeze(1)?; // (seq_len, 1, dim) + let sin = sin.unsqueeze(1)?; // (seq_len, 1, dim) + let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; + let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + Ok((q_embed, k_embed)) +} + +#[test] +fn rotary() -> Result<()> { + let device = Device::new_cuda(0)?; + + let seqlen = 12; + let num_heads = 8; + let rot_dim = 64; + + let q = Tensor::randn(0.0, 1.0, (seqlen, num_heads, rot_dim), &device)?.to_dtype(DType::F32)?; + let k = Tensor::randn(0.0, 1.0, (seqlen, num_heads, rot_dim), &device)?.to_dtype(DType::F32)?; + + let (expected_q, expected_k) = { + let freqs = freqs(rot_dim, seqlen, &device)?; + let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; + apply_rotary_emb_qkv(&q, &k, &freqs.cos()?, &freqs.sin()?) + }?; + + // Create inv freqs + let inv_freqs = candle_rotary::inv_freqs(rot_dim, 10000f32, &device)?; + // Create an over-sized cos sin cache like you would usually do + let (cos, sin) = candle_rotary::cos_sin(32, &inv_freqs, DType::F32)?; + // Positions for seqlen + let position_ids = Tensor::arange(0, seqlen as u32, &device)?; + // Filter cos and sin + let cos = cos.index_select(&position_ids, 0)?; + let sin = sin.index_select(&position_ids, 0)?; + + // Inplace + candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?; + + assert_eq!(to_vec3_round(expected_q, 3)?, to_vec3_round(q, 3)?); + assert_eq!(to_vec3_round(expected_k, 3)?, to_vec3_round(k, 3)?); + + Ok(()) +} From ba869c45ada27113e8d6d7807cfeedfa57e5e93b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Mar 2025 16:13:38 +0100 Subject: [PATCH 2/3] Take the submodules. --- .github/workflows/build.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 1c41b611..23498f34 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -29,6 +29,8 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v3 + with: + submodules: true - id: set-matrix env: From b995c21ebac1f33076d78b69fb0c0e4ff3b67b3b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Mar 2025 17:14:34 +0100 Subject: [PATCH 3/3] Submodules placed differently. --- .github/workflows/build.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 23498f34..3cb8a221 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -29,8 +29,6 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v3 - with: - submodules: true - id: set-matrix env: @@ -60,6 +58,8 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 + with: + submodules: true - name: Initialize Docker Buildx uses: docker/setup-buildx-action@v3