From 2a27823bceb0f19f9217cc4814e6c1a9c993237d Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 26 Aug 2025 20:47:36 +0000 Subject: [PATCH 01/53] Test working as I think it should work MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe revert accidental change Signed-off-by: Varun Thumbe Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe fix merge conflict Signed-off-by: Varun Thumbe bug: missed a } in the code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Signed-off-by: Varun Thumbe [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Signed-off-by: Varun Thumbe Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov Signed-off-by: Varun Thumbe [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe feat: Add support for multiple quantization modes in the UB communicators (#2043) Signed-off-by: Varun Thumbe [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> Signed-off-by: Varun Thumbe mxfp8 unfused quant support, refined unit test, remove unecessary quantization code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe missed a quant code removal Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor bug fix Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov Signed-off-by: Varun Thumbe minor code cleanup Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor cosmetics Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Address review comment Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor comment update Signed-off-by: Varun Thumbe Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> Signed-off-by: Varun Thumbe minor bug: quantizer should not be none for unfused quantization Signed-off-by: Varun Thumbe [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani Add check for sm100 Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe fix linting error Signed-off-by: Varun Thumbe --- docs/api/pytorch.rst | 5 +- docs/examples/onnx/onnx_export.ipynb | 2 +- .../te_layer_with_overlap.py | 8 +- qa/L0_cppunittest/test.sh | 2 +- qa/L1_cpp_distributed/test.sh | 15 + qa/L1_jax_distributed_unittest/test.sh | 1 + setup.py | 13 + tests/cpp/CMakeLists.txt | 1 + tests/cpp/comm_gemm/CMakeLists.txt | 19 + tests/cpp/comm_gemm/test_comm_gemm.cu | 441 +++++++++++++++ tests/jax/multi_process_launch.sh | 23 + tests/jax/test_fused_attn.py | 9 + tests/jax/test_helper.py | 37 +- tests/jax/test_layer.py | 45 +- ..._multi_process_distributed_grouped_gemm.py | 164 ++++++ .../distributed/run_layer_with_overlap.py | 81 ++- .../test_fusible_ops_with_userbuffers.py | 8 +- tests/pytorch/test_fusible_ops.py | 56 +- tests/pytorch/test_numerics.py | 17 +- tests/pytorch/test_onnx_export.py | 18 +- tests/pytorch/test_sanity.py | 50 +- tests/pytorch/utils.py | 2 +- transformer_engine/common/CMakeLists.txt | 28 + transformer_engine/common/__init__.py | 33 ++ .../common/comm_gemm/comm_gemm.cpp | 519 ++++++++++++++++++ .../comm_gemm_overlap/comm_gemm_overlap.cpp | 26 +- .../userbuffers/userbuffers-host.cpp | 2 +- .../userbuffers/userbuffers.cu | 13 +- .../userbuffers/userbuffers.h | 2 +- transformer_engine/common/common.cu | 22 +- transformer_engine/common/common.h | 16 +- transformer_engine/common/dropout/dropout.cu | 355 ++++++++++++ .../common/fused_attn/context_parallel.cu | 9 + .../common/fused_attn/flash_attn.cu | 2 + .../fused_attn_f16_arbitrary_seqlen.cu | 4 + .../common/fused_attn/fused_attn_fp8.cu | 4 + .../common/fused_attn/kv_cache.cu | 4 + transformer_engine/common/fused_attn/utils.cu | 8 +- .../common/fused_router/fused_moe_aux_loss.cu | 12 +- .../fused_score_for_moe_aux_loss.cu | 2 + .../fused_topk_with_score_function.cu | 2 + .../scaled_aligned_causal_masked_softmax.cu | 2 + .../fused_softmax/scaled_masked_softmax.cu | 3 + .../scaled_upper_triang_masked_softmax.cu | 2 + .../common/gemm/cublaslt_gemm.cu | 18 - .../include/transformer_engine/comm_gemm.h | 156 ++++++ .../include/transformer_engine/dropout.h | 51 ++ .../common/multi_tensor/l2norm.cu | 2 + .../common/normalization/common.cpp | 4 +- .../layernorm/ln_bwd_semi_cuda_kernel.cu | 34 +- .../layernorm/ln_fwd_cuda_kernel.cu | 31 +- .../rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 16 +- .../rmsnorm/rmsnorm_fwd_cuda_kernel.cu | 31 +- .../common/nvshmem_api/nvshmem_waitkernel.cu | 15 +- .../common/permutation/permutation.cu | 6 + .../common/recipe/current_scaling.cu | 2 +- .../common/recipe/fp8_block_scaling.cu | 2 + transformer_engine/common/swizzle/swizzle.cu | 56 +- .../common/transformer_engine.cpp | 4 +- .../common/transpose/cast_transpose.cu | 1 + .../common/transpose/cast_transpose_fusion.cu | 16 +- .../common/transpose/multi_cast_transpose.cu | 4 + .../common/transpose/transpose.cu | 1 + .../common/transpose/transpose_fusion.cu | 13 +- .../common/util/cast_gated_kernels.cuh | 24 +- .../common/util/cast_kernels.cuh | 18 +- .../common/util/dequantize_kernels.cuh | 1 + transformer_engine/common/util/logging.h | 17 + transformer_engine/common/util/padding.cu | 4 + .../common/util/vectorized_pointwise.h | 4 + .../jax/cpp_extensions/attention.py | 6 + transformer_engine/jax/cpp_extensions/base.py | 2 +- transformer_engine/jax/cpp_extensions/gemm.py | 6 +- transformer_engine/jax/cpp_extensions/misc.py | 10 + .../jax/cpp_extensions/quantization.py | 64 ++- .../jax/csrc/extensions/gemm.cpp | 5 +- .../jax/csrc/extensions/quantization.cpp | 11 +- transformer_engine/jax/dense.py | 131 ++++- transformer_engine/jax/flax/module.py | 31 +- transformer_engine/jax/layernorm_mlp.py | 7 + transformer_engine/jax/quantize/helper.py | 250 +++++---- transformer_engine/jax/quantize/quantizer.py | 79 ++- .../jax/quantize/scaling_modes.py | 2 +- transformer_engine/jax/sharding.py | 32 +- transformer_engine/pytorch/__init__.py | 1 + .../attention/dot_product_attention/utils.py | 16 +- transformer_engine/pytorch/cpu_offload.py | 30 +- transformer_engine/pytorch/csrc/extensions.h | 11 + .../pytorch/csrc/extensions/cast.cpp | 14 +- .../pytorch/csrc/extensions/dropout.cpp | 89 +++ .../pytorch/csrc/extensions/gemm.cpp | 48 +- .../pytorch/csrc/extensions/pybind.cpp | 7 + transformer_engine/pytorch/csrc/quantizer.cpp | 49 +- transformer_engine/pytorch/graph.py | 31 +- transformer_engine/pytorch/module/__init__.py | 2 +- transformer_engine/pytorch/module/base.py | 120 ++-- .../pytorch/module/layernorm_linear.py | 22 +- .../pytorch/module/layernorm_mlp.py | 24 +- transformer_engine/pytorch/module/linear.py | 22 +- transformer_engine/pytorch/onnx_extensions.py | 54 +- .../pytorch/ops/basic/dropout.py | 69 ++- .../ops/fused/userbuffers_backward_linear.py | 10 +- .../ops/fused/userbuffers_forward_linear.py | 2 +- transformer_engine/pytorch/setup.py | 105 +++- .../pytorch/tensor/float8_tensor.py | 22 +- 105 files changed, 3366 insertions(+), 636 deletions(-) create mode 100755 qa/L1_cpp_distributed/test.sh create mode 100644 tests/cpp/comm_gemm/CMakeLists.txt create mode 100644 tests/cpp/comm_gemm/test_comm_gemm.cu create mode 100644 tests/jax/multi_process_launch.sh create mode 100644 tests/jax/test_multi_process_distributed_grouped_gemm.py create mode 100644 transformer_engine/common/comm_gemm/comm_gemm.cpp create mode 100644 transformer_engine/common/dropout/dropout.cu create mode 100644 transformer_engine/common/include/transformer_engine/comm_gemm.h create mode 100644 transformer_engine/common/include/transformer_engine/dropout.h create mode 100644 transformer_engine/pytorch/csrc/extensions/dropout.cpp diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 3229298f2d..04b49fac2f 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -49,7 +49,7 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.moe_permute -.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs +.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs .. autoapifunction:: transformer_engine.pytorch.moe_unpermute @@ -62,3 +62,6 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.initialize_ub .. autoapifunction:: transformer_engine.pytorch.destroy_ub + +.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode + :members: FP8, NONE \ No newline at end of file diff --git a/docs/examples/onnx/onnx_export.ipynb b/docs/examples/onnx/onnx_export.ipynb index 91fc380037..26ac71188c 100644 --- a/docs/examples/onnx/onnx_export.ipynb +++ b/docs/examples/onnx/onnx_export.ipynb @@ -10,7 +10,7 @@ "\n", "Note:\n", "\n", - "Currently, export to ONNX is supported only for high precision, FP8 delayed scaling and MXFP8.\n", + "Currently, export to ONNX is supported only for high precision, FP8 delayed scaling, FP8 current scaling and MXFP8.\n", "\n", "\n", "\n", diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index e510df1761..d52e97d65c 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -263,7 +263,13 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) te.module.base.initialize_ub( [batched_size, hidden_size], tp_size, - use_fp8=opts.fp8, + quantization_modes=[ + ( + te.module.base.UserBufferQuantizationMode.FP8 + if opts.fp8 + else te.module.base.UserBufferQuantizationMode.NONE + ) + ], dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, ) diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh index cd46b0b63c..aa56d69ed6 100755 --- a/qa/L0_cppunittest/test.sh +++ b/qa/L0_cppunittest/test.sh @@ -17,4 +17,4 @@ cd $TE_PATH/tests/cpp cmake -GNinja -Bbuild . cmake --build build export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS)) -ctest --test-dir build -j$NUM_PARALLEL_JOBS +ctest --test-dir build -j$NUM_PARALLEL_JOBS -E '(AgGemm|GemmRs|GemmAr)' diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh new file mode 100755 index 0000000000..f4f914b3e9 --- /dev/null +++ b/qa/L1_cpp_distributed/test.sh @@ -0,0 +1,15 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +# Find TE +: ${TE_PATH:=/opt/transformerengine} +TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}') +export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH + +cd $TE_PATH/tests/cpp +cmake -GNinja -S. -Bbuild +cmake --build build +mpirun --allow-run-as-root --np 4 --oversubscribe ./build/comm_gemm/test_comm_gemm diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index f332e32e85..8ecc5a9178 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -9,3 +9,4 @@ set -xe mkdir -p "$XML_LOG_DIR" NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* +SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh diff --git a/setup.py b/setup.py index 0b1b523277..52adaf9238 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ """Installation script.""" +from importlib import metadata import os import time from pathlib import Path @@ -66,6 +67,18 @@ def setup_common_extension() -> CMakeExtension: if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") + if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))): + cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON") + cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution( + "nvidia-cublasmp-cu12" + ).locate_file("nvidia/cublasmp/cu12") + cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") + nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution( + "nvidia-nvshmem-cu12" + ).locate_file("nvidia/nvshmem") + cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}") + print("CMAKE_FLAGS:", cmake_flags[-2:]) + # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") if nvte_cmake_extra_args: diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index eb2825ba41..c2c9d0d915 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -37,6 +37,7 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_ message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) include_directories(../../transformer_engine/common) +include_directories(../../transformer_engine) include_directories(${CMAKE_SOURCE_DIR}) find_package(CUDAToolkit REQUIRED) diff --git a/tests/cpp/comm_gemm/CMakeLists.txt b/tests/cpp/comm_gemm/CMakeLists.txt new file mode 100644 index 0000000000..55f5207acf --- /dev/null +++ b/tests/cpp/comm_gemm/CMakeLists.txt @@ -0,0 +1,19 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +add_executable(test_comm_gemm + test_comm_gemm.cu + ../test_common.cu) + +find_package(OpenMP REQUIRED) +find_package(MPI REQUIRED) +find_library(NCCL_LIB + NAMES nccl libnccl + PATH_SUFFIXES lib + REQUIRED) +target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) +target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc CUDNN::cudnn MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) + +include(GoogleTest) +gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp/comm_gemm/test_comm_gemm.cu b/tests/cpp/comm_gemm/test_comm_gemm.cu new file mode 100644 index 0000000000..b34d4db4b8 --- /dev/null +++ b/tests/cpp/comm_gemm/test_comm_gemm.cu @@ -0,0 +1,441 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../test_common.h" +#include "common.h" + +using transformer_engine::DType; +using transformer_engine::TypeInfo; + +#define CHECK_MPI(expr) \ + do { \ + int err = (expr); \ + if (err != MPI_SUCCESS) { \ + char err_str[MPI_MAX_ERROR_STRING + 1]{}; \ + int _len{}; \ + MPI_Error_string(err, err_str, &_len); \ + EXPECT_TRUE(false) << "MPI error: " << err << ": " << err_str; \ + } \ + } while (false) + +#define CHECK_NCCL(expr) \ + do { \ + ncclResult_t err = (expr); \ + if (err != ncclSuccess) { \ + EXPECT_TRUE(false) << "NCCL error: " << err << ": " << ncclGetErrorString(err); \ + } \ + } while (false) + +#define CHECK_CU(expr) \ + do { \ + CUresult err = (expr); \ + if (err != CUDA_SUCCESS) { \ + const char* str{}; \ + CUresult e_str = cuGetErrorString(err, &str); \ + if (e_str != CUDA_SUCCESS) str = "(unknown)"; \ + EXPECT_TRUE(false) << "CU error: " << err << ": " << str; \ + } \ + } while (false) + +int main(int argc, char* argv[]) { + ::testing::InitGoogleTest(&argc, argv); + CHECK_MPI(MPI_Init(&argc, &argv)); + auto ret = RUN_ALL_TESTS(); + CHECK_MPI(MPI_Finalize()); + return ret; +} + +bool IsMulticastSupported(int device_id) { + int supported = 0; + CHECK_CU(cuDeviceGetAttribute(&supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, device_id)); + return supported; +} + +template +std::vector CopyMatrix(const std::vector& data, size_t mstart, size_t nstart, size_t msize, + size_t nsize, size_t ld) { + std::vector ret(msize * nsize); + size_t dst = 0; + for (size_t j = nstart; j < nstart + nsize; ++j) { + for (size_t i = mstart; i < mstart + msize; ++i) { + ret[dst++] = data[j * ld + i]; + } + } + return ret; +} + +template +test::Tensor Make(size_t m, size_t n, float scale) { + test::Tensor ret("", std::vector{n, m}, TypeInfo::dtype); + ret.set_scale(scale); + ret.set_scale_inv(1.0 / scale); + return ret; +} + +template +test::Tensor MakeFromData(const std::vector& data, size_t mstart, size_t nstart, size_t msize, + size_t nsize, size_t ld, float scale) { + test::Tensor ret("", std::vector{nsize, msize}, TypeInfo::dtype); + ret.set_scale(scale); + ret.set_scale_inv(1.0 / scale); + auto local = CopyMatrix(data, mstart, nstart, msize, nsize, ld); + NVTE_CHECK_CUDA(cudaMemcpy(ret.rowwise_dptr(), local.data(), local.size() * sizeof local[0], + cudaMemcpyDefault)); + return ret; +} + +template +float GetScale(float amax) { + if constexpr (sizeof(T) > 1) return 1.0; + return static_cast(static_cast(std::numeric_limits::max())) / amax; +} + +struct Params { + DType a_type; + DType b_type; + DType d_type; + bool transa; + bool transb; + size_t m; + size_t n; + size_t k; + float tol; +}; + +class CommGemmFixure : public ::testing::TestWithParam { + protected: + CommGemmFixure() { + CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &nranks_)); + CHECK_MPI(MPI_Comm_rank(MPI_COMM_WORLD, &rank_)); + NVTE_CHECK_CUDA(cudaSetDevice(rank_)); + ncclUniqueId id{}; + if (rank_ == 0) CHECK_NCCL(ncclGetUniqueId(&id)); + CHECK_MPI(MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); + CHECK_NCCL(ncclCommInitRank(&comm_, nranks_, id, rank_)); + ctx_ = nvte_comm_gemm_ctx_create(comm_, nranks_, rank_); + } + ~CommGemmFixure() { + nvte_comm_gemm_ctx_destroy(ctx_); + ncclCommDestroy(comm_); + } + + struct PatternDims { + int64_t a_rows_start; + int64_t a_rows_num; + int64_t a_cols_start; + int64_t a_cols_num; + int64_t b_rows_start; + int64_t b_rows_num; + int64_t b_cols_start; + int64_t b_cols_num; + int64_t d_rows_start; + int64_t d_rows_num; + int64_t d_cols_start; + int64_t d_cols_num; + }; + + virtual PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) = 0; + + virtual void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) = 0; + + template + void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) { + cudaStream_t stream{}; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + constexpr float MAX_IN = 1.0; + std::mt19937 rng(12); + std::uniform_real_distribution dist(0.0, MAX_IN); + + float a_scale = GetScale(MAX_IN); + float b_scale = GetScale(MAX_IN); + float d_scale = GetScale(MAX_IN * MAX_IN * k); + float bias_scale = GetScale(MAX_IN); + + std::vector adata(m * k); + std::generate(adata.begin(), adata.end(), + [&rng, &dist, a_scale] { return static_cast(dist(rng) * a_scale); }); + std::vector bdata(k * n); + std::generate(bdata.begin(), bdata.end(), + [&rng, &dist, b_scale] { return static_cast(dist(rng) * b_scale); }); + std::vector biasdata(m * n); + std::generate(biasdata.begin(), biasdata.end(), [&rng, &dist, bias_scale] { + return static_cast(dist(rng) * bias_scale); + }); + + auto ga = transa ? MakeFromData(adata, 0, 0, k, m, k, a_scale) + : MakeFromData(adata, 0, 0, m, k, m, a_scale); + auto gb = transb ? MakeFromData(bdata, 0, 0, n, k, n, b_scale) + : MakeFromData(bdata, 0, 0, k, n, k, b_scale); + auto gbias = MakeFromData(biasdata, 0, 0, m, n, m, bias_scale); + auto gd = Make(m, n, d_scale); + auto gaux = Make(m, n, d_scale); + + auto dims = DistributeTensors(m, n, k); + auto a = transa ? MakeFromData(adata, dims.a_rows_start, dims.a_cols_start, + dims.a_rows_num, dims.a_cols_num, k, a_scale) + : MakeFromData(adata, dims.a_cols_start, dims.a_rows_start, + dims.a_cols_num, dims.a_rows_num, m, a_scale); + auto b = transb ? MakeFromData(bdata, dims.b_cols_start, dims.b_rows_start, + dims.b_cols_num, dims.b_rows_num, n, b_scale) + : MakeFromData(bdata, dims.b_rows_start, dims.b_cols_start, + dims.b_rows_num, dims.b_cols_num, k, b_scale); + auto bias = MakeFromData(biasdata, dims.d_rows_start, dims.d_cols_start, + dims.d_rows_num, dims.d_cols_num, m, bias_scale); + auto d = Make(dims.d_rows_num, dims.d_cols_num, d_scale); + auto aux = Make(dims.d_rows_num, dims.d_cols_num, d_scale); + + bool grad = false; + bool accumulate = false; + CommGemm(m, n, k, a.data(), b.data(), d.data(), bias.data(), aux.data(), transa, transb, grad, + accumulate, 0 /*comm_sm_count*/, stream); + auto workspace = Make(1, 32 << 20, 1.0); + nvte_cublas_gemm(ga.data(), gb.data(), gd.data(), gbias.data(), gaux.data(), transa, transb, + grad, workspace.data(), accumulate, false /* use_split_accumulator */, + 0 /* math_sm_count */, stream); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); + std::vector out(dims.d_rows_num * dims.d_cols_num); + NVTE_CHECK_CUDA( + cudaMemcpy(out.data(), d.rowwise_dptr(), out.size() * sizeof out[0], cudaMemcpyDefault)); + std::vector out_golden_global(m * n); + NVTE_CHECK_CUDA(cudaMemcpy(out_golden_global.data(), gd.rowwise_dptr(), + out_golden_global.size() * sizeof out_golden_global[0], + cudaMemcpyDefault)); + + auto out_golden = CopyMatrix(out_golden_global, dims.d_rows_start, dims.d_cols_start, + dims.d_rows_num, dims.d_cols_num, m); + NVTE_CHECK(out.size() == out_golden.size()); + for (size_t i = 0; i < out.size(); ++i) { + EXPECT_NEAR(static_cast(out[i]), static_cast(out_golden[i]), tol * k); + } + } + + NVTECommGemmCtx* ctx_{}; + int nranks_{}; + int rank_{}; + ncclComm_t comm_{}; +}; + +struct AgGemm : public CommGemmFixure { + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { + auto a_cols_num = nvte_comm_gemm_numroc(ctx_, m); + auto b_cols_num = nvte_comm_gemm_numroc(ctx_, n); + + int64_t a_cols_start{}; + int64_t b_cols_start{}; + MPI_Exscan(&a_cols_num, &a_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + MPI_Exscan(&b_cols_num, &b_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + + return PatternDims{ + .a_rows_start = 0, + .a_rows_num = k, + .a_cols_start = a_cols_start, + .a_cols_num = a_cols_num, + .b_rows_start = 0, + .b_rows_num = k, + .b_cols_start = b_cols_start, + .b_cols_num = b_cols_num, + .d_rows_start = a_cols_start, + .d_rows_num = a_cols_num, + .d_cols_start = 0, + .d_cols_num = n, + }; + } + + void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) override { + nvte_all_gather_gemm(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad, + accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault); + } +}; + +struct GemmRs : public CommGemmFixure { + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { + auto rows_num = nvte_comm_gemm_numroc(ctx_, k); + auto d_cols_num = nvte_comm_gemm_numroc(ctx_, n); + + int64_t rows_start{}; + int64_t d_cols_start{}; + MPI_Exscan(&rows_num, &rows_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + MPI_Exscan(&d_cols_num, &d_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + + return PatternDims{ + .a_rows_start = rows_start, + .a_rows_num = rows_num, + .a_cols_start = 0, + .a_cols_num = m, + .b_rows_start = rows_start, + .b_rows_num = rows_num, + .b_cols_start = 0, + .b_cols_num = n, + .d_rows_start = 0, + .d_rows_num = m, + .d_cols_start = d_cols_start, + .d_cols_num = d_cols_num, + }; + } + + void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) override { + nvte_gemm_reduce_scatter(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad, + accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault); + } +}; + +struct GemmAr : public CommGemmFixure { + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { + auto rows_num = nvte_comm_gemm_numroc(ctx_, k); + + int64_t rows_start{}; + MPI_Exscan(&rows_num, &rows_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + + return PatternDims{ + .a_rows_start = rows_start, + .a_rows_num = rows_num, + .a_cols_start = 0, + .a_cols_num = m, + .b_rows_start = rows_start, + .b_rows_num = rows_num, + .b_cols_start = 0, + .b_cols_num = n, + .d_rows_start = 0, + .d_rows_num = m, + .d_cols_start = 0, + .d_cols_num = n, + }; + } + + void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) override { + nvte_gemm_all_reduce(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad, + accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault); + } + + void SetUp() override { + if (!IsMulticastSupported(rank_)) + GTEST_SKIP() << "Multicast is not supported on device " << rank_; + } +}; + +TEST_P(AgGemm, Gemm) { + auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + a_type, AType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + b_type, BType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + d_type, DType, Run(transa, transb, m, n, k, tol);))); +} + +TEST_P(GemmRs, Gemm) { + auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + a_type, AType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + b_type, BType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + d_type, DType, Run(transa, transb, m, n, k, tol);))); +} + +TEST_P(GemmAr, Gemm) { + auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + a_type, AType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + b_type, BType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + d_type, DType, Run(transa, transb, m, n, k, tol);))); +} + +std::string ParamSuffix(const testing::TestParamInfo& info) { + const auto [a_type, b_type, d_type, transa, transb, m, n, k, _tol] = info.param; + std::ostringstream ss; + ss << static_cast(a_type) << "_" << static_cast(b_type) << "_" + << static_cast(d_type) << "_" << (transa ? "T" : "N") << (transb ? "T" : "N") << "_" << m + << "x" << n << "x" << k; + return ss.str(); +} + +INSTANTIATE_TEST_SUITE_P(AgGemm, AgGemm, + testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, true, 256, 128, 64, 1e-3}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + true, false, 256, 128, 64, 1e-3}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, false, 256, 128, 64, 1e-3}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, true, 256, 128, 64, 1e-3}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, true, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat8E4M3, DType::kFloat8E4M3, + DType::kFloat16, true, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat8E4M3, DType::kFloat8E5M2, + DType::kFloat16, true, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat8E5M2, DType::kFloat8E4M3, + DType::kFloat16, true, false, 256, 128, 64, 1e-3}), + &ParamSuffix); + +INSTANTIATE_TEST_SUITE_P(GemmRs, GemmRs, + testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, true, 64, 128, 256, 5e-2}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + true, false, 64, 128, 256, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, false, 64, 128, 256, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, true, 64, 128, 256, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, true, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E4M3, + DType::kFloat16, true, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E5M2, + DType::kFloat16, true, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat8E5M2, DType::kFloat8E4M3, + DType::kFloat16, true, false, 64, 128, 256, 5e-2}), + &ParamSuffix); + +INSTANTIATE_TEST_SUITE_P( + GemmAr, GemmAr, + testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, true, false, 64, + 64 * 4, 64 * 4, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, true, false, 64, + 64 * 4, 64 * 4, 5e-2}, + Params{DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kFloat16, true, false, + 128, 128 * 4, 128 * 4, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kFloat16, true, false, + 128, 128 * 4, 128 * 4, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kFloat16, true, false, + 128, 128 * 4, 128 * 4, 5e-2}), + &ParamSuffix); diff --git a/tests/jax/multi_process_launch.sh b/tests/jax/multi_process_launch.sh new file mode 100644 index 0000000000..3e0852f393 --- /dev/null +++ b/tests/jax/multi_process_launch.sh @@ -0,0 +1,23 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +#!/bin/bash + +SCRIPT_NAME="${SCRIPT_NAME:-test.py}" + + +XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_command_buffer=''" + +export XLA_FLAGS="${XLA_BASE_FLAGS}" + +NUM_RUNS=$(nvidia-smi --query-gpu=count --format=csv,noheader) +for ((i=1; i /dev/null 2>&1 & +done + +CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_PROC + +wait diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index ec530a3959..87dfc113c7 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -41,6 +41,7 @@ from transformer_engine_jax import ( NVTE_Fused_Attn_Backend, get_cudnn_version, + get_device_compute_capability, ) from distributed_test_base import assert_equal_collectives @@ -348,6 +349,14 @@ def _check_configs(self): "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) + if ( + get_device_compute_capability(0) == 100 + and self.dropout_prob == 0.1 + and self.attn_bias_type is not AttnBiasType.NO_BIAS + ): + pytest.skip( + "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" + ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate(): diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index 9b67de6dd8..e4511e1fe0 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -14,10 +14,11 @@ from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.jax import fp8_autocast, get_delayed_scaling from transformer_engine.jax.quantize import ( - QuantizeConfig, + get_quantize_config, is_fp8_available, ScalingMode, update_collections, + TensorSource, ) from transformer_engine.jax.sharding import MeshResource, global_mesh_resource @@ -49,7 +50,7 @@ def test_update_collections(self): class TestFP8Functions(unittest.TestCase): def _check_default_state(self): - self.assertFalse(QuantizeConfig.is_fp8_enabled()) + self.assertFalse(get_quantize_config().is_fp8_enabled()) def _compare_delay_scaling(self, ref, test): self.assertTrue(ref.margin == test.margin) @@ -58,17 +59,23 @@ def _compare_delay_scaling(self, ref, test): self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) def _compare_current_scaling(self, test): - self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) - self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING) + self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) + for tensor_source in TensorSource: + self.assertEqual( + get_quantize_config().get_scaling_mode(tensor_source), + ScalingMode.CURRENT_TENSOR_SCALING, + ) def _compare_mxfp8_scaling(self, test): - self.assertEqual(QuantizeConfig.MARGIN, test.margin) - self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) - self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING) + self.assertEqual(get_quantize_config().MARGIN, test.margin) + self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) + for tensor_source in TensorSource: + self.assertEqual( + get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING + ) @unittest.skipIf(not is_fp8_supported, reason=reason) def test_fp8_autocast_delayed_scaling(self): - QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()): @@ -78,21 +85,20 @@ def test_fp8_autocast_delayed_scaling(self): ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_delay_scaling(get_delayed_scaling(), ds) self._check_default_state() ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1) with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_delay_scaling(get_delayed_scaling(), ds) self._check_default_state() @unittest.skipIf(not is_fp8_supported, reason=reason) def test_fp8_autocast_current_scaling(self): - QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() with fp8_autocast( @@ -104,21 +110,20 @@ def test_fp8_autocast_current_scaling(self): cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3) with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_current_scaling(cs) self._check_default_state() cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID) with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_current_scaling(cs) self._check_default_state() @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) def test_fp8_autocast_mxfp8_block_scaling(self): - QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() with fp8_autocast( @@ -130,14 +135,14 @@ def test_fp8_autocast_mxfp8_block_scaling(self): bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3) with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_mxfp8_scaling(bs) self._check_default_state() bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID) with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_mxfp8_scaling(bs) self._check_default_state() diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 8fe7ebae3d..6f672ade7b 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -23,12 +23,14 @@ from transformer_engine.common import recipe from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType from transformer_engine.jax.quantize import ( - QuantizeConfig, + get_quantize_config, ScalingMode, is_fp8_available, update_collections, + TensorSource, + fp8_autocast, ) -from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.jax.sharding import MeshResource @pytest.fixture(autouse=True, scope="function") @@ -356,7 +358,7 @@ def test_backward( ref_params, test_params = self._sync_params(ref_params, test_params) - if QuantizeConfig.is_fp8_enabled(): + if get_quantize_config().is_fp8_enabled(): for _ in range(4): _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)( inputs, @@ -365,12 +367,15 @@ def test_backward( test_others, test_layer, ) - if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: + if ( + get_quantize_config().get_scaling_mode(TensorSource.X) + == ScalingMode.DELAYED_TENSOR_SCALING + ): _, updated_quantize_meta = flax.core.pop( - updated_state[0], QuantizeConfig.COLLECTION_NAME + updated_state[0], get_quantize_config().COLLECTION_NAME ) test_others = update_collections( - {QuantizeConfig.COLLECTION_NAME: updated_quantize_meta}, test_others + {get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others ) del updated_quantize_meta del updated_state @@ -500,41 +505,33 @@ class BaseTester: def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" - QuantizeConfig.finalize() # Ensure FP8 disabled. - with global_shard_guard( - MeshResource() - ): # Empty MeshResource is used as we are running on a single device + # Ensure FP8 disabled. + # Empty MeshResource is used as we are running on a single device + with fp8_autocast(enabled=False, mesh_resource=MeshResource()): self.runner(attrs).test_forward(data_shape, dtype) def test_backward(self, data_shape, dtype, attrs): """Test normal datatype backward""" - QuantizeConfig.finalize() # Ensure FP8 disabled. - with global_shard_guard( - MeshResource() - ): # Empty MeshResource is used as we are running on a single device + # Ensure FP8 disabled. + # Empty MeshResource is used as we are running on a single device + with fp8_autocast(enabled=False, mesh_resource=MeshResource()): self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test forward with fp8 enabled""" - QuantizeConfig.initialize(fp8_recipe=fp8_recipe) - with global_shard_guard( - MeshResource() - ): # Empty MeshResource is used as we are running on a single device + # Empty MeshResource is used as we are running on a single device + with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) - QuantizeConfig.finalize() @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test backward with fp8 enabled""" - QuantizeConfig.initialize(fp8_recipe=fp8_recipe) - with global_shard_guard( - MeshResource() - ): # Empty MeshResource is used as we are running on a single device + # Empty MeshResource is used as we are running on a single device + with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) - QuantizeConfig.finalize() class TestEncoderLayer(BaseTester): diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py new file mode 100644 index 0000000000..6fce62d8cc --- /dev/null +++ b/tests/jax/test_multi_process_distributed_grouped_gemm.py @@ -0,0 +1,164 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from functools import partial + +import jax +import jax.numpy as jnp + +from transformer_engine.jax.dense import grouped_dense as te_grouped_dense +from transformer_engine.jax.quantize import ( + QuantizerFactory, + ScalingMode, +) + +from utils import assert_allclose + + +N_GROUP = 8 +MESH_AXIS_NAME = "fsdp" + + +def test_grouped_gemm_fp8_allgather(data_shapes, kernel_fsdp_axis): + assert kernel_fsdp_axis in [1, 2] + x_shape, w_shape = data_shapes + + x_sharding = NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None, None, None)) + w_sharding = ( + NamedSharding(mesh, PartitionSpec(None, None, MESH_AXIS_NAME)) + if kernel_fsdp_axis == 2 + else NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None)) + ) + w_no_sharding = NamedSharding(mesh, PartitionSpec(None, None, None)) + + def init_data(): + x_key = jax.random.PRNGKey(0) + w_key = jax.random.PRNGKey(1) + x = jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16) + w = jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16) + w_amax = jnp.max(jnp.abs(w), axis=range(1, w.ndim)) + return x, w, w, w_amax + + def test_func(outter_x, outter_w, outter_w_amax): + in_specs = (x_sharding.spec, w_sharding.spec, None) + out_specs = x_sharding.spec + + @partial( + shard_map.shard_map, + mesh=mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + ) + def sharded_group_gemm(x, w, w_amax): + group_size = x.shape[0] + x_reshaped = x.reshape(-1, x.shape[-1]) + n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) + + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, + is_2x2x=True, + n_groups=group_size, + ) + + output = te_grouped_dense( + x_reshaped, + w, + n_groups, + kernel_amax=w_amax, + quantizer_set=quantizer_set, + kernel_fsdp_info=(MESH_AXIS_NAME, kernel_fsdp_axis), + ) + output = output.reshape(*x.shape[:-1], -1) + return output + + def run(x, w, w_amax): + output = sharded_group_gemm(x, w, w_amax) + return output + + output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_w_amax) + dx, dw, _ = vjp_fn(output) + return output, dx, dw + + def ref_func(outter_x, outter_w): + + in_specs = (x_sharding.spec, w_no_sharding.spec) + out_specs = x_sharding.spec + + @partial( + shard_map.shard_map, + mesh=mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + ) + def sharded_group_gemm(x, w): + group_size = x.shape[0] + x_reshaped = x.reshape(-1, x.shape[-1]) + n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) + + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, + is_2x2x=True, + n_groups=group_size, + ) + output = te_grouped_dense(x_reshaped, w, n_groups, quantizer_set=quantizer_set) + output = output.reshape(*x.shape[:-1], -1) + return output + + def run(x, w): + output = sharded_group_gemm(x, w) + return output + + output, vjp_fn = jax.vjp(run, outter_x, outter_w) + dx, dw = vjp_fn(output) + return output, dx, dw + + init_func = jax.jit(init_data, out_shardings=(x_sharding, w_sharding, w_no_sharding, None)) + x, w, w_global, w_amax = init_func() + + o_sharding = x_sharding + test_func_jitted = jax.jit( + test_func, + in_shardings=(x_sharding, w_sharding, None), + out_shardings=(o_sharding, x_sharding, w_sharding), + ) + ref_func_jitted = jax.jit( + ref_func, + in_shardings=(x_sharding, w_no_sharding), + out_shardings=(o_sharding, x_sharding, w_no_sharding), + ) + + out, dx, dw = test_func_jitted(x, w, w_amax) + ref_out, ref_dx, ref_dw = ref_func_jitted(x, w_global) + + assert_allclose(out, ref_out, dtype=jnp.float8_e4m3fn) + assert_allclose(dx, ref_dx, dtype=jnp.float8_e5m2) + assert_allclose(dw, ref_dw, dtype=jnp.float8_e5m2) + + +if __name__ == "__main__": + from jax.sharding import NamedSharding, PartitionSpec + from jax.experimental import shard_map + import sys + + coord_addr = sys.argv[1] + proc_id = int(sys.argv[2]) + num_procs = int(sys.argv[3]) + + jax.distributed.initialize( + coordinator_address=coord_addr, num_processes=num_procs, process_id=proc_id + ) + + mesh = jax.make_mesh((num_procs,), (MESH_AXIS_NAME,)) + + with mesh: + data_shapes = [((4, 16, 128, 7168), (7168, 2048))] + for data_shape in data_shapes: + for kernel_fsdp_axis in [1, 2]: + test_grouped_gemm_fp8_allgather(data_shape, kernel_fsdp_axis) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 2fc4537f05..2a6e55b2c0 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -12,6 +12,8 @@ import warnings import pprint import yaml +from contextlib import nullcontext +from functools import partial import torch import torch.distributed as dist @@ -35,9 +37,10 @@ def __init__(self, module, num_layers, *args, **kwargs): self.num_layers = num_layers self.layers = torch.nn.ModuleList([module(*args, **kwargs) for _ in range(num_layers)]) - def forward(self, x): - for layer in self.layers: - x = layer(x) + def forward(self, x, layer_contexts): + for layer, context in zip(self.layers, layer_contexts): + with context(): + x = layer(x) return x @@ -237,12 +240,46 @@ def _parse_args(argv=None, namespace=None): default=False, help="Print out additional debug information.", ) + parser.add_argument( + "--first-last-layers-bf16", + action="store_true", + default=False, + help="Use bf16 for first and last N layers.", + ) + parser.add_argument( + "--num-layers-at-start-in-bf16", + type=int, + default=0, + help="Number of layers at the start to run in bf16.", + ) + parser.add_argument( + "--num-layers-at-end-in-bf16", + type=int, + default=0, + help="Number of layers at the end to run in bf16.", + ) args = parser.parse_args(argv, namespace) if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]: warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!") args.use_cuda_graphs = False + if not args.first_last_layers_bf16 and ( + args.num_layers_at_start_in_bf16 > 0 or args.num_layers_at_end_in_bf16 > 0 + ): + warnings.warn( + "num-layers-at-start-in-bf16 and num-layers-at-end-in-bf16 are only supported when" + " first-last-layers-bf16 is enabled!" + ) + args.num_layers_at_start_in_bf16 = 0 + args.num_layers_at_end_in_bf16 = 0 + + if args.num_layers_at_start_in_bf16 + args.num_layers_at_end_in_bf16 > args.num_layers: + raise ValueError( + "num-layers-at-start-in-bf16 + num-layers-at-end-in-bf16 must be less than or equal to" + " num-layers!" + ) + return args @@ -381,10 +418,21 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): "qkv_dgrad": {"method": "ring_exchange"}, "fc1_dgrad": {"method": "ring_exchange"}, } + + quantization_modes = [ + ( + te.module.base.UserBufferQuantizationMode.FP8 + if opts.fp8 + else te.module.base.UserBufferQuantizationMode.NONE + ) + ] + if opts.first_last_layers_bf16 and opts.fp8: + quantization_modes.append(te.module.base.UserBufferQuantizationMode.NONE) + te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], opts.tp, - use_fp8=opts.fp8, + quantization_modes=quantization_modes, dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg, @@ -423,6 +471,16 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): elif opts.quantization == "mxfp8": fp8_recipe = MXFP8BlockScaling() + layer_contexts = [ + ( + partial(te.fp8_autocast, enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world) + if opts.num_layers_at_start_in_bf16 <= i + and i < (opts.num_layers - opts.num_layers_at_end_in_bf16) + else nullcontext + ) + for i in range(opts.num_layers) + ] + # Prepare random input tensors test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True) test_x.retain_grad() @@ -435,14 +493,13 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): # Execute fwd/bwd and collect tensors to test def run_fwd_bwd(model, x): with torch.amp.autocast("cuda", dtype=torch.bfloat16): - with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): - y = model(x) - if isinstance(y, tuple): - out, *_ = y - else: - out = y - loss = out.sum() - loss.backward() + y = model(x, layer_contexts) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + loss = out.sum() + loss.backward() return out torch_rng_state = torch.get_rng_state() diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 37f0e86692..d6ddfe27c9 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -506,7 +506,13 @@ def main() -> None: model_config.num_heads * model_config.head_dim, ], torch.distributed.get_world_size(group), - use_fp8=model_config.quantization is not None, + quantization_modes=[ + ( + te.module.base.UserBufferQuantizationMode.FP8 + if model_config.quantization is not None + else te.module.base.UserBufferQuantizationMode.NONE + ) + ], dtype=model_config.dtype, bootstrap_backend=bootstrap_backend, ub_cfgs=userbuffer_configs, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9325f5d1e5..bb07e87d98 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1749,25 +1749,44 @@ def test_constant_scale( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) - @pytest.mark.parametrize("prob", (0.1, 0.5, 0.75)) + @pytest.mark.parametrize("prob", (0.0625, 0.5, 0.75)) @pytest.mark.parametrize("is_training", (True, False)) - @pytest.mark.parametrize("shape", ((101,), (2, 4, 16))) + @pytest.mark.parametrize("quantization", (None, "fp8_current_scaling")) + @pytest.mark.parametrize("shape", ((101,), (2, 4, 16), (128, 128))) @pytest.mark.parametrize("dtype", _dtypes) def test_dropout( self, *, prob: float, is_training: bool, + quantization: Optional[str], shape: Iterable[int], dtype: torch.dtype, device: torch.device = "cuda", ): + # Skip invalid configurations + quantized_input = quantization is not None + maybe_skip_quantization(quantization, dims=shape, device=device) + # Random data - x_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5 - x_test = x_ref.clone().requires_grad_() - dy_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5 - dy_test = dy_ref.clone() + # Note: Shift values to make sure inputs are non-zero + x_ref, x_test = make_reference_and_test_tensors( + shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + test_is_quantized=quantized_input, + ) + with torch.no_grad(): + x_test += 1 + x_ref.copy_(x_test) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) # Apply dropout op = te_ops.Dropout(prob) @@ -1775,17 +1794,20 @@ def test_dropout( op.train() else: op.eval() - y = op(x_test) - y.backward(dy_test) + y_test = op(x_test) + y_test.backward(dy_test) # Check values + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") if is_training: - mask = ((y != 0) / (1 - prob)).to(dtype=dtype) - torch.testing.assert_close(y, x_ref * mask) - torch.testing.assert_close(x_test.grad, dy_ref * mask) + tols = dtype_tols(dtype) + mask = ((y_test != 0) / (1 - prob)).to(dtype=dtype) + torch.testing.assert_close(y_test, x_ref * mask, **tols) + torch.testing.assert_close(dx_test, dy_ref * mask, **tols) else: - torch.testing.assert_close(y, x_ref, rtol=0, atol=0) - torch.testing.assert_close(x_test.grad, dy_ref, rtol=0, atol=0) + torch.testing.assert_close(y_test, x_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_test, dy_ref, rtol=0, atol=0) # Hypothesis testing for number of zeros # Note: A Bernoulli random variable with probability p has @@ -1797,9 +1819,11 @@ def test_dropout( # p-value is less than 1% and we assume that the dropout # distribution is incorrect. if is_training: - prob_observed = 1 - torch.count_nonzero(y).item() / y.numel() - z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y.numel()) - assert abs(z_score) < 2.5758, "Number of zeros is outside 99% confidence interval" + prob_observed = 1 - torch.count_nonzero(y_test).item() / y_test.numel() + z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y_test.numel()) + assert ( + abs(z_score) < 2.5758 + ), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})" class TestFusedOps: diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index b76f3d2b21..e720673675 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -122,13 +122,18 @@ def is_fused_attn_available( - config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True + config: ModelConfig, + dtype: torch.dtype, + qkv_layout="bshd_bshd_bshd", + is_training=True, + deterministic=False, ): _, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, is_training=is_training, + deterministic=deterministic, ) return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends @@ -839,7 +844,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("model", ["126m"]) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] - if not is_fused_attn_available(config, dtype): + if not is_fused_attn_available(config, dtype, deterministic=True): pytest.skip("No attention backend available.") outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) @@ -887,7 +892,9 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): config = model_configs[model] - if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): + if not is_fused_attn_available( + config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True + ): pytest.skip("No attention backend available.") te_gpt = TransformerLayer( @@ -1000,7 +1007,9 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): @pytest.mark.parametrize("mask_type", mask_types) def test_mha_accuracy(dtype, bs, model, mask_type): config = model_configs[model] - if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): + if not is_fused_attn_available( + config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True + ): pytest.skip("No attention backend available.") te_mha = MultiheadAttention( diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index b353333a50..e5368497d5 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -65,6 +65,7 @@ fp8_recipes.append(recipe.MXFP8BlockScaling()) if fp8_available: fp8_recipes.append(recipe.DelayedScaling()) + fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(None) supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] @@ -81,11 +82,11 @@ ], outputs=[PyCustomOpDef.dt_uint8], ) -def trt_fp8_quantize(t, scale): +def trt_fp8_quantize(t, scale_inv): """FP8 quantization extension for ONNX Runtime.""" x = torch.from_numpy(t).cuda() q = te.tensor.float8_tensor.Float8Quantizer( - scale=1 / torch.from_numpy(scale).cuda(), + scale=1 / torch.from_numpy(scale_inv).cuda(), amax=torch.zeros([1]).cuda(), fp8_dtype=tex.DType.kFloat8E4M3, ) @@ -101,11 +102,11 @@ def trt_fp8_quantize(t, scale): ], outputs=[PyCustomOpDef.dt_float], ) -def trt_fp8_dequantize(t, scale): +def trt_fp8_dequantize(t, scale_inv): """FP8 dequantization extension for ONNX Runtime.""" x = torch.from_numpy(t).cuda() q = te.tensor.float8_tensor.Float8Quantizer( - scale=1 / torch.from_numpy(scale).cuda(), + scale=1 / torch.from_numpy(scale_inv).cuda(), amax=torch.zeros([1]).cuda(), fp8_dtype=tex.DType.kFloat8E4M3, ) @@ -593,7 +594,9 @@ def _test_export_layernorm_linear( fname, inp, model, - atol=1e-3, + # For current scaling we use Float8Quantizer in tests + amax computed by hand, + # which has slightly different numerics than Float8CurrentScalingQuantizer. + atol=1e-3 if fp8_recipe.__class__ is not recipe.Float8CurrentScaling else 2e-2, is_fp8=fp8_recipe is not None, te_outputs=te_outputs, ) @@ -1150,6 +1153,11 @@ def test_trt_integration(fp8_recipe: recipe.Recipe): ffn_hidden_size=128, num_attention_heads=4, ).eval() + + if type(fp8_recipe) == recipe.Float8CurrentScaling: + # TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling + model = te.LayerNormMLP(128, 128) + inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),) with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 5151aa96e7..ae364f80a9 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -38,7 +38,7 @@ Float8Quantizer, Float8Tensor, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint from utils import ModelConfig @@ -911,6 +911,54 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): torch.cuda.synchronize() +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("N", [32]) +@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + "input_quantizer", + [ + Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), + ], +) +@pytest.mark.parametrize( + "out_quantizer", + [ + Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), + ], +) +def test_sanity_fp8gemm_with_quantization(N, datatype, input_quantizer, out_quantizer): + # For MXFP8 and CurrentScaling, below unfused quantization should happen + # FP8 input --> cublas GEMM --> BF16 output --> Quantize to FP8 --> fp8 Output + offset = 32 + scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype) + scratchpad_fp8 = input_quantizer(scratchpad) + inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N)) + weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N)) + outp_type = torch.float32 + quantized_out, *_ = general_gemm( + weight_fp8, + inp_fp8, + get_workspace(), + outp_type, + quantization_params=out_quantizer, + bias=None, + use_split_accumulator=False, + ) + out, *_ = general_gemm( + weight_fp8, + inp_fp8, + get_workspace(), + outp_type, + quantization_params=None, + bias=None, + use_split_accumulator=False, + ) + expected_quantized_out = out_quantizer(out) + torch.testing.assert_close(expected_quantized_out.dequantize(), quantized_out.dequantize()) + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_replace_raw_data_for_float8tensor(): """Test the functionality of replace_raw_data""" diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 524bd3289c..38f400f659 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -266,8 +266,8 @@ def test(): ) ( use_flash_attention, - use_fused_attention, flash_attention_backend, + use_fused_attention, fused_attention_backend, use_unfused_attention, available_backends, diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b51e61929b..cb9f13b899 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -69,6 +69,7 @@ list(APPEND transformer_engine_SOURCES transpose/quantize_transpose_vector_blockwise.cu transpose/swap_first_dims.cu activation/gelu.cu + dropout/dropout.cu fused_attn/flash_attn.cu fused_attn/context_parallel.cu fused_attn/kv_cache.cu @@ -110,6 +111,12 @@ list(APPEND transformer_engine_SOURCES comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers.cu comm_gemm_overlap/comm_gemm_overlap.cpp) + +if (NVTE_WITH_CUBLASMP) +list(APPEND transformer_engine_SOURCES + comm_gemm/comm_gemm.cpp) +endif() + add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") @@ -123,6 +130,8 @@ target_link_libraries(transformer_engine PUBLIC CUDNN::cudnn_all) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(transformer_engine SYSTEM PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI @@ -141,6 +150,25 @@ if (NVTE_ENABLE_NVSHMEM) target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR}) endif() +option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) +if (NVTE_WITH_CUBLASMP) + target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) + target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) + find_library(CUBLASMP_LIB + NAMES cublasmp libcublasmp + PATHS ${CUBLASMP_DIR} + PATH_SUFFIXES lib + REQUIRED) + find_library(NVSHMEM_HOST_LIB + NAMES nvshmem_host libnvshmem_host.so.3 + PATHS ${NVSHMEM_DIR} + PATH_SUFFIXES lib + REQUIRED) + target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) + message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") + message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") +endif() + # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 834c4fe259..7feb5fda5f 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -294,6 +294,38 @@ def _load_nvrtc(): return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) +@functools.lru_cache(maxsize=None) +def _load_curand(): + """Load cuRAND shared library.""" + # Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" + libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True) + libs = list(filter(lambda x: not ("stub" in x), libs)) + libs.sort(reverse=True, key=os.path.basename) + if libs: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + + # Attempt to locate cuRAND in Python dist-packages + found, handle = _load_nvidia_cuda_library("curand") + if found: + return handle + + # Attempt to locate cuRAND via ldconfig + libs = subprocess.check_output( + f"ldconfig -p | grep 'libcurand{_get_sys_extension()}'", shell=True + ) + libs = libs.decode("utf-8").split("\n") + sos = [] + for lib in libs: + if "libcurand" in lib and "=>" in lib: + sos.append(lib.split(">")[1].strip()) + if sos: + return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) + + # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise + return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + + @functools.lru_cache(maxsize=None) def _load_core_library(): """Load shared library with Transformer Engine C extensions""" @@ -303,6 +335,7 @@ def _load_core_library(): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): _CUDNN_LIB_CTYPES = _load_cudnn() _NVRTC_LIB_CTYPES = _load_nvrtc() + _CURAND_LIB_CTYPES = _load_curand() _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") _TE_LIB_CTYPES = _load_core_library() diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp new file mode 100644 index 0000000000..76f46298db --- /dev/null +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -0,0 +1,519 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/comm_gemm.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" + +using namespace transformer_engine; + +namespace { + +// TODO: log warnings on failures of the *Destroy calls below, once TE has such ability. +// For now, just silently ignoring the errors, since the only diag available in TE is throwing +// exceptions, but these calls will typically be made from destructors, so cannot throw. + +template +auto CreateWithCudaCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) { + using Handle = std::remove_pointer_t; + HandlePtr raw{}; + NVTE_CHECK_CUDA(create_fn(&raw, std::forward(args)...)); + return std::unique_ptr(raw, destroy_fn); +} + +using CudaStream = + std::unique_ptr, decltype(&cudaStreamDestroy)>; + +CudaStream CudaStreamCreate() { + return CreateWithCudaCheck(cudaStreamCreate, cudaStreamDestroy); +} + +using CudaEvent = std::unique_ptr, decltype(&cudaEventDestroy)>; + +CudaEvent CudaEventCreate(unsigned flags) { + return CreateWithCudaCheck(cudaEventCreateWithFlags, cudaEventDestroy, flags); +} + +template +auto CreateWithCublasMpCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) { + using Handle = std::remove_pointer_t; + HandlePtr raw{}; + if constexpr (raw_last) { + NVTE_CHECK_CUBLASMP(create_fn(std::forward(args)..., &raw)); + } else { + NVTE_CHECK_CUBLASMP(create_fn(&raw, std::forward(args)...)); + } + return std::unique_ptr(raw, destroy_fn); +} + +using CublasMp = + std::unique_ptr, decltype(&cublasMpDestroy)>; + +CublasMp CublasMpCreate(cudaStream_t stream) { + return CreateWithCublasMpCheck(cublasMpCreate, cublasMpDestroy, stream); +} + +using CublasMpGrid = + std::unique_ptr, decltype(&cublasMpGridDestroy)>; + +CublasMpGrid CublasMpGridCreate(int64_t nprow, int64_t npcol, cublasMpGridLayout_t layout, + ncclComm_t comm) { + return CreateWithCublasMpCheck(cublasMpGridCreate, cublasMpGridDestroy, + nprow, npcol, layout, comm); +} + +using CublasMpMatrixDesc = std::unique_ptr, + decltype(&cublasMpMatrixDescriptorDestroy)>; + +CublasMpMatrixDesc CublasMpMatrixDescCreate(int64_t m, int64_t n, int64_t mb, int64_t nb, + int64_t rsrc, int64_t csrc, int64_t lld, + cudaDataType_t type, cublasMpGrid_t grid) { + return CreateWithCublasMpCheck( + cublasMpMatrixDescriptorCreate, cublasMpMatrixDescriptorDestroy, m, n, mb, nb, rsrc, csrc, + lld, type, grid); +} + +using CublasMpMatmulDesc = std::unique_ptr, + decltype(&cublasMpMatmulDescriptorDestroy)>; + +CublasMpMatmulDesc CublasMpMatmulDescCreate(cublasComputeType_t compute_type) { + return CreateWithCublasMpCheck( + cublasMpMatmulDescriptorCreate, cublasMpMatmulDescriptorDestroy, compute_type); +} + +} // namespace + +struct NVTECommGemmCtx { + int64_t nranks; + int64_t rank; + ncclComm_t comm; + CudaStream stream; + CudaEvent event; + CublasMp cublas_mp; + CublasMpGrid grid_col_major; + CublasMpGrid grid_row_major; + CublasMpMatrixDesc a_desc; + CublasMpMatrixDesc b_desc; + CublasMpMatrixDesc d_desc; + CublasMpMatmulDesc matmul_desc; + void* workspace; + size_t workspace_size; +}; + +namespace { + +int64_t block_size(NVTECommGemmCtx* ctx, int64_t global_size) { + // Use non-cyclic layout to maximize opportunity for comm overlap. + return (global_size + ctx->nranks - 1) / ctx->nranks; +} + +void AgGemmInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, + const Tensor* a, const Tensor* b, const Tensor* d, bool transa, + bool transb) { + const auto a0 = a->flat_first_dim(); + const auto a1 = a->flat_last_dim(); + const auto b0 = b->flat_first_dim(); + const auto b1 = b->flat_last_dim(); + const auto d0 = d->flat_first_dim(); + const auto d1 = d->flat_last_dim(); + + if (transa) { + NVTE_CHECK(a1 == k, "Unsupported tensor dimension in A: expected ", k, ", got ", a1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, k, block_size(ctx, m), 0, 0, k, + get_cuda_dtype(a->dtype()), + ctx->grid_row_major.get(), ctx->a_desc.get())); + } else { + NVTE_CHECK(a0 == k, "Unsupported tensor dimension in A: expected ", k, ", got ", a0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, k, block_size(ctx, m), k, 0, 0, + block_size(ctx, m), get_cuda_dtype(a->dtype()), + ctx->grid_col_major.get(), ctx->a_desc.get())); + } + if (transb) { + NVTE_CHECK(b0 == k, "Unsupported tensor dimensionin B: expected ", k, ", got ", b0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(n, k, block_size(ctx, n), k, 0, 0, + block_size(ctx, n), get_cuda_dtype(b->dtype()), + ctx->grid_col_major.get(), ctx->b_desc.get())); + } else { + NVTE_CHECK(b1 == k, "Unsupported tensor dimension in B: expected ", k, ", got ", b1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, n, k, block_size(ctx, n), 0, 0, k, + get_cuda_dtype(b->dtype()), + ctx->grid_row_major.get(), ctx->b_desc.get())); + } + NVTE_CHECK(d0 == n, "Unsupported tensor dimension in D: expected ", n, ", got ", d0); + *ldd = block_size(ctx, m); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n, block_size(ctx, m), block_size(ctx, n), 0, + 0, *ldd, get_cuda_dtype(d->dtype()), + ctx->grid_col_major.get(), ctx->d_desc.get())); +} + +void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, + const Tensor* a, const Tensor* b, const Tensor* d, bool transa, + bool transb) { + const auto a0 = a->flat_first_dim(); + const auto a1 = a->flat_last_dim(); + const auto b0 = b->flat_first_dim(); + const auto b1 = b->flat_last_dim(); + const auto d0 = d->flat_first_dim(); + const auto d1 = d->flat_last_dim(); + + if (transa) { + NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, block_size(ctx, k), m, 0, 0, + block_size(ctx, k), get_cuda_dtype(a->dtype()), + ctx->grid_col_major.get(), ctx->a_desc.get())); + } else { + NVTE_CHECK(a1 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, k, m, block_size(ctx, k), 0, 0, m, + get_cuda_dtype(a->dtype()), + ctx->grid_row_major.get(), ctx->a_desc.get())); + } + if (transb) { + NVTE_CHECK(b1 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit( + n, k, block_size(ctx, n), block_size(ctx, k), 0, 0, block_size(ctx, n), + get_cuda_dtype(b->dtype()), ctx->grid_row_major.get(), ctx->b_desc.get())); + } else { + NVTE_CHECK(b0 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit( + k, n, block_size(ctx, k), block_size(ctx, n), 0, 0, block_size(ctx, k), + get_cuda_dtype(b->dtype()), ctx->grid_col_major.get(), ctx->b_desc.get())); + } + NVTE_CHECK(d1 == m, "Unsupported tensor dimension in D: expected ", m, ", got ", d1); + *ldd = m; + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n, m, block_size(ctx, n), 0, 0, *ldd, + get_cuda_dtype(d->dtype()), + ctx->grid_row_major.get(), ctx->d_desc.get())); +} + +void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, + const Tensor* a, const Tensor* b, const Tensor* d, bool transa, + bool transb) { + const auto a0 = a->flat_first_dim(); + const auto a1 = a->flat_last_dim(); + const auto b0 = b->flat_first_dim(); + const auto b1 = b->flat_last_dim(); + const auto d0 = d->flat_first_dim(); + const auto d1 = d->flat_last_dim(); + + if (transa) { + NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, block_size(ctx, k), m, 0, 0, + block_size(ctx, k), get_cuda_dtype(a->dtype()), + ctx->grid_col_major.get(), ctx->a_desc.get())); + } else { + NVTE_ERROR("N transpose flag is not supported for input A"); + } + if (transb) { + NVTE_ERROR("T transpose flag is not supported for input B"); + } else { + NVTE_CHECK(b0 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, n, block_size(ctx, k), n, 0, 0, + block_size(ctx, k), get_cuda_dtype(b->dtype()), + ctx->grid_col_major.get(), ctx->b_desc.get())); + } + NVTE_CHECK(d1 == m, "Unsupported tensor dimension in D: expected ", m, ", got ", d1); + *ldd = m; + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n * ctx->nranks, m, n, 0, 0, *ldd, + get_cuda_dtype(d->dtype()), + ctx->grid_row_major.get(), ctx->d_desc.get())); + + const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE; + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, + sizeof epilogue)); +} + +using InitMatricesFn = void (*)(NVTECommGemmCtx*, int64_t*, int64_t, int64_t, int64_t, + const Tensor*, const Tensor*, const Tensor*, bool, bool); + +cublasMpMatmulAlgoType_t cublasmp_algo(NVTECommGemmAlgoType algo) { + static const std::unordered_map s_map{ + {kNVTECommGemmAlgoDefault, CUBLASMP_MATMUL_ALGO_TYPE_DEFAULT}, + {kNVTECommGemmAlgoSplitP2P, CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_P2P}, + {kNVTECommGemmAlgoSplitMulticast, CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_MULTICAST}, + {kNVTECommGemmAlgoAtomicP2P, CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_P2P}, + {kNVTECommGemmAlgoAtomicMulticast, CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_MULTICAST}, + }; + auto it = s_map.find(algo); + return it != s_map.end() ? it->second : static_cast(algo); +} + +void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECommGemmAlgoType algo, + int64_t m, int64_t n, int64_t k, const Tensor* a, const Tensor* b, + const Tensor* d, const Tensor* bias, const Tensor* pre_act_out, bool transa, + bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t main_stream) { + for (auto t : {a, b, d}) { + NVTE_CHECK(is_tensor_scaling(t->scaling_mode), + "Unsupported scaling mode: " + std::to_string(t->scaling_mode)); + } + + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorInit(ctx->matmul_desc.get(), CUBLAS_COMPUTE_32F)); + + int64_t ldd{}; + init_matrices_fn(ctx, &ldd, m, n, k, a, b, d, transa, transb); + + const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N; + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a, + sizeof trans_a)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b, + sizeof trans_b)); + cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr, + sizeof algo_attr)); + + const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32; + if (is_fp8_dtype(a->dtype())) { + NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode, + sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER, + &a->scale_inv.dptr, sizeof(void*))); + } + if (is_fp8_dtype(b->dtype())) { + NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode, + sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER, + &b->scale_inv.dptr, sizeof(void*))); + } + if (is_fp8_dtype(d->dtype())) { + NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode, + sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER, + &d->scale.dptr, sizeof(void*))); + if (d->amax.dptr) { + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER, + &d->amax.dptr, sizeof(void*))); + } + } + + // Might be set to ALLREDUCE before, need to OR with the new flags to set. + cublasMpMatmulEpilogue_t epilogue{}; + size_t size_read{}; + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, + sizeof epilogue, &size_read)); + NVTE_CHECK(size_read == sizeof epilogue); + // (bias, gelu, grad) -> epilogue + const std::map, cublasMpMatmulEpilogue_t> flags_to_epilogue{ + {{true, true, false}, CUBLASMP_MATMUL_EPILOGUE_GELU_AUX_BIAS}, + {{true, true, true}, CUBLASMP_MATMUL_EPILOGUE_DGELU_BGRAD}, + {{true, false, false}, CUBLASMP_MATMUL_EPILOGUE_BIAS}, + {{true, false, true}, CUBLASMP_MATMUL_EPILOGUE_BGRADB}, + {{false, true, false}, CUBLASMP_MATMUL_EPILOGUE_GELU_AUX}, + {{false, true, true}, CUBLASMP_MATMUL_EPILOGUE_DGELU}, + }; + if (auto it = + flags_to_epilogue.find({bias ? bias->data.dptr != nullptr : false, + pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad}); + it != flags_to_epilogue.end()) { + epilogue = static_cast(epilogue | it->second); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, + sizeof epilogue)); + } + + if (bias && bias->data.dptr) { + cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type, + sizeof bias_type)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr, + sizeof bias->data.dptr)); + } + + if (pre_act_out && pre_act_out->data.dptr) { + cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE, + &aux_type, sizeof aux_type)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER, + &pre_act_out->data.dptr, sizeof pre_act_out->data.dptr)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd, + sizeof ldd)); + if (is_fp8_dtype(pre_act_out->dtype())) { + NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE, + &scale_mode, sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER, + &pre_act_out->scale.dptr, sizeof(void*))); + if (pre_act_out->amax.dptr) { + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER, + &pre_act_out->amax.dptr, sizeof(void*))); + } + } + } + + if (comm_sm_count) { + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT, + &comm_sm_count, sizeof comm_sm_count)); + } + + NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream)); + + size_t wrksp_size_device{}; + size_t wrksp_size_host{}; + + float alpha = 1.0; + float beta = accumulate ? 1.0 : 0.0; + std::tuple args{ctx->cublas_mp.get(), + ctx->matmul_desc.get(), + m, + n, + k, + &alpha, + a->data.dptr, + 1, + 1, + ctx->a_desc.get(), + b->data.dptr, + 1, + 1, + ctx->b_desc.get(), + &beta, + accumulate ? d->data.dptr : nullptr, + 1, + 1, + accumulate ? ctx->d_desc.get() : nullptr, + d->data.dptr, + 1, + 1, + ctx->d_desc.get()}; + NVTE_CHECK_CUBLASMP( + std::apply(cublasMpMatmul_bufferSize, + std::tuple_cat(args, std::tuple{&wrksp_size_device, &wrksp_size_host}))); + + std::vector workspace_host(wrksp_size_host); + if (ctx->workspace_size < wrksp_size_device) { + nvshmem_free(ctx->workspace); + ctx->workspace = nvshmem_malloc(wrksp_size_device); + ctx->workspace_size = wrksp_size_device; + } + + NVTE_CHECK_CUBLASMP( + std::apply(cublasMpMatmul, + std::tuple_cat(args, std::tuple{ctx->workspace, ctx->workspace_size, + workspace_host.data(), workspace_host.size()}))); + + NVTE_CHECK_CUDA(cudaEventRecord(ctx->event.get(), main_stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(ctx->stream.get(), ctx->event.get(), 0)); +} + +} // namespace + +NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank) { + NVTE_API_CALL(nvte_comm_gemm_ctx_create); + auto stream = CudaStreamCreate(); + auto event = CudaEventCreate(cudaEventDisableTiming); + auto cublas_mp = CublasMpCreate(stream.get()); + + auto col_major = CublasMpGridCreate(nranks, 1, CUBLASMP_GRID_LAYOUT_COL_MAJOR, comm); + auto row_major = CublasMpGridCreate(1, nranks, CUBLASMP_GRID_LAYOUT_ROW_MAJOR, comm); + + // Pre-creating matrix descriptors here, will be initialized with the actual params later. + auto a_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get()); + auto b_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get()); + auto d_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get()); + + auto matmul_desc = CublasMpMatmulDescCreate(CUBLAS_COMPUTE_32F); + + return new NVTECommGemmCtx{ + .nranks = nranks, + .rank = rank, + .comm = comm, + .stream = std::move(stream), + .event = std::move(event), + .cublas_mp = std::move(cublas_mp), + .grid_col_major = std::move(col_major), + .grid_row_major = std::move(row_major), + .a_desc = std::move(a_desc), + .b_desc = std::move(b_desc), + .d_desc = std::move(d_desc), + .matmul_desc = std::move(matmul_desc), + }; +} + +void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { + NVTE_API_CALL(nvte_comm_gemm_ctx_destroy); + nvshmemx_sync_all_on_stream(ctx->stream.get()); + delete ctx; +} + +void nvte_all_gather_gemm(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo) { + NVTE_API_CALL(nvte_all_gather_gemm); + cublasmp_gemm(AgGemmInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a), + convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias), + convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate, + comm_sm_count, main_stream); +} + +void nvte_gemm_reduce_scatter(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, + const NVTETensor a, const NVTETensor b, const NVTETensor d, + const NVTETensor bias, const NVTETensor pre_act_out, bool transa, + bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t main_stream, NVTECommGemmAlgoType algo) { + NVTE_API_CALL(nvte_gemm_reduce_scatter); + cublasmp_gemm(GemmRsInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a), + convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias), + convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate, + comm_sm_count, main_stream); +} + +void nvte_gemm_all_reduce(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo) { + NVTE_API_CALL(nvte_gemm_all_reduce); + cublasmp_gemm(GemmArInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a), + convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias), + convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate, + comm_sm_count, main_stream); +} + +int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size) { + NVTE_API_CALL(nvte_comm_gemm_numroc); + return cublasMpNumroc(global_size, block_size(ctx, global_size), ctx->rank, 0, ctx->nranks); +} diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 9ba6688ce8..d90dd3abc1 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -101,10 +101,10 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl DType::kInt32); } // CUDA event creation - cudaEventCreateWithFlags(&_start_compute, 0); - cudaEventCreateWithFlags(&_stop_compute, 0); - cudaEventCreateWithFlags(&_start_comm, 0); - cudaEventCreateWithFlags(&_stop_comm, 0); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_compute, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_compute, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_comm, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_comm, 0)); /* Defining the launcher order between the communication and GEMM kernels @@ -114,11 +114,11 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl */ int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); int runtime_version = 0; - cudaRuntimeGetVersion(&runtime_version); + NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&runtime_version)); cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, 0)); if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) { - cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming)); } else { _comm_launch_event = 0; } @@ -129,9 +129,13 @@ CommOverlapCore::~CommOverlapCore() { cudaEventDestroy(_start_comm); cudaEventDestroy(_stop_compute); cudaEventDestroy(_start_compute); - if (_comm_launch_event) cudaEventDestroy(_comm_launch_event); + if (_comm_launch_event) { + cudaEventDestroy(_comm_launch_event); + } - if (_atomic_gemm) cudaFree(_counter.dptr()); + if (_atomic_gemm) { + cudaFree(_counter.dptr()); + } for (size_t i = 0; i < _stream_compute.size(); i++) { cudaStreamSynchronize(_stream_compute[i]); @@ -698,7 +702,9 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { cudaEventDestroy(_stop_recv); cudaEventDestroy(_stop_send); cudaStreamDestroy(_stream_recv); - for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]); + for (size_t i = 0; i < _stream_send.size(); i++) { + cudaStreamDestroy(_stream_send[i]); + } } TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 65da58d5f3..1ce89c512f 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -511,7 +511,7 @@ void destroy_communicator_mpi(communicator *comm) { } int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) { - if (comm->free_region > NVTE_MAX_REGIONS) return -1; + if (comm->free_region >= NVTE_MAX_REGIONS) return -1; int hndl = comm->free_region; comm->peer_ptr[hndl] = reinterpret_cast(malloc(sizeof(void *) * (comm->nvsize))); size_t aligned_size = bytes; diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 893644ce6f..17f3cf658e 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -2319,6 +2319,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds if (comm->push == 0) { kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]), reinterpret_cast(flagptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + srcoffset; void *dstptr = reinterpret_cast(comm->peer_ptr[dsthandler][peerlocal]) + dstoffset; @@ -2516,8 +2517,11 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast(flagptr), reinterpret_cast(srcptr), reinterpret_cast(dstptr), signalonly ? 0 : bytes / 16, comm->ub_timeout); - if (!signalonly) + NVTE_CHECK_CUDA(cudaGetLastError()); + if (!signalonly) { kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler])); + NVTE_CHECK_CUDA(cudaGetLastError()); + } if (comm->use_ce) { NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); } @@ -2532,6 +2536,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds reinterpret_cast(0 ? // temporary disable GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2) : nullptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -2612,24 +2617,28 @@ void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 block(1); dim3 grid(1); producer_kernel<<>>(atomic_ptr, chunk_i); + NVTE_CHECK_CUDA(cudaGetLastError()); } void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 block(1); dim3 grid(1); consumer_kernel<<>>(atomic_ptr, chunk_i); + NVTE_CHECK_CUDA(cudaGetLastError()); } void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream) { dim3 block(1); dim3 grid(1); consumer_batch_kernel<<>>(atomic_ptr, first_chunk_i, num_chunks); + NVTE_CHECK_CUDA(cudaGetLastError()); } void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) { dim3 block(1); dim3 grid(1); reset_counters_kernel<<>>(atomic_ptr, num_chunks, allgather); + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -2683,6 +2692,7 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in reduce_fp8_in_bf16_out_cuda <<>>(inputs, output, scale, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size); + NVTE_CHECK_CUDA(cudaGetLastError()); } template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale, @@ -2738,4 +2748,5 @@ void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cud dim3 grid(num_blocks); reduce_bf16_cuda<<>>( inputs, output, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size); + NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 34d6ff72f4..8077f90be8 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -27,7 +27,7 @@ using ExtAllgatherOp = std::function; using ExtBarrierOp = std::function; -#define NVTE_MAX_REGIONS 16 +#define NVTE_MAX_REGIONS 32 #define NVTE_MAX_SMS 32 #define NVTE_MAX_OPS 32 #define NVTE_MAX_PEERS 8192 diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 4e697979d8..8b7f92aff9 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -26,12 +26,31 @@ __global__ void __launch_bounds__(1) } // namespace +cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { + using namespace transformer_engine; + switch (t) { + case DType::kFloat16: + return CUDA_R_16F; + case DType::kFloat32: + return CUDA_R_32F; + case DType::kBFloat16: + return CUDA_R_16BF; + case DType::kFloat8E4M3: + return CUDA_R_8F_E4M3; + case DType::kFloat8E5M2: + return CUDA_R_8F_E5M2; + default: + NVTE_ERROR("Invalid type"); + } +} + void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) { if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) { NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv."); update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>( reinterpret_cast(t->scale.dptr), reinterpret_cast(t->scale_inv.dptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -73,6 +92,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) dim3 grid(numBlocks, 1, 1); \ memset_kernel \ <<>>(ptr, value, size_in_bytes); \ + NVTE_CHECK_CUDA(cudaGetLastError()); \ return; \ } @@ -83,7 +103,7 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream if (size_in_bytes > 4096) { // Use cudaMemsetAsync for larger sizes. - cudaMemsetAsync(ptr, value, size_in_bytes, stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(ptr, value, size_in_bytes, stream)); return; } diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index aa47f2c3d9..e2a3c52aa2 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -270,6 +270,8 @@ struct QuantizationConfig { }; }; +cudaDataType_t get_cuda_dtype(const transformer_engine::DType t); + template constexpr T DIVUP(const T &x, const T &y) { return (((x) + ((y)-1)) / (y)); @@ -382,9 +384,19 @@ struct BitsNumber { template struct TypeInfo { #if FP4_TYPE_SUPPORTED - using types = std::tuple; + using types = std::tuple= 12080 + , + fp8e8m0 +#endif + >; #else - using types = std::tuple; + using types = std::tuple= 12080 + , + fp8e8m0 +#endif + >; #endif template diff --git a/transformer_engine/common/dropout/dropout.cu b/transformer_engine/common/dropout/dropout.cu new file mode 100644 index 0000000000..bab349161e --- /dev/null +++ b/transformer_engine/common/dropout/dropout.cu @@ -0,0 +1,355 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include + +#include "../common.h" +#include "../utils.cuh" +#include "transformer_engine/dropout.h" + +namespace transformer_engine { +namespace { + +// RNG kernels process chunks of 16 entries +constexpr size_t rng_chunk_size = 16; + +// CUDA block size +constexpr size_t block_size = 128; + +// Vector class to help with vectorized memory accesses +template +union Vector { + using StorageType = typename BytesToType::Type; + StorageType storage; + T entries[kSize]; +}; + +/* Byte-wise less-than comparison + * + * Results are stored in each byte's most-significant bit (MSB). All + * other bits are zero. + */ +__device__ __forceinline__ uint32_t bytewise_less_than(uint32_t a, uint32_t b) { + // Compare low bits by masking MSBs and subtracting. The resulting + // MSBs are 0 if the low bits of a are less than the low bits of b. + uint32_t result = (a | 0x80808080) - (b & 0x7F7F7F7F); + + // Bitwise logical op to get answer in MSBs + // Equivalent logic: result = (a == b) ? !result : b + asm("lop3.b32 %0, %1, %2, %3, 0x4D;\n\t" : "=r"(result) : "r"(a), "r"(b), "r"(result)); + + // Mask out everything except MSBs and return + result &= 0x80808080; + return result; +} + +/* Generate dropout mask with 16 bits. + * + * 1 corresponds to keep and 0 to drop. + * + * Consumes 4 values from cuRAND Philox generator. + */ +__device__ __forceinline__ uint16_t make_16bit_mask(uint64_t chunk_idx, uint64_t rng_seed, + uint64_t rng_offset, + uint32_t bytewise_drop_prob) { + // Generate random bits + curandStatePhilox4_32_10_t state; + curand_init(rng_seed, chunk_idx, rng_offset, &state); + const uint4 rand_bits = curand4(&state); + + // Compute mask + // Note: bytewise_less_than fills MSBs (bits 7, 15, 23, 31). By + // shifting 2 bits after every call, every other bit will be filled. + uint32_t result = bytewise_less_than(rand_bits.x, bytewise_drop_prob); + result = (result >> 2) | bytewise_less_than(rand_bits.y, bytewise_drop_prob); + result = (result >> 2) | bytewise_less_than(rand_bits.z, bytewise_drop_prob); + result = (result >> 2) | bytewise_less_than(rand_bits.w, bytewise_drop_prob); + + // Consolidate mask in lowest 16 bits + result |= result >> 17; + + // Flip bits so 0 corresponds to drop + result = ~result; + + return result; +} + +// Dropout forward with FP16/BF16 input and output. +template +__global__ void __launch_bounds__(block_size) + dropout_kernel_fwd_f16(const T *__restrict__ input_ptr, T *__restrict__ output_ptr, + uint8_t *__restrict__ mask_ptr, + const uint64_t *__restrict__ rng_state_ptr, size_t num_chunks, + uint32_t bytewise_drop_prob, float scale) { + static_assert(sizeof(T) == 2); + + // Each thread processes a chunk of 16 entries + const size_t gid = threadIdx.x + blockIdx.x * block_size; + const size_t nthreads = gridDim.x * block_size; + for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) { + // Generate dropout mask + auto local_mask = + make_16bit_mask(chunk_idx, rng_state_ptr[0], rng_state_ptr[1], bytewise_drop_prob); + reinterpret_cast(mask_ptr)[chunk_idx] = local_mask; + + // Read input data + using VectorType = Vector; + VectorType local_data; + local_data = reinterpret_cast(input_ptr)[chunk_idx]; + + // Apply dropout based on mask +#pragma unroll + for (size_t i = 0; i < rng_chunk_size; i++) { + float val = static_cast(local_data.entries[i]); + if ((local_mask & 0x1) == 0) { + val = 0; + } + val *= scale; + local_data.entries[i] = static_cast(val); + local_mask >>= 1; + } + + // Write output data + reinterpret_cast(output_ptr)[chunk_idx] = local_data; + } +} + +// Dropout forward with FP8 input and FP16/BF16 output. +template +__global__ void __launch_bounds__(block_size) + dropout_kernel_fwd_fp8(const InputType *__restrict__ input_ptr, + const float *__restrict__ input_scale_inv_ptr, + OutputType *__restrict__ output_ptr, uint8_t *__restrict__ mask_ptr, + const uint64_t *__restrict__ rng_state_ptr, size_t num_chunks, + uint32_t bytewise_drop_prob, float scale) { + static_assert(sizeof(InputType) == 1); + static_assert(sizeof(OutputType) == 2); + const float input_scale_inv = *input_scale_inv_ptr; + + // Each thread processes a chunk of 16 entries + const size_t gid = threadIdx.x + blockIdx.x * block_size; + const size_t nthreads = gridDim.x * block_size; + for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) { + // Generate dropout mask + auto local_mask = + make_16bit_mask(chunk_idx, rng_state_ptr[0], rng_state_ptr[1], bytewise_drop_prob); + reinterpret_cast(mask_ptr)[chunk_idx] = local_mask; + + // Read input data + using InputVectorType = Vector; + InputVectorType local_input; + local_input = reinterpret_cast(input_ptr)[chunk_idx]; + + // Apply dropout based on mask + using OutputVectorType = Vector; + OutputVectorType local_output; +#pragma unroll + for (size_t i = 0; i < rng_chunk_size; i++) { + float val = static_cast(local_input.entries[i]); + val *= input_scale_inv; + if ((local_mask & 0x1) == 0) { + val = 0; + } + val *= scale; + local_output.entries[i] = static_cast(val); + local_mask >>= 1; + } + + // Write output data + reinterpret_cast(output_ptr)[chunk_idx] = local_output; + } +} + +// Apply dropout mask and scale. +template +__global__ void __launch_bounds__(block_size) + apply_dropout_mask(const T *__restrict__ input_ptr, const uint8_t *__restrict__ mask_ptr, + T *__restrict__ output_ptr, size_t num_chunks, float scale) { + // Each thread processes a chunk of 8 entries. + const size_t gid = threadIdx.x + blockIdx.x * block_size; + const size_t nthreads = gridDim.x * block_size; + constexpr size_t chunk_size = 8; + for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) { + // Read dropout mask + uint8_t local_mask = mask_ptr[chunk_idx]; + + // Read input data + using VectorType = Vector; + VectorType local_data; + local_data = reinterpret_cast(input_ptr)[chunk_idx]; + + // Apply dropout based on mask +#pragma unroll + for (size_t i = 0; i < chunk_size; i++) { + float val = static_cast(local_data.entries[i]); + if ((local_mask & 0x1) == 0) { + val = 0; + } + val *= scale; + local_data.entries[i] = static_cast(val); + local_mask >>= 1; + } + + // Write output data + reinterpret_cast(output_ptr)[chunk_idx] = local_data; + } +} + +} // namespace + +void dropout_fwd(const Tensor &input, Tensor &output, Tensor &mask, Tensor &rng_state, + float dropout_probability, cudaStream_t stream) { + // Check tensors + const size_t numel = input.numel(); + NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be FP16/BF16 tensor or tensor-scaled FP8 tensor, ", + "but scaling mode is ", to_string(input.scaling_mode), "."); + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Output tensor must be FP16/BF16 tensor, ", "but scaling mode is ", + to_string(output.scaling_mode), "."); + NVTE_CHECK(mask.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Mask tensor must be plain tensor, ", + "but scaling mode is ", to_string(mask.scaling_mode), "."); + NVTE_CHECK(rng_state.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "RNG state tensor must be INT64 tensor with two entries, ", "but scaling mode is ", + to_string(rng_state.scaling_mode), "."); + NVTE_CHECK(output.dtype() == DType::kFloat16 || output.dtype() == DType::kBFloat16, + "Output tensor must be FP16/BF16 tensor, but dtype is ", to_string(output.dtype()), + "."); + NVTE_CHECK(rng_state.dtype() == DType::kInt64, + "RNG state tensor must be INT64 tensor with two entries, but dtype is ", + to_string(rng_state.dtype()), "."); + NVTE_CHECK(numel % 16 == 0, + "Input tensor number of elements must be divisible by 16, but shape is ", + input.shape(), "."); + NVTE_CHECK(numel == output.numel(), "Input tensor (shape=", input.shape(), + ") and output tensor (shape=", output.shape(), ") do not match."); + NVTE_CHECK(typeToNumBits(mask.dtype()) * mask.numel() == numel, "Mask tensor must have ", numel, + " bits, but found dtype=", to_string(mask.dtype()), " and shape=", mask.shape(), "."); + NVTE_CHECK(rng_state.numel() == 2, "RNG state tensor must be INT64 tensor with two entries, ", + "but shape is ", rng_state.shape(), "."); + NVTE_CHECK(input.data.dptr != nullptr, "Input tensor is missing data."); + NVTE_CHECK(output.data.dptr != nullptr, "Output tensor is missing data."); + NVTE_CHECK(mask.data.dptr != nullptr, "Mask tensor is missing data."); + NVTE_CHECK(rng_state.data.dptr != nullptr, "RNG state tensor is missing data."); + + // Convert dropout probablity to scale and 8-bit representation + NVTE_CHECK(dropout_probability >= 0 && dropout_probability < 1, "Invalid dropout probability (", + dropout_probability, ")."); + const float scale = 1 / (1 - dropout_probability); + uint32_t bytewise_drop_prob = static_cast(std::floor(dropout_probability * 256)); + bytewise_drop_prob |= bytewise_drop_prob << 8; + bytewise_drop_prob |= bytewise_drop_prob << 16; + + // CUDA config + const size_t num_chunks = numel / rng_chunk_size; + const size_t num_blocks = DIVUP(num_chunks, block_size); + + // Launch kernel depending on input dtype + if (input.dtype() == DType::kFloat16 || input.dtype() == DType::kBFloat16) { + NVTE_CHECK(input.dtype() == output.dtype(), "Input tensor (dtype=", to_string(input.dtype()), + ") and output tensor (dtype=", to_string(output.dtype()), ") do not match."); + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + input.dtype(), DType, + dropout_kernel_fwd_f16<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), + reinterpret_cast(mask.data.dptr), + reinterpret_cast(rng_state.data.dptr), num_chunks, bytewise_drop_prob, + scale);); + NVTE_CHECK_CUDA(cudaGetLastError()); + } else if (input.dtype() == DType::kFloat8E4M3 || input.dtype() == DType::kFloat8E5M2) { + NVTE_CHECK(input.scale_inv.dptr != nullptr, "Input tensor scale-inverse is not allocated."); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.dtype(), InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + output.dtype(), OutputType, + dropout_kernel_fwd_fp8<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(output.data.dptr), + reinterpret_cast(mask.data.dptr), + reinterpret_cast(rng_state.data.dptr), num_chunks, + bytewise_drop_prob, scale); + + );); + NVTE_CHECK_CUDA(cudaGetLastError()); + } else { + NVTE_ERROR("Input tensor must be FP16/BF16 tensor or tensor-scaled FP8 tensor, ", + "but dtype is ", to_string(input.dtype()), "."); + } +} + +void dropout_bwd(const Tensor &grad_output, const Tensor &mask, Tensor &grad_input, + float dropout_probability, cudaStream_t stream) { + // Check tensors + const size_t numel = grad_output.numel(); + NVTE_CHECK(grad_output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Grad output tensor must be FP16/BF16 tensor, ", "but scaling mode is ", + to_string(grad_output.scaling_mode), "."); + NVTE_CHECK(grad_input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Grad input tensor must be FP16/BF16 tensor, ", "but scaling mode is ", + to_string(grad_input.scaling_mode), "."); + NVTE_CHECK(mask.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Mask tensor must be a plain tensor, but scaling mode is ", + to_string(mask.scaling_mode), "."); + NVTE_CHECK(grad_output.dtype() == DType::kFloat16 || grad_output.dtype() == DType::kBFloat16, + "Grad output tensor must be FP16/BF16 tensor, but dtype is ", + to_string(grad_output.dtype()), "."); + NVTE_CHECK(grad_output.dtype() == grad_input.dtype(), + "Grad output tensor (dtype=", to_string(grad_output.dtype()), + ") and grad input tensor (dtype=", to_string(grad_input.dtype()), ") do not match."); + NVTE_CHECK(numel % 16 == 0, + "Grad output tensor number of elements must be divisible by 16, but shape is ", + grad_output.shape(), "."); + NVTE_CHECK(numel == grad_input.numel(), "Grad output tensor (shape=", grad_output.shape(), + ") and grad input tensor (shape=", grad_input.shape(), ") do not match."); + NVTE_CHECK(typeToNumBits(mask.dtype()) * mask.numel() == numel, "Mask tensor must have ", numel, + " bits, but found dtype=", to_string(mask.dtype()), " and shape=", mask.shape(), "."); + NVTE_CHECK(grad_output.data.dptr != nullptr, "Grad output tensor is missing data."); + NVTE_CHECK(grad_input.data.dptr != nullptr, "Grad input tensor is missing data."); + NVTE_CHECK(mask.data.dptr != nullptr, "Mask tensor is missing data."); + + // Convert dropout probablity to scale + NVTE_CHECK(dropout_probability >= 0 && dropout_probability < 1, "Invalid dropout probability (", + dropout_probability, ")."); + const float scale = 1 / (1 - dropout_probability); + + // CUDA config + const size_t num_chunks = numel / 8; + const size_t num_blocks = DIVUP(num_chunks, block_size); + + // Launch kernel + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + grad_output.dtype(), DType, + apply_dropout_mask<<>>( + reinterpret_cast(grad_output.data.dptr), + reinterpret_cast(mask.data.dptr), + reinterpret_cast(grad_input.data.dptr), num_chunks, scale);); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace transformer_engine + +void nvte_dropout_fwd(const NVTETensor input, NVTETensor output, NVTETensor mask, + NVTETensor rng_state, float dropout_probability, cudaStream_t stream) { + NVTE_API_CALL(nvte_dropout_fwd); + using namespace transformer_engine; + dropout_fwd(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + *convertNVTETensorCheck(mask), *convertNVTETensorCheck(rng_state), + dropout_probability, stream); +} + +void nvte_dropout_bwd(const NVTETensor grad_output, const NVTETensor mask, NVTETensor grad_input, + float dropout_probability, cudaStream_t stream) { + NVTE_API_CALL(nvte_dropout_bwd); + using namespace transformer_engine; + dropout_bwd(*convertNVTETensorCheck(grad_output), *convertNVTETensorCheck(mask), + *convertNVTETensorCheck(grad_input), dropout_probability, stream); +} diff --git a/transformer_engine/common/fused_attn/context_parallel.cu b/transformer_engine/common/fused_attn/context_parallel.cu index 15708d2d59..5921d97d52 100644 --- a/transformer_engine/common/fused_attn/context_parallel.cu +++ b/transformer_engine/common/fused_attn/context_parallel.cu @@ -341,6 +341,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor thd_read_half_tensor_kernel<<>>( half.data.dptr, tensor.data.dptr, reinterpret_cast(cu_seqlens.data.dptr), batch, hidden_size_in_bytes, half_idx, tensor_shape[seq_dim]); + NVTE_CHECK_CUDA(cudaGetLastError()); } /*************************************************************************************************** @@ -397,11 +398,13 @@ void thd_second_half_lse_correction(Tensor lse, const Tensor &lse_per_step, reinterpret_cast(lse.data.dptr), reinterpret_cast(lse_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, second_half_lse_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { thd_lse_kernel<<>>( reinterpret_cast(lse.data.dptr), reinterpret_cast(lse_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, second_half_lse_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -446,11 +449,13 @@ void thd_read_second_half_lse(const Tensor &lse, const Tensor &cu_seqlens, Tenso reinterpret_cast(lse.data.dptr), reinterpret_cast(half_lse.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, second_half_lse_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { thd_lse_kernel<<>>( reinterpret_cast(lse.data.dptr), reinterpret_cast(half_lse.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, second_half_lse_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -519,6 +524,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co reinterpret_cast(lse_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, dim_per_head, lse_seqlen, lse_per_step_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { thd_out_correction_kernel <<>>( @@ -528,6 +534,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co reinterpret_cast(lse_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, dim_per_head, lse_seqlen, lse_per_step_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -602,6 +609,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step, reinterpret_cast(grad.data.dptr), reinterpret_cast(grad_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, hidden_size, total_tokens); + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -667,6 +675,7 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to thd_partition_indices_kernel<<>>( reinterpret_cast(output.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, total_tokens, world_size, rank); + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace context_parallel diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index 0c261d0fae..59207d59a5 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -91,6 +91,7 @@ void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) { prepare_kernel_fwd<<>>( reinterpret_cast(qkvi.data.dptr), reinterpret_cast(qkv.data.dptr), shape[1], shape[2], shape[3], shape[4]);); + NVTE_CHECK_CUDA(cudaGetLastError()); } void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) { @@ -129,6 +130,7 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), reinterpret_cast(v.data.dptr), reinterpret_cast(qkv.data.dptr), q_shape[0], q_shape[1], q_shape[2], q_shape[3]);); + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace flash_attention diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 0932b2cf85..4e6c3c858b 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -416,6 +416,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( actual_b, b, static_cast(devPtrCuSeqlensQ), static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); + NVTE_CHECK_CUDA(cudaGetLastError()); variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } @@ -454,6 +455,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); + NVTE_CHECK_CUDA(cudaGetLastError()); if (is_ragged_q) { variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_o] = devOffsetsO; @@ -883,6 +885,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( actual_b, b, static_cast(devPtrCuSeqlensQ), static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); + NVTE_CHECK_CUDA(cudaGetLastError()); variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } @@ -916,6 +919,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); + NVTE_CHECK_CUDA(cudaGetLastError()); if (is_ragged_q) { variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_o] = devOffsetsO; diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 3e38a5066e..d7f0983763 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1111,6 +1111,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in cu_seqlens_to_offsets<<>>( b, h, d, reinterpret_cast(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, o_ragged_offset); + NVTE_CHECK_CUDA(cudaGetLastError()); void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); @@ -1577,6 +1578,7 @@ void fused_attn_fp8_bwd_impl( cu_seqlens_to_offsets<<>>( b, h, d, reinterpret_cast(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, o_ragged_offset); + NVTE_CHECK_CUDA(cudaGetLastError()); void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); @@ -1933,6 +1935,7 @@ void fused_attn_fp8_fwd_impl_v1( b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); + NVTE_CHECK_CUDA(cudaGetLastError()); variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } @@ -2329,6 +2332,7 @@ void fused_attn_fp8_bwd_impl_v1( b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); + NVTE_CHECK_CUDA(cudaGetLastError()); variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } diff --git a/transformer_engine/common/fused_attn/kv_cache.cu b/transformer_engine/common/fused_attn/kv_cache.cu index 9bdc41e9e2..67119c323b 100644 --- a/transformer_engine/common/fused_attn/kv_cache.cu +++ b/transformer_engine/common/fused_attn/kv_cache.cu @@ -157,6 +157,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso reinterpret_cast(page_table.data.dptr), reinterpret_cast(cu_new_lens.data.dptr), reinterpret_cast(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len); + NVTE_CHECK_CUDA(cudaGetLastError()); } dim3 grid_size(b, max_ctx_len); copy_to_kv_cache_kernel<<>>( @@ -166,6 +167,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso reinterpret_cast(cu_new_lens.data.dptr), reinterpret_cast(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -215,6 +217,7 @@ void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se reinterpret_cast(tensor.data.dptr), reinterpret_cast(new_tensor.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), b, max_seq_len, h, d); + NVTE_CHECK_CUDA(cudaGetLastError()); } void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b, @@ -254,6 +257,7 @@ void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se reinterpret_cast(tensor.data.dptr), reinterpret_cast(new_tensor.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), b, max_seq_len, h, d); + NVTE_CHECK_CUDA(cudaGetLastError()); } void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 768dbd99f9..df1eae0dd7 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -600,13 +600,14 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud // workspace size requires 4 bytes uint32_t *dout = static_cast(workspace); uint32_t hout{}; - cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream)); constexpr int threads = 128; const int blocks = (len - 1) / threads + 1; get_runtime_num_segments_kernel<<>>(static_cast(cu_seqlen), len, dout); - cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream); - cudaStreamSynchronize(stream); + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); return hout; } @@ -633,4 +634,5 @@ void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>( rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph); + NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index f64b75d971..a738be8736 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -177,9 +177,9 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, config.stream = stream; // Update the max cluster size based on the device - cudaOccupancyMaxPotentialClusterSize( + NVTE_CHECK_CUDA(cudaOccupancyMaxPotentialClusterSize( &cluster_size, - reinterpret_cast(fused_moe_aux_loss_forward_kernel), &config); + reinterpret_cast(fused_moe_aux_loss_forward_kernel), &config)); cudaLaunchAttribute attribute[1]; attribute[0].id = cudaLaunchAttributeClusterDimension; @@ -189,14 +189,15 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, config.numAttrs = 1; config.attrs = attribute; - cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel, probs, - tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk, - coeff, aux_loss, Const_buf); + NVTE_CHECK_CUDA(cudaLaunchKernelEx( + &config, fused_moe_aux_loss_forward_kernel, probs, tokens_per_expert, + total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, aux_loss, Const_buf)); } else { size_t smem_size = sizeof(CompType) * num_cols; fused_moe_aux_loss_forward_kernel <<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, aux_loss, Const_buf); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -247,6 +248,7 @@ void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf, int grid_size = (num_rows + block_size - 1) / block_size; fused_moe_aux_loss_backward_kernel<<>>( Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert, diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 47d2150571..03d22942b5 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -151,6 +151,7 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( <<>>( logits, num_tokens, num_experts, topk, score_function, scores, routing_map, intermediate_output); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, int num_experts, @@ -286,6 +287,7 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher( <<>>( intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function, grad_logits); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output, diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index a1785c6639..03e972332a 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -257,6 +257,7 @@ void fused_topk_with_score_function_forward_kernel_launcher( <<>>( logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, int num_experts, @@ -447,6 +448,7 @@ void fused_topk_with_score_function_backward_kernel_launcher( <<>>( routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, use_pre_softmax, scaling_factor, score_function, grad_logits); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_topk_with_score_function_backward(const Tensor &routing_map, diff --git a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu index 1f54f2e720..bbe722a8f5 100644 --- a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu @@ -353,6 +353,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_forward( scaled_aligned_causal_masked_softmax_warp_forward <<>>(dst, src, scale, microbatches, query_seq_len, key_seq_len); + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -363,6 +364,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_backward( scaled_aligned_causal_masked_softmax_warp_backward <<>>(gradInput, grad, output, scale, microbatches, query_seq_len, key_seq_len); + NVTE_CHECK_CUDA(cudaGetLastError()); } template diff --git a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu index 02f3153727..79318cd28b 100644 --- a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu @@ -513,6 +513,7 @@ void dispatch_scaled_softmax_forward(output_t *dst, const input_t *src, const in default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -625,6 +626,7 @@ void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, c default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -736,6 +738,7 @@ void dispatch_scaled_masked_softmax_backward(output_t *grad_input, const input_t default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu index 351f4946cf..03cdd68279 100644 --- a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu @@ -445,6 +445,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(output_t *dst, const in default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -561,6 +562,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(output_t *grad_input, default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index d65cd7b556..9e6c5417bc 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -22,24 +22,6 @@ namespace { -cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { - using namespace transformer_engine; - switch (t) { - case DType::kFloat16: - return CUDA_R_16F; - case DType::kFloat32: - return CUDA_R_32F; - case DType::kBFloat16: - return CUDA_R_16BF; - case DType::kFloat8E4M3: - return CUDA_R_8F_E4M3; - case DType::kFloat8E5M2: - return CUDA_R_8F_E5M2; - default: - NVTE_ERROR("Invalid type"); - } -} - uint32_t _getAlignment(uintptr_t address) { // alignment are in bytes uint32_t alignment = 256; diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm.h b/transformer_engine/common/include/transformer_engine/comm_gemm.h new file mode 100644 index 0000000000..14cf56a002 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_gemm.h @@ -0,0 +1,156 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file comm_gemm.h + * \brief Functions for distributed (multi-GPU) matrix multiplication. + * + * This API is a TE-native binding to cuBLASMp library. + * Refer here: https://docs.nvidia.com/cuda/cublasmp/usage/tp.html for specific + * patterns, which allow communication-computation overlap. + * + * All GEMM functions here have the same computation semantic, as expressed + * on global matrices, similar to nvte_cublas_gemm call: + * - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors + * - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty + * - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors + * + * Functions differ in matrix distribution patterns + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_ +#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_ + +#include +#include + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +typedef struct NVTECommGemmCtx NVTECommGemmCtx; + +enum NVTECommGemmAlgoType { + kNVTECommGemmAlgoDefault = 0, + kNVTECommGemmAlgoSplitP2P = 1, + kNVTECommGemmAlgoSplitMulticast = 2, + kNVTECommGemmAlgoAtomicP2P = 3, + kNVTECommGemmAlgoAtomicMulticast = 4 +}; + +/*! \brief Create a comm-gemm context. + * + * \param[in] comm NCCL communicator. + * \param[in] nranks Number of ranks. + * \param[in] rank Local rank. + */ +NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank); + +/*! \brief Destroy a comm-gemm context. + * + * \param[in] ctx Context to destroy. + */ +void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx); + +/*! \brief Perform AllGather communication followed by GEMM + * + * Gathers distributed data from all ranks, then computes matrix multiplication. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] m Global m dimension. + * \param[in] n Global n dimension. + * \param[in] k Global k dimension. + * \param[in] a Local part of A matrix. + * \param[in] b Local part of B matrix. + * \param[in,out] d Local part of D matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_act_out Local part of output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of gradient computation. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics) + * \param[in] main_stream CUDA stream used for computation. + * \param[in] algo Algorithm to use. + */ +void nvte_all_gather_gemm(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo); + +/*! \brief Perform GEMM followed by ReduceScatter communication + * + * Computes matrix multiplication, then distributes results across ranks with reduction. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] m Global m dimension. + * \param[in] n Global n dimension. + * \param[in] k Global k dimension. + * \param[in] a Local part of A matrix. + * \param[in] b Local part of B matrix. + * \param[in,out] d Local part of D matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_act_out Local part of output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of gradient computation. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics) + * \param[in] main_stream CUDA stream used for computation. + * \param[in] algo Algorithm to use. + */ +void nvte_gemm_reduce_scatter(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, + const NVTETensor a, const NVTETensor b, const NVTETensor d, + const NVTETensor bias, const NVTETensor pre_act_out, bool transa, + bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t main_stream, NVTECommGemmAlgoType algo); + +/*! \brief Perform GEMM followed by AllReduce communication + * + * Computes matrix multiplication, then reduces results across all ranks. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] m Global m dimension. + * \param[in] n Global n dimension. + * \param[in] k Global k dimension. + * \param[in] a Local part of A matrix. + * \param[in] b Local part of B matrix. + * \param[in,out] d Local part of D matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_act_out Local part of output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of gradient computation. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics) + * \param[in] main_stream CUDA stream used for computation. + * \param[in] algo Algorithm to use. + */ +void nvte_gemm_all_reduce(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo); + +/*! \brief Get local number of rows or columns. + * + * Utility function to get local dimension. + * Block size, nranks and local rank is derived from the context ctx. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] global_size Global dimension. + */ +int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_COMM_GEMM_H_ diff --git a/transformer_engine/common/include/transformer_engine/dropout.h b/transformer_engine/common/include/transformer_engine/dropout.h new file mode 100644 index 0000000000..6ba1ab9126 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/dropout.h @@ -0,0 +1,51 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dropout.h + * \brief Functions for dropout. + */ + +#ifndef TRANSFORMER_ENGINE_DROPOUT_FP8_H_ +#define TRANSFORMER_ENGINE_DROPOUT_FP8_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Dropout forward kernel. + * + * \param[in] input Input tensor. + * \param[out] output Output tensor. + * \param[out] mask Mask tensor. Each bit corresponds to an + * output tensor entry. Ones indicate kept + * entries and zeros indicate dropped entries. + * \param[in] rng_state RNG engine inputs. + * \param[in] dropout_probability Dropout probability. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_dropout_fwd(const NVTETensor input, NVTETensor output, NVTETensor mask, + NVTETensor rng_state, float dropout_probability, cudaStream_t stream); + +/*! \brief Dropout backward kernel. + * + * \param[in] grad_output Gradient of output tensor. + * \param[out] mask Mask tensor. Each bit corresponds to an + * output tensor entry. Ones indicate kept + * entries and zeros indicate dropped entries. + * \param[out] grad_input Gradient of input tensor. + * \param[in] dropout_probability Dropout probability. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_dropout_bwd(const NVTETensor grad_output, const NVTETensor mask, NVTETensor grad_input, + float dropout_probability, cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/transformer_engine/common/multi_tensor/l2norm.cu b/transformer_engine/common/multi_tensor/l2norm.cu index ca2fce27ae..cc66562af5 100644 --- a/transformer_engine/common/multi_tensor/l2norm.cu +++ b/transformer_engine/common/multi_tensor/l2norm.cu @@ -413,6 +413,7 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag, reinterpret_cast(ret.data.dptr), per_tensor ? reinterpret_cast(ret_per_tensor.data.dptr) : nullptr, per_tensor, max_chunks_per_tensor); + NVTE_CHECK_CUDA(cudaGetLastError()); } void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, @@ -440,6 +441,7 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, reinterpret_cast(ret.data.dptr), per_tensor ? reinterpret_cast(ret_per_tensor.data.dptr) : nullptr, per_tensor, max_chunks_per_tensor); + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace multi_tensor_l2norm diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index c280c1c353..337b165080 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -138,8 +138,8 @@ void TeNormalizationPlan::_set_workspace() { if (_launch_params.barrier_bytes > 0) { _launch_params.params.barrier = reinterpret_cast(workspace_dptr + _launch_params.workspace_bytes); - cudaMemsetAsync(_launch_params.params.barrier, 0, _launch_params.barrier_bytes, - _launch_params.stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(_launch_params.params.barrier, 0, + _launch_params.barrier_bytes, _launch_params.stream)); } if constexpr (std::is_same_v) { _launch_params.params.dgamma_part = diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index f63edfb644..1eeb08415b 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -14,16 +14,16 @@ using namespace transformer_engine::normalization; template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_ln_bwd_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &ln_bwd_tuned_kernel; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_col = launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; @@ -49,13 +49,14 @@ void launch_tuned_(LaunchParams &launch_params, if (ctas_per_row == 1) { kernel<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, - stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), + Kernel_traits::SMEM_BYTES, stream)); } using Kernel_traits_f = @@ -66,13 +67,14 @@ void launch_tuned_(LaunchParams &launch_params, auto kernel_f = &ln_bwd_finalize_tuned_kernel; kernel_f<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_ln_bwd_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; // Instantiate kernel @@ -87,8 +89,8 @@ void launch_general_(LaunchParams &launch_params, int ctas_per_row = launch_params.params.ctas_per_row; if (configure_params) { int ctas_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, - Kernel_traits::THREADS_PER_CTA, 0); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0)); const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); @@ -109,10 +111,11 @@ void launch_general_(LaunchParams &launch_params, dim3 block(Kernel_traits::THREADS_PER_CTA); if (ctas_per_row == 1) { kernel<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream)); } // Launch finalization kernel @@ -126,6 +129,7 @@ void launch_general_(LaunchParams &launch_params, dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); kernel_final<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } #define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ @@ -134,8 +138,8 @@ void launch_general_(LaunchParams &launch_params, void \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_##LAUNCH_TYPE##_( \ - launch_params, configure_params); \ + launch_ln_bwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ REGISTER_NORM_BASE( \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index 9336abc26c..787c75ef8c 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -13,15 +13,15 @@ using namespace transformer_engine::normalization; template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_ln_fwd_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &ln_fwd_tuned_kernel; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_col = launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; @@ -45,19 +45,21 @@ void launch_tuned_(LaunchParams &launch_params, if (ctas_per_row == 1) { kernel<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), + Kernel_traits::SMEM_BYTES_FWD, stream)); } } template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_ln_fwd_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &ln_fwd_general_kernel; @@ -70,8 +72,8 @@ void launch_general_(LaunchParams &launch_params, int ctas_per_row = launch_params.params.ctas_per_row; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0)); const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); @@ -91,10 +93,11 @@ void launch_general_(LaunchParams &launch_params, dim3 block(Kernel_traits::THREADS_PER_CTA); if (ctas_per_row == 1) { kernel<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream)); } } @@ -104,8 +107,8 @@ void launch_general_(LaunchParams &launch_params, void \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_##LAUNCH_TYPE##_( \ - launch_params, configure_params); \ + launch_ln_fwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ REGISTER_NORM_BASE( \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index 0a7b380000..9bd56c4ec9 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -13,8 +13,8 @@ using namespace transformer_engine::normalization; template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_rmsnorm_bwd_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &rmsnorm_bwd_tuned_kernel; @@ -48,6 +48,7 @@ void launch_tuned_(LaunchParams &launch_params, if (ctas_per_row == 1) { kernel<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); @@ -65,13 +66,14 @@ void launch_tuned_(LaunchParams &launch_params, auto kernel_f = &rmsnorm_bwd_finalize_tuned_kernel; kernel_f<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_rmsnorm_bwd_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; // Instantiate kernel @@ -110,6 +112,7 @@ void launch_general_(LaunchParams &launch_params, dim3 block(Kernel_traits::THREADS_PER_CTA); if (ctas_per_row == 1) { kernel<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *params_ = reinterpret_cast(&launch_params.params); NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, @@ -127,6 +130,7 @@ void launch_general_(LaunchParams &launch_params, dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); kernel_final<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } #define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ @@ -135,8 +139,8 @@ void launch_general_(LaunchParams &launch_params, void \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_##LAUNCH_TYPE##_( \ - launch_params, configure_params); \ + launch_rmsnorm_bwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ REGISTER_NORM_BASE( \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index 25bed95dc5..90b4f13405 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -13,16 +13,16 @@ using namespace transformer_engine::normalization; template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_rmsnorm_fwd_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &rmsnorm_fwd_tuned_kernel; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_col = launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; @@ -46,19 +46,21 @@ void launch_tuned_(LaunchParams &launch_params, if (ctas_per_row == 1) { kernel<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), + Kernel_traits::SMEM_BYTES_FWD, stream)); } } template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_rmsnorm_fwd_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &rmsnorm_fwd_general_kernel; @@ -71,8 +73,8 @@ void launch_general_(LaunchParams &launch_params, int ctas_per_row = launch_params.params.ctas_per_row; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0)); const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); @@ -92,10 +94,11 @@ void launch_general_(LaunchParams &launch_params, dim3 block(Kernel_traits::THREADS_PER_CTA); if (ctas_per_row == 1) { kernel<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream)); } } @@ -105,8 +108,8 @@ void launch_general_(LaunchParams &launch_params, void \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_##LAUNCH_TYPE##_( \ - launch_params, configure_params); \ + launch_rmsnorm_fwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ REGISTER_NORM_BASE( \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ diff --git a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu index a18ea6d4a7..d5f6aeecce 100644 --- a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu +++ b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu @@ -35,17 +35,20 @@ void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t switch (wait_kind) { case WaitKind::KERNEL_WAIT: wait_until_on_stream_and_reset<<<1, 1, 0, cur_stream>>>(sig_addr, wait_value, signal_reset); + NVTE_CHECK_CUDA(cudaGetLastError()); break; case WaitKind::NVSHMEM_WAIT: nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, cur_stream); - cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset, - CU_STREAM_WRITE_VALUE_DEFAULT); + NVTE_CHECK_CUDA_DRIVER(cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, + (cuuint64_t)signal_reset, + CU_STREAM_WRITE_VALUE_DEFAULT)); break; case WaitKind::STREAM_WAIT: - cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)wait_value, - CU_STREAM_WAIT_VALUE_GEQ); - cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset, - CU_STREAM_WRITE_VALUE_DEFAULT); + NVTE_CHECK_CUDA_DRIVER(cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, + (cuuint64_t)wait_value, CU_STREAM_WAIT_VALUE_GEQ)); + NVTE_CHECK_CUDA_DRIVER(cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, + (cuuint64_t)signal_reset, + CU_STREAM_WRITE_VALUE_DEFAULT)); break; } } diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 5716196fea..d66298b692 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -243,11 +243,13 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, num_out_tokens); + NVTE_CHECK_CUDA(cudaGetLastError()); blocks = num_rows; threads = std::min(num_cols / kElementsPerAccess, 1024); moe_permute_kernel<<>>( input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { // moe_unpermute_bwd @@ -259,6 +261,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, moe_permute_kernel<<>>( input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { // moe_unpermute_bwd with probs @@ -282,6 +285,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, } else { NVTE_ERROR("topK cannot exceed 128."); } + NVTE_CHECK_CUDA(cudaGetLastError()); } } } @@ -306,11 +310,13 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f moe_unpermute_kernel<<>>( input, output, row_id_map, nullptr, num_rows, topK, num_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { // moe_unpermute_fwd with probs moe_unpermute_kernel<<>>( input, output, row_id_map, prob, num_rows, topK, num_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index f8642cfb68..e1657b77a1 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -60,7 +60,7 @@ __launch_bounds__(amax_kernel_threads) __global__ template void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { // Zero out amax so we can update with atomic max - cudaMemsetAsync(amax, 0, sizeof(float), stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); // Return immediately if tensor is empty if (N == 0) { diff --git a/transformer_engine/common/recipe/fp8_block_scaling.cu b/transformer_engine/common/recipe/fp8_block_scaling.cu index 759197dc88..42a7b8d696 100644 --- a/transformer_engine/common/recipe/fp8_block_scaling.cu +++ b/transformer_engine/common/recipe/fp8_block_scaling.cu @@ -183,6 +183,7 @@ void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_ reinterpret_cast(amax.data.dptr), amax_stride_h, amax_stride_w, h, w, start_offset, len);) + NVTE_CHECK_CUDA(cudaGetLastError()); } void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h, @@ -215,6 +216,7 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s reinterpret_cast(out.data.dptr), reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, h, w, start_offset, len);))) + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace fp8_block_scaling_recipe diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index fcb379a82b..9ec86a37c6 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -387,22 +387,25 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; switch (vec_load_size) { case 4: - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); swizzle_row_scaling_kernel <<>>( input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; case 2: - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); swizzle_row_scaling_kernel <<>>( input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; case 1: - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); swizzle_row_scaling_kernel <<>>( input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); @@ -411,6 +414,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s NVTE_ERROR("Not valid vec_load_size."); break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } if (input->has_columnwise_data()) { int vec_load_size = (num_tiles_m - 1) % 4 + 1; @@ -422,24 +426,27 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; switch (vec_load_size) { case 4: - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); swizzle_col_scaling_kernel <<>>(input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k, original_M, original_K); break; case 2: - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); swizzle_col_scaling_kernel <<>>(input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k, original_M, original_K); break; case 1: - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); swizzle_col_scaling_kernel <<>>(input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, @@ -449,6 +456,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s NVTE_ERROR("Not valid vec_load_size."); break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } // 2D block scaling @@ -489,23 +497,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, if (is_rowwise) { switch (vec_load_size) { case 4: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( multi_tensor_swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); multi_tensor_swizzle_row_scaling_kernel <<>>(kernel_args); break; case 2: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( multi_tensor_swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); multi_tensor_swizzle_row_scaling_kernel <<>>(kernel_args); break; case 1: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( multi_tensor_swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); multi_tensor_swizzle_row_scaling_kernel <<>>(kernel_args); break; @@ -516,23 +524,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, } else { switch (vec_load_size) { case 4: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( multi_tensor_swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); multi_tensor_swizzle_col_scaling_kernel <<>>(kernel_args); break; case 2: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( multi_tensor_swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); multi_tensor_swizzle_col_scaling_kernel <<>>(kernel_args); break; case 1: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( multi_tensor_swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); multi_tensor_swizzle_col_scaling_kernel <<>>(kernel_args); break; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index a33f3d959a..55654989a7 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -544,11 +544,11 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { // Zero out tensor data if allocated if (t.data.dptr != nullptr) { const size_t size_in_bytes = nvte_tensor_size_bytes(tensor); - cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream)); } // Set amax to 0 if allocated if (t.amax.dptr != nullptr) { - cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream)); } } diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index 723dbb4a95..648070c8d1 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -335,6 +335,7 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu static_cast(output.scale.dptr), static_cast(output.amax.dptr), static_cast(output.scale_inv.dptr), row_length, num_rows); + NVTE_CHECK_CUDA(cudaGetLastError()); } } else { NVTE_ERROR("Not implemented scaling mode: ", to_string(output.scaling_mode)); diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index ca48a055a7..6329e79ae7 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -264,6 +264,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt reinterpret_cast(dbias->data.dptr), reinterpret_cast(workspace.data.dptr), reduce_dbias_row_length, reduce_dbias_num_rows); + NVTE_CHECK_CUDA(cudaGetLastError()); } template , - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + cudaFuncAttributePreferredSharedMemoryCarveout, 100)); cast_transpose_fused_kernel_notaligned <<>>( param, row_length, num_rows, num_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); } if constexpr (IS_DBIAS) { @@ -1197,10 +1199,10 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile * (THREADS_PER_WARP + 1) * sizeof(Vec); if (full_tile) { - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( dgated_act_cast_transpose_kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + cudaFuncAttributePreferredSharedMemoryCarveout, 100)); dgated_act_cast_transpose_kernel @@ -1213,11 +1215,12 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu reinterpret_cast(output->amax.dptr), reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, n_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( dgated_act_cast_transpose_kernel_notaligned, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + cudaFuncAttributePreferredSharedMemoryCarveout, 100)); dgated_act_cast_transpose_kernel_notaligned <<>>( @@ -1229,6 +1232,7 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu reinterpret_cast(output->amax.dptr), reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, n_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); }); // NOLINT(*) ); // NOLINT(*) } diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index 2be365465b..bf38565686 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -258,6 +258,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector <<>>(kernel_args_aligned);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); kernel_args_aligned.num_tensors = 0; } if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) { @@ -271,6 +272,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector <<>>(kernel_args_unaligned);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); kernel_args_unaligned.num_tensors = 0; } @@ -311,6 +313,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector <<>>(kernel_args_aligned);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } if (kernel_args_unaligned.num_tensors > 0) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( @@ -323,6 +326,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector <<>>(kernel_args_unaligned);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index 103f45cf1f..9f0acd8071 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -279,6 +279,7 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr static_cast(noop.data.dptr), static_cast(output.data.dptr), row_length, num_rows); + NVTE_CHECK_CUDA(cudaGetLastError()); }); // NOLINT(*) } diff --git a/transformer_engine/common/transpose/transpose_fusion.cu b/transformer_engine/common/transpose/transpose_fusion.cu index 7a19c12852..3c51ce3dab 100644 --- a/transformer_engine/common/transpose/transpose_fusion.cu +++ b/transformer_engine/common/transpose/transpose_fusion.cu @@ -416,6 +416,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt reinterpret_cast(dbias->data.dptr), reinterpret_cast(workspace.data.dptr), reduce_dbias_row_length, reduce_dbias_num_rows); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor *dbias, @@ -472,17 +473,21 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor param.workspace = reinterpret_cast(workspace->data.dptr); if (full_tile) { - cudaFuncSetAttribute(transpose_dbias_kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + NVTE_CHECK_CUDA(cudaFuncSetAttribute(transpose_dbias_kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, + 100)); transpose_dbias_kernel <<>>( param, row_length, num_rows, n_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { - cudaFuncSetAttribute(transpose_dbias_kernel_notaligned, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(transpose_dbias_kernel_notaligned, + cudaFuncAttributePreferredSharedMemoryCarveout, 100)); transpose_dbias_kernel_notaligned <<>>( param, row_length, num_rows, n_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); } reduce_dbias(*workspace, dbias, row_length, num_rows, nvec_out, diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 83359eb053..50ff82d85f 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -950,16 +950,17 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( cast_fp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); cast_fp8_gated_kernel <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols);); // NOLINT(*) - ); // NOLINT(*) + cols); + NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) } template , - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); mxfp8_kernel::cast_mxfp8_gated_kernel @@ -1096,13 +1097,14 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); mxfp8_kernel::cast_mxfp8_gated_kernel @@ -1112,13 +1114,14 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::BIDIMENSIONAL: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); mxfp8_kernel::cast_mxfp8_gated_kernel @@ -1128,6 +1131,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); break; }); // NOLINT(*) ); // NOLINT(*) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 9a02d71f2d..1158132e3f 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -894,6 +894,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, reduce_dbias_kernel <<>>( reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -925,6 +926,7 @@ static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream cast_fp8_1D_kernel<<>>( input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -988,6 +990,7 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols); + NVTE_CHECK_CUDA(cudaGetLastError()); if constexpr (IS_DBIAS) { reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); @@ -1124,10 +1127,10 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, switch (scaling_type) { case ScalingType::ROWWISE: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); cast_mxfp8_2D_kernel @@ -1136,12 +1139,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); cast_mxfp8_2D_kernel @@ -1150,12 +1154,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::BIDIMENSIONAL: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); cast_mxfp8_2D_kernel @@ -1164,6 +1169,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); break; } diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index a82f113075..e2d8d34f3d 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -329,6 +329,7 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s ); // NOLINT(*) ); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace dequantization diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 173aad52af..941899b28c 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -12,8 +12,13 @@ #include #include +#ifdef NVTE_WITH_CUBLASMP +#include +#endif // NVTE_WITH_CUBLASMP + #include #include +#include #include "../util/string.h" @@ -87,4 +92,16 @@ } \ } while (false) +#ifdef NVTE_WITH_CUBLASMP + +#define NVTE_CHECK_CUBLASMP(expr) \ + do { \ + const cublasMpStatus_t status = (expr); \ + if (status != CUBLASMP_STATUS_SUCCESS) { \ + NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \ + } \ + } while (false) + +#endif // NVTE_WITH_CUBLASMP + #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index ad6cf2a2ee..0d92b243a7 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -248,6 +248,7 @@ void multi_padding(const std::vector input_list, std::vector o const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; multi_padding_kernel <<>>(kernel_args);); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); kernel_args.num_tensors = 0; } @@ -277,6 +278,7 @@ void multi_padding(const std::vector input_list, std::vector o const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; multi_padding_kernel <<>>(kernel_args);); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -322,6 +324,7 @@ void multi_unpadding(const std::vector input_list, std::vector const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; multi_unpadding_kernel <<>>(kernel_args);); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); kernel_args.num_tensors = 0; } @@ -349,6 +352,7 @@ void multi_unpadding(const std::vector input_list, std::vector const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; multi_unpadding_kernel <<>>(kernel_args);); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 6e4507eef0..0d667a0ece 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -364,6 +364,7 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out break; } } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -398,6 +399,7 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp break; } } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -491,6 +493,7 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c break; } } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -602,6 +605,7 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu break; } } + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 089ef75f1c..df89174b2c 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -34,6 +34,7 @@ te_dtype_to_jax_dtype, get_padded_spec, get_cudnn_version, + get_all_device_compute_capability, ) from ..sharding import ( global_mesh_resource, @@ -2745,6 +2746,11 @@ def fused_attn_bwd( assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) + if 100 in get_all_device_compute_capability(): + assert not ( + attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 + ), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" + fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 22842e4f3e..a27cec001a 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -219,7 +219,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F """ Helper function to manage primitive states by name without modifying environment variables. Allows enabling specific primitives, disabling specific primitives, or disabling all primitives. - This helper is used in the QuantizeConfig.initialize() methods. + This helper is used in the get_quantize_config().initialize() methods. Args: enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None. diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 95ef428219..be73f708e2 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -28,7 +28,7 @@ ScalingMode, Quantizer, GroupedQuantizer, - QuantizeConfig, + get_quantize_config, QuantizerSet, QuantizeLayout, noop_quantizer_set, @@ -754,7 +754,7 @@ def _te_gemm( fuse_bias: bool = False, fuse_gelu: bool = False, grad: bool = False, - use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, + use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP, ) -> Tuple[jax.Array, ...]: # Prepare non-quantized GEMM operands @@ -1107,7 +1107,7 @@ def _jax_gemm_fp8_impl(lhs, rhs): ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" precision = ( jax.lax.Precision.HIGHEST - if QuantizeConfig.FP8_2X_ACC_FPROP + if get_quantize_config().FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT ) return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 94dfaa45a4..3bda37128b 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -193,6 +193,16 @@ def get_min_device_compute_capability(): ) +def get_all_device_compute_capability(): + """ + Returns a list of compute capability of all local devices. + """ + return tuple( + transformer_engine_jax.get_device_compute_capability(local_gpu_id) + for local_gpu_id in range(len(jax.local_devices())) + ) + + def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None): """ Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 0b27557447..198beb55eb 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -57,14 +57,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): name = "te_dbias_quantize_ffi" multiple_results = True impl_static_args = ( - 2, 3, 4, 5, 6, 7, 8, - ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer + 9, + ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer, amax_aval inner_primitive = None outer_primitive = None @@ -72,6 +72,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): def abstract( x_aval, scale_aval, + amax_aval, *, out_dtype, scaling_mode, @@ -95,7 +96,7 @@ def abstract( rowwise_out_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) - updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + updated_amax_aval = amax_aval rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode @@ -168,6 +169,7 @@ def lowering( ctx, x, scale, + amax, *, out_dtype, scaling_mode, @@ -181,13 +183,17 @@ def lowering( te_dbias_quantize_p lowering rules """ del out_dtype, scale_dtype, is_outer - x_aval, scale_aval = ctx.avals_in + x_aval, scale_aval, amax_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert scale_aval.dtype == jnp.float32 - return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)( + assert scale_aval.dtype == amax_aval.dtype == jnp.float32 + return ffi.ffi_lowering( + BaseDBiasQuantizePrimitive.name, + operand_output_aliases={2: 4}, # donate amax buffer to updated_amax + )( ctx, x, scale, + amax, scaling_mode=scaling_mode.value, q_layout=q_layout, flatten_axis=flatten_axis, @@ -198,6 +204,7 @@ def lowering( def impl( x, scale, + amax, out_dtype, scaling_mode, q_layout, @@ -222,6 +229,7 @@ def impl( ) = BaseDBiasQuantizePrimitive.inner_primitive.bind( x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -268,15 +276,15 @@ def batcher( del is_outer check_valid_batch_dims(batch_dims) assert BaseDBiasQuantizePrimitive.outer_primitive is not None - x, scale = batched_args - x_bdim, scale_bdim = batch_dims - amax_bdim = scale_bdim + x, scale, amax = batched_args + x_bdim, scale_bdim, amax_bdim = batch_dims out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim return ( BaseDBiasQuantizePrimitive.outer_primitive.bind( x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -303,7 +311,7 @@ def infer_sharding_from_operands( del (out_dtype, result_infos, scale_dtype, is_outer) # Unused. x_spec = get_padded_spec(arg_infos[0]) - scale_spec = get_padded_spec(arg_infos[1]) + amax_spec = get_padded_spec(arg_infos[2]) out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), @@ -329,10 +337,8 @@ def infer_sharding_from_operands( desc="BaseDBiasQuantizePrimitive.dbias_sharding", ) - scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -341,14 +347,14 @@ def infer_sharding_from_operands( scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv" ) - amax_sharding = NamedSharding( - mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax" - ) colwise_scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*colwise_scale_inv_spec), desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", ) + amax_sharding = NamedSharding( + mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax" + ) return ( out_sharding, @@ -375,7 +381,7 @@ def partition( del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) - scale_spec = get_padded_spec(arg_infos[1]) + amax_spec = get_padded_spec(arg_infos[2]) out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), @@ -401,10 +407,8 @@ def partition( desc="BaseDBiasQuantizePrimitive.dbias_sharding", ) - scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -432,7 +436,7 @@ def partition( dbias_sharding, ) - def sharded_impl(x, scale): + def sharded_impl(x, scale, amax): ( local_x, local_colwise_x, @@ -443,6 +447,7 @@ def sharded_impl(x, scale): ) = BaseDBiasQuantizePrimitive.impl( x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -510,7 +515,7 @@ def shardy_sharding_rule( amax = (prefix + "amax",) return SdyShardingRule( - (x_axes, ("…1",)), + (x_axes, ("…1",), amax), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), ) @@ -638,6 +643,9 @@ def _quantize_dbias_impl( elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale + # Make sure amax is init with zero + amax = jnp.zeros((1,), jnp.float32) + # It is faster to use 1x quantization for tensor scaling is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) force_1x_quantization = ( @@ -659,6 +667,7 @@ def _quantize_dbias_impl( ) = PrimitiveClass.outer_primitive.bind( x, scale, + amax, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, q_layout=q_layout.value, @@ -931,6 +940,7 @@ def grouped_quantize( x: jnp.ndarray, quantizer: GroupedQuantizer, group_sizes: jnp.ndarray = None, + amax: jnp.ndarray = None, flatten_axis: int = -1, ) -> GroupedScaledTensor1x: """Quantize a tensor in grouped manner. @@ -943,6 +953,7 @@ def grouped_quantize( x: Input tensor to quantize quantizer: The quantizer to use for quantization group_sizes: Array of ints containing the size of each group (default: None) + amax: The amax of x; if None, it is auto-generated. (default: None) flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) Returns: @@ -985,7 +996,10 @@ def grouped_quantize( scale = scale.at[i].set(quantizer_i.scale[0]) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: - row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) + if amax is not None: + row_amax = amax + else: + row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) segment_ids = jnp.repeat( jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] ) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 29d0fbfa6a..032ac9eb70 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -285,18 +285,17 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t out_dtype_bytes = te_dtype_bytes(out_dtype); if (is_tensor_scaling) { - cudaStream_t stream_0 = nvte_get_compute_stream(0); size_t dpitch = tensor_scaling_sinv_aligment; size_t spitch = lhs_sinv_dtype_bytes; size_t width = lhs_sinv_dtype_bytes; size_t height = lhs_sinv_size; cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream_0); + cudaMemcpyDeviceToDevice, stream); spitch = rhs_sinv_dtype_bytes; width = rhs_sinv_dtype_bytes; height = rhs_sinv_size; cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream_0); + cudaMemcpyDeviceToDevice, stream); lhs_sinv_ptr = lhs_scatter_aligned_ptr; rhs_sinv_ptr = rhs_scatter_aligned_ptr; } diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 7bea11f916..d17d83ec1e 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -72,9 +72,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ } Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type output_trans_buf, - Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, + Buffer_Type amax_buf, Result_Type output_buf, + Result_Type output_trans_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); @@ -119,11 +120,10 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T if (is_fp8_dtype(out_dtype)) { if (is_tensor_scaling) { float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax = reinterpret_cast(amax_buf->untyped_data()); + float *amax = reinterpret_cast(updated_amax_buf->untyped_data()); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - nvte_memset(amax, 0, sizeof(float), stream); output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); output_tensor.set_rowwise_scale_inv( scale_inv_buf->untyped_data(), @@ -183,6 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Ctx() // stream .Arg() // input .Arg() // scale + .Arg() // amax .Ret() // output .Ret() // colwise output .Ret() // scale_inv diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 4a50fe0e5a..65d65e7d4a 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -16,13 +16,45 @@ from . import cpp_extensions as tex from .quantize import ( + ScaledTensorFactory, + ScalingMode, + QuantizeLayout, QuantizerSet, noop_quantizer_set, with_sharding_constraint_by_logical_axes, + is_fp8_gemm_with_all_layouts_supported, TensorUsage, ) +def _all_gather_kernel(kernel, mesh_axis, axis_idx): + assert mesh_axis is not None + assert 0 < axis_idx < len(kernel.shape) + + # TODO(Ming Hunag): Add a condition branch for with/without shmap. + kernel_shape = kernel.shape + kernel_whole_shape = (*kernel_shape[:axis_idx], -1, *kernel_shape[axis_idx + 1 :]) + global_kernel = jax.lax.all_gather(kernel, mesh_axis, axis=axis_idx) + global_kernel = global_kernel.reshape(*kernel_whole_shape) + return global_kernel + + +def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx): + assert mesh_axis is not None + assert 0 < axis_idx < len(scattered_kernel_shape) + + # TODO(Ming Hunag): Add a condition branch for with/without shmap. + kernel = kernel.reshape( + *scattered_kernel_shape[:axis_idx], + -1, + scattered_kernel_shape[axis_idx], + *scattered_kernel_shape[axis_idx + 1 :], + ) + kernel = jax.lax.psum_scatter(kernel, mesh_axis, scatter_dimension=axis_idx) + kernel = kernel.reshape(scattered_kernel_shape) + return kernel + + def dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -253,10 +285,12 @@ def grouped_dense( group_sizes: jnp.ndarray, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)), bias: jnp.ndarray = None, + kernel_amax: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, preferred_element_type: jnp.dtype = None, group_offset: jnp.array = None, quantizer_set: QuantizerSet = noop_quantizer_set, + kernel_fsdp_info: Tuple[str, int] = (None, -1), ): """ Perform grouped dense (linear) layer transformation with optional quantization. @@ -268,10 +302,15 @@ def grouped_dense( contracting_dims: Tuple of sequences specifying which dimensions to contract (currently only supports ((1,), (1,))) bias: Bias tensor of shape (G, N) + kernel_amax: The amax values of weight matrix of shape (G,) precision: JAX precision for the GEMM operation preferred_element_type: Preferred data type for the output tensor group_offset: 1D array containing offsets for each group (not yet implemented) quantizer_set: Set of quantizers for FP8 quantization of the input and output + kernel_fsdp_info: A tuple containing FSDP-related information for a weight matrix + represented in the format (str, int). The first element is the + FSDP mesh axis, and the second element is the dimension along + which the weight is sharded. Returns: A jnp.ndarray containing the result of the grouped linear operation @@ -282,25 +321,29 @@ def grouped_dense( group_sizes, contracting_dims, bias, + kernel_amax, precision, preferred_element_type, group_offset, quantizer_set, + kernel_fsdp_info, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 6, 7, 8, 10)) def _grouped_dense( x, kernel, group_sizes, contracting_dims, bias, + kernel_amax, precision, preferred_element_type, group_offset, quantizer_set, + kernel_fsdp_info, ): output, _ = _grouped_dense_fwd_rule( x, @@ -308,10 +351,12 @@ def _grouped_dense( group_sizes, contracting_dims, bias, + kernel_amax, precision, preferred_element_type, group_offset, quantizer_set, + kernel_fsdp_info, ) return output @@ -322,21 +367,31 @@ def _grouped_dense_fwd_rule( group_sizes, contracting_dims, bias, + kernel_amax, precision, preferred_element_type, group_offset, quantizer_set, + kernel_fsdp_info, ): use_bias = bias is not None is_noop_quantizer_set = quantizer_set == noop_quantizer_set + kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info + kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None + if is_noop_quantizer_set: grouped_gemm_x = x grouped_gemm_kernel = kernel ctx_x = x ctx_kernel = kernel flatten_axis_k = None + + if kernel_fsdp_enabled: + kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx) else: + original_quantizer_set_kernel_q_layout = quantizer_set.kernel.q_layout + x_contracting_dims, k_contracting_dims = contracting_dims flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis @@ -352,10 +407,24 @@ def _grouped_dense_fwd_rule( ) casted_x = tex.grouped_quantize( - x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x + x, + quantizer_set.x, + group_sizes, + flatten_axis=flatten_axis_x, ) + + ctx_kernel_usage = TensorUsage.RHS_TRANS + if kernel_fsdp_enabled: + assert quantizer_set.kernel.scaling_mode in [ + ScalingMode.CURRENT_TENSOR_SCALING, + ScalingMode.DELAYED_TENSOR_SCALING, + ] + # Perform `cast` only + ctx_kernel_usage = TensorUsage.LHS + quantizer_set.kernel.q_layout = QuantizeLayout.ROWWISE + casted_kernel = tex.grouped_quantize( - kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k + kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k ) contracting_dims = (x_contracting_dims, k_contracting_dims) @@ -363,9 +432,51 @@ def _grouped_dense_fwd_rule( # rowwise_casted_x.original_shape == (M, K) # colwise_casted_kernel.original_shape == (G, N, K) grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS) - grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS) - ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS) + ctx_kernel = casted_kernel.get_tensor(usage=ctx_kernel_usage) + + if kernel_fsdp_enabled: + ctx_kernel_in_original_shape = ctx_kernel.data.reshape(ctx_kernel.original_shape) + global_ctx_kernel_data = _all_gather_kernel( + ctx_kernel_in_original_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx + ) + kernel_shape = global_ctx_kernel_data.shape + + ctx_kernel = ScaledTensorFactory.create_1x( + global_ctx_kernel_data.reshape(-1), + ctx_kernel.scale_inv, + ctx_kernel.scaling_mode, + dq_dtype=ctx_kernel.dq_dtype, + is_colwise=False, + data_layout="N", + flatten_axis=ctx_kernel.flatten_axis, + group_sizes=ctx_kernel.group_sizes, + original_shape=kernel_shape, + group_axis=ctx_kernel.group_axis, + ) + + if is_fp8_gemm_with_all_layouts_supported(): + grouped_gemm_kernel = ctx_kernel + else: + grouped_gemm_kernel_data = global_ctx_kernel_data.transpose(0, 2, 1) + grouped_gemm_kernel = ScaledTensorFactory.create_1x( + grouped_gemm_kernel_data.reshape(-1), + ctx_kernel.scale_inv, + ctx_kernel.scaling_mode, + dq_dtype=ctx_kernel.dq_dtype, + is_colwise=True, + data_layout="T", + flatten_axis=ctx_kernel.flatten_axis, + group_sizes=ctx_kernel.group_sizes, + original_shape=kernel_shape, + group_axis=ctx_kernel.group_axis, + ) + else: + grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) + + # Reset quantizer_set.kernel.q_layout to align the PyTree as the given one. + # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled. + quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout output = tex.grouped_gemm( grouped_gemm_x, @@ -393,7 +504,7 @@ def _grouped_dense_fwd_rule( def _grouped_dense_bwd_rule( - contracting_dims, precision, preferred_element_type, group_offset, ctx, grad + contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad ): fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims @@ -474,11 +585,17 @@ def _grouped_dense_bwd_rule( preferred_element_type=preferred_element_type, group_offset=group_offset, ) + kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info + if kernel_fsdp_mesh_axis is not None: + wgrad = _psum_scatter_kernel( + wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx + ) group_sizes_grad = None dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None + dkernel_amax = None - return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set + return dgrad, wgrad, group_sizes_grad, dbias, dkernel_amax, quantizer_set _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index dc9d0209b1..c548c54efa 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -32,7 +32,14 @@ jax_scaled_masked_softmax, jax_scaled_upper_triang_masked_softmax, ) -from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode +from ..quantize import ( + QuantizerFactory, + get_quantize_config, + QuantizeMeta, + QuantizeMetaSet, + ScalingMode, + TensorSource, +) PRNGKey = Any Shape = Tuple[int, ...] @@ -350,7 +357,7 @@ def generate_quantize_meta(quantizer_name: str): collection_name = ( variable_collection if variable_collection is not None - else QuantizeConfig.COLLECTION_NAME + else get_quantize_config().COLLECTION_NAME ) scale = self.variable( collection_name, @@ -363,14 +370,14 @@ def generate_quantize_meta(quantizer_name: str): collection_name, f"{quantizer_name}{postfix}_amax_history", jnp.zeros, - (QuantizeConfig.AMAX_HISTORY_LEN,), + (get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32, ).value return QuantizeMeta(scale=scale, amax_history=amax_history) - if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING or isinstance( - fp8_recipe, recipe.DelayedScaling - ): + if get_quantize_config().get_scaling_mode( + TensorSource.X + ) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling): x_meta = generate_quantize_meta("x") kernel_meta = generate_quantize_meta("kernel") grad_meta = generate_quantize_meta("grad") @@ -483,7 +490,7 @@ def __call__(self, inputs: Array) -> Array: self.dtype, ) - if not QuantizeConfig.is_fp8_enabled(): + if not get_quantize_config().is_fp8_enabled(): kernel = kernel.astype(input_dtype) if self.use_bias: @@ -692,7 +699,7 @@ def __call__(self, inputs: Array) -> Array: quantizer_set = self.generate_quantizer_set() fuse_layernorm = ( - QuantizeConfig.is_fp8_enabled() + get_quantize_config().is_fp8_enabled() and not self.return_layernorm_output and self.enable_layernorm ) @@ -743,7 +750,7 @@ def __call__(self, inputs: Array) -> Array: kernel_shape, self.dtype, ) - if not QuantizeConfig.is_fp8_enabled(): + if not get_quantize_config().is_fp8_enabled(): kernel = kernel.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) @@ -1005,7 +1012,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: # TODO(Phuong): use fuse_layernorm for high-precision # when NoOpQuantizer and Tensor are implemented fuse_layernorm = ( - QuantizeConfig.is_fp8_enabled() + get_quantize_config().is_fp8_enabled() and not self.return_layernorm_output and self.enable_layernorm ) @@ -1088,7 +1095,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): self.dtype, ) - if not QuantizeConfig.is_fp8_enabled(): + if not get_quantize_config().is_fp8_enabled(): kernel_1 = kernel_1.astype(input_dtype) hidden_size = inputs.shape[-1] @@ -1100,7 +1107,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): kernel_2_shape, self.dtype, ) - if not QuantizeConfig.is_fp8_enabled(): + if not get_quantize_config().is_fp8_enabled(): kernel_2 = kernel_2.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 8727ea7e34..00e3ddc3e8 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -289,6 +289,13 @@ def _layernorm_mlp_fwd_rule( bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) + # This sharding constraint is needed to correct the Shardy sharding propagation + if dot_2_input_axes is not None: + dot_1_output_axes = ( + dot_2_input_axes[:-1] + (None,) + dot_2_input_axes[-1:] + ) # add the act_num axis + dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) + dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # (batch..., hidden_in) -> (batch..., hidden) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index f8d18983e4..3d460e81ab 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -7,9 +7,11 @@ This module provides configuration and helper functions for managing quantization metadata in JAX, including support for different scaling modes and datatypes. """ +from abc import ABC, abstractmethod from contextlib import contextmanager +from dataclasses import dataclass from enum import Enum -from typing import Optional, Tuple, Dict, Union, Sequence +from typing import Optional, Tuple, Dict, Union, Sequence, Type from functools import reduce import operator @@ -26,7 +28,7 @@ from .device_utils import get_device_compute_capability __all__ = [ - "QuantizeConfig", + "get_quantize_config", "fp8_autocast", "is_fp8_available", "update_collections", @@ -34,12 +36,15 @@ "apply_padding_to_scale_inv", "remove_padding_from_scale_inv", "NVTE_FP8_COLLECTION_NAME", + "TensorSource", ] _is_fp8_available = None _reason_for_no_fp8 = "" Collection = Union[Dict, FrozenDict] +NVTE_FP8_COLLECTION_NAME = "fp8_metas" + def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: """Check if delayed scaling FP8 is supported on the given GPU architecture. @@ -154,6 +159,17 @@ def _format2dtypes(format_: recipe.Format): return jnp.bfloat16, jnp.bfloat16 +class TensorSource(Enum): + """Enumeration for where a tensor's data comes from.""" + + # Input data + X = 0 + # Model parameters + KERNEL = 1 + # Gradients in the backward pass + DGRAD = 2 + + class AmaxComputeAlgo(Enum): """Enumeration for AMAX computation algorithms. @@ -166,28 +182,8 @@ class AmaxComputeAlgo(Enum): MOST_RECENT = "most_recent" -def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: - """Convert recipe.Recipe to ScalingMode. - - Args: - fp8_recipe: The FP8 recipe to convert - - Returns: - The corresponding ScalingMode - - Raises: - ValueError: If the recipe type is not supported - """ - if isinstance(fp8_recipe, recipe.DelayedScaling): - return ScalingMode.DELAYED_TENSOR_SCALING - if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - return ScalingMode.MXFP8_1D_SCALING - if isinstance(fp8_recipe, recipe.Float8CurrentScaling): - return ScalingMode.CURRENT_TENSOR_SCALING - raise ValueError("Invalid fp8_recipe!") - - -class QuantizeConfig: +@dataclass +class BaseQuantizeConfig(ABC): """Configuration class for quantization settings. This class manages global quantization settings including FP8 formats, @@ -204,14 +200,13 @@ class QuantizeConfig: FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients INFERENCE_MODE: Whether to enable optimization for inference - SCALING_MODE: Scaling mode AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling AMAX_COMPUTE_ALGO: Algorithm for AMAX computation """ INITIALIZED = False MARGIN: float = 0.0 - COLLECTION_NAME: str = "fp8_metas" + COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME FP8_FORMAT: recipe.Format = recipe.Format.HYBRID FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0] BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1] @@ -219,61 +214,82 @@ class QuantizeConfig: FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False INFERENCE_MODE: bool = False - SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING # DelayedScaling AMAX_HISTORY_LEN: int = 1024 AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX - @staticmethod - def is_fp8_enabled(): + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + """Initialize the quantization configuration. + + Args: + fp8_recipe: The FP8 recipe to use for initialization + """ + self.INITIALIZED = True + self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 + self.FP8_FORMAT = fp8_recipe.fp8_format + self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(self.FP8_FORMAT) + + def is_fp8_enabled(self) -> bool: """Check if FP8 quantization is enabled. Returns: bool: True if quantization is enabled, False otherwise """ - return QuantizeConfig.INITIALIZED + return self.INITIALIZED - @classmethod - def initialize(cls, fp8_recipe: recipe.Recipe) -> None: - """Initialize the quantization configuration. + @abstractmethod + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type. Args: - fp8_recipe: The FP8 recipe to use for initialization + tensor_source: The usage type for which to get the scaling mode. + + Returns: + The scaling mode for the specified usage type. + """ + + def is_supported(self) -> tuple[bool, str]: + """Check if this QuantizeConfig class is supported on the available devices. + + Returns: + bool: True if the class is supported, False otherwise + str: Reason for being unsupported, if applicable. """ - cls.INITIALIZED = True - cls.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 - cls.FP8_FORMAT = fp8_recipe.fp8_format - cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) - cls.SCALING_MODE = _get_scaling_mode(fp8_recipe) - - @classmethod - def finalize(cls) -> None: - """Reset the quantization configuration to default values.""" - cls.INITIALIZED = False - cls.MARGIN = 0.0 - cls.FP8_FORMAT = recipe.Format.HYBRID - cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) - cls.SCALING_MODE = ScalingMode.NO_SCALING - cls.FP8_2X_ACC_FPROP = False - cls.FP8_2X_ACC_DGRAD = False - cls.FP8_2X_ACC_WGRAD = False - cls.SCALING_MODE = ScalingMode.NO_SCALING - cls.INFERENCE_MODE = False - # DelayedScaling - cls.AMAX_HISTORY_LEN = 1024 - cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX - - -class DelayedScalingQuantizeConfig: + + x_scaling_mode = self.get_scaling_mode(TensorSource.X) + kernel_scaling_mode = self.get_scaling_mode(TensorSource.KERNEL) + grad_scaling_mode = self.get_scaling_mode(TensorSource.DGRAD) + for scaling_mode in [x_scaling_mode, kernel_scaling_mode, grad_scaling_mode]: + is_supported, reason = is_fp8_available(scaling_mode=scaling_mode) + if not is_supported: + return is_supported, reason + return True, None + + +class NoOpQuantizeConfig(BaseQuantizeConfig): + """Configuration class higher-precision non-quantized operation.""" + + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + """Initialize no-op configuration.""" + raise NotImplementedError( + "NoOpQuantizeConfig cannot be initialize from a recipe as it represents" + " higher-precision when no quantized recipe is set." + ) + + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + return ScalingMode.NO_SCALING + + +class DelayedScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for delayed scaling FP8 recipe. This class provides specific initialization and finalization for delayed scaling FP8 quantization mode. """ - @staticmethod - def initialize(fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: """Initialize delayed scaling FP8 configuration. Args: @@ -282,6 +298,8 @@ def initialize(fp8_recipe: recipe.Recipe) -> None: Raises: AssertionError: If recipe parameters are not supported """ + super().initialize_from_recipe(fp8_recipe) + assert fp8_recipe.amax_compute_algo in [ "max", "most_recent", @@ -291,71 +309,88 @@ def initialize(fp8_recipe: recipe.Recipe) -> None: ), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX." assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX." - cls = QuantizeConfig - cls.initialize(fp8_recipe) - - cls.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len + self.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len string_to_amax_compute_algo = { "max": AmaxComputeAlgo.MAX, "most_recent": AmaxComputeAlgo.MOST_RECENT, } - cls.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo] + self.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo] - cls.FP8_2X_ACC_DGRAD = True - cls.FP8_2X_ACC_WGRAD = True + self.FP8_2X_ACC_DGRAD = True + self.FP8_2X_ACC_WGRAD = True - @staticmethod - def finalize() -> None: - """Reset the delayed scaling configuration.""" - QuantizeConfig.finalize() + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + return ScalingMode.DELAYED_TENSOR_SCALING -class CurrentScalingQuantizeConfig: +class CurrentScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for current scaling FP8 recipe. This class provides specific initialization and finalization for current scaling FP8 quantization mode. """ - @staticmethod - def initialize(fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: """Initialize current scaling FP8 configuration. Args: fp8_recipe: The FP8 recipe to use for initialization """ - cls = QuantizeConfig - cls.initialize(fp8_recipe) - cls.AMAX_HISTORY_LEN = 0 + super().initialize_from_recipe(fp8_recipe) + self.AMAX_HISTORY_LEN = 0 - @staticmethod - def finalize() -> None: - """Reset the current scaling configuration.""" - QuantizeConfig.finalize() + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + return ScalingMode.CURRENT_TENSOR_SCALING -class BlockScalingQuantizeConfig: +class BlockScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for block scaling FP8 recipe. This class provides specific initialization and finalization for block scaling FP8 quantization mode. """ - @staticmethod - def initialize(fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: """Initialize block scaling FP8 configuration. Args: fp8_recipe: The FP8 recipe to use for initialization """ - cls = QuantizeConfig - cls.initialize(fp8_recipe) - cls.AMAX_HISTORY_LEN = 0 + super().initialize_from_recipe(fp8_recipe) + self.AMAX_HISTORY_LEN = 0 + + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + return ScalingMode.MXFP8_1D_SCALING + + +_QUANTIZE_CONFIG = NoOpQuantizeConfig() + - @staticmethod - def finalize() -> None: - """Reset the block scaling configuration.""" - QuantizeConfig.finalize() +def get_quantize_config(): + """Global instance of BaseQuantizeConfig set by fp8_autocast context.""" + return _QUANTIZE_CONFIG + + +def get_quantize_config_class( + fp8_recipe: recipe.Recipe, +) -> Type[BaseQuantizeConfig]: + """Get the quantization configuration based on the FP8 recipe. + + Args: + fp8_recipe: The FP8 recipe to use for initialization + Returns: + The quantization config class corresponding to the given recipe. + """ + if isinstance(fp8_recipe, recipe.DelayedScaling): + return DelayedScalingQuantizeConfig + if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + return BlockScalingQuantizeConfig + if isinstance(fp8_recipe, recipe.Float8CurrentScaling): + return CurrentScalingQuantizeConfig + raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}") @contextmanager @@ -404,22 +439,22 @@ def fp8_autocast( if fp8_recipe is None: fp8_recipe = recipe.DelayedScaling() - Config = DelayedScalingQuantizeConfig - if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - Config = BlockScalingQuantizeConfig - if isinstance(fp8_recipe, recipe.Float8CurrentScaling): - Config = CurrentScalingQuantizeConfig + global _QUANTIZE_CONFIG + + old_quantize_config = _QUANTIZE_CONFIG + + _QUANTIZE_CONFIG = NoOpQuantizeConfig() try: with global_shard_guard(mesh_resource): if enabled: - fp8_available, reason_for_no_fp8 = is_fp8_available(_get_scaling_mode(fp8_recipe)) - assert fp8_available, reason_for_no_fp8 - - Config.initialize(fp8_recipe) + _QUANTIZE_CONFIG = get_quantize_config_class(fp8_recipe)() + is_supported, reason = _QUANTIZE_CONFIG.is_supported() + assert is_supported, reason + _QUANTIZE_CONFIG.initialize_from_recipe(fp8_recipe) yield finally: - Config.finalize() + _QUANTIZE_CONFIG = old_quantize_config def get_delayed_scaling(): @@ -437,12 +472,12 @@ def get_delayed_scaling(): an instance of DelayedScaling which is set via fp8_autocast. """ amax_compute_algo = ( - "max" if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent" + "max" if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent" ) return recipe.DelayedScaling( - margin=int(QuantizeConfig.MARGIN), - fp8_format=QuantizeConfig.FP8_FORMAT, - amax_history_len=QuantizeConfig.AMAX_HISTORY_LEN, + margin=int(get_quantize_config().MARGIN), + fp8_format=get_quantize_config().FP8_FORMAT, + amax_history_len=get_quantize_config().AMAX_HISTORY_LEN, amax_compute_algo=amax_compute_algo, ) @@ -581,6 +616,3 @@ def apply_padding_to_scale_inv( # Pad the scales with the lowest representable value (2^-127) and return pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape)) return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127) - - -NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 9a65f99bf3..6cecfa361f 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -21,9 +21,10 @@ from .scaling_modes import ScalingMode from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory from .helper import ( - QuantizeConfig, + get_quantize_config, + get_quantize_config_class, AmaxComputeAlgo, - _get_scaling_mode, + TensorSource, ) from .device_utils import is_fp8_gemm_with_all_layouts_supported @@ -56,7 +57,7 @@ def compute_scale_from_amax( fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32) if scale is None: scale = jnp.ones((1,)) - sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN) + sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) return sf @@ -234,7 +235,7 @@ def _quantize_func( dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) amax = jnp.max(jnp.abs(x)).reshape((1,)) fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) - scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN) + scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) scaled_x = x.astype(compute_dtype) * scale clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) @@ -320,7 +321,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) amax_history: jnp.ndarray = field( - default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32) + default_factory=lambda: jnp.zeros((get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32) ) def tree_flatten(self): @@ -397,7 +398,7 @@ def _compute_scale(amax_history, scale, q_dtype): Updated scale value """ # 2. Calculate the current scale - if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: + if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: amax = jnp.max(amax_history, axis=-1, keepdims=True) else: amax = amax_history[0:1] @@ -827,12 +828,21 @@ def create( @staticmethod def _create_set( - scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs + x_scaling_mode, + kernel_scaling_mode, + grad_scaling_mode, + fwd_dtype, + bwd_dtype, + is_2x2x, + n_groups, + **kwargs, ) -> QuantizerSet: """Create a set of quantizers for forward and backward passes. Args: - scaling_mode: Scaling mode to use + x_scaling_mode: Scaling mode to use for input tensor 'x' + kernel_scaling_mode: Scaling mode to use for kernel tensor + grad_scaling_mode: Scaling mode to use for gradient tensor fwd_dtype: Data type for forward pass bwd_dtype: Data type for backward pass is_2x2x: Whether to use 2x2x quantization @@ -846,9 +856,9 @@ def _create_set( q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE else: q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE - if scaling_mode.is_1d_block_scaling(): + if kernel_scaling_mode.is_1d_block_scaling(): q_layout_kernel = QuantizeLayout.COLWISE - if QuantizeConfig.INFERENCE_MODE: + if get_quantize_config().INFERENCE_MODE: q_layout_dgrad = None if "quantize_meta_set" in kwargs: @@ -868,12 +878,12 @@ def _create_set( else: args_x = args_kernel = args_grad = {} - q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x) + q_x = QuantizerFactory.create(1, x_scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x) q_kernel = QuantizerFactory.create( - 1, scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel + 1, kernel_scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel ) q_dgrad = QuantizerFactory.create( - 1, scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad + 1, grad_scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad ) return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad) @@ -892,10 +902,10 @@ def create_set( Args: n_quantizer_sets: Number of quantizer sets to create - scaling_mode: Scaling mode to use, default is QuantizeConfig.SCALING_MODE - fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE - bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE - is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X + scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode + fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE + bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE + is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X n_groups: fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set. **kwargs: Additional arguments for quantizer initialization @@ -912,27 +922,44 @@ def create_set( ) if fp8_recipe is not None: - # TODO(jberchtold): once recipe and scaling mode are decoupled update this logic - scaling_mode = _get_scaling_mode(fp8_recipe) + quantize_config = get_quantize_config_class(fp8_recipe)() + x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X) + kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL) + grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD) + elif scaling_mode is not None: + x_scaling_mode = scaling_mode + kernel_scaling_mode = scaling_mode + grad_scaling_mode = scaling_mode else: - scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE - fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE - bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE + x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X) + kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL) + grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD) + + fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE + bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE if is_2x2x is None: - if scaling_mode.is_1d_block_scaling(): + # TODO(Jeremy): check x, kernel, grad separately for 2x + if x_scaling_mode.is_1d_block_scaling(): is_2x2x = True - elif scaling_mode.is_tensor_scaling(): + elif x_scaling_mode.is_tensor_scaling(): is_2x2x = not is_fp8_gemm_with_all_layouts_supported() else: # NO_SCALING ignores is_2x2x for now is_2x2x = False - is_inference_mode = QuantizeConfig.INFERENCE_MODE + is_inference_mode = get_quantize_config().INFERENCE_MODE assert not is_inference_mode, "Inference mode is not supported yet!" q_set = [] for _ in range(n_quantizer_sets): q_set.append( QuantizerFactory._create_set( - scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs + x_scaling_mode=x_scaling_mode, + kernel_scaling_mode=kernel_scaling_mode, + grad_scaling_mode=grad_scaling_mode, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=is_2x2x, + n_groups=n_groups, + **kwargs, ) ) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index fc4fd13531..868570f73c 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -396,7 +396,7 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: The quantize layout for the tensor usage """ # If we need to support 1x1x for inference in the future - # if QuantizeConfig.INFERENCE_MODE: + # if get_quantize_config().INFERENCE_MODE: # assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!") # if usage == TensorUsage.LHS: # return QuantizeLayout.ROWWISE diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index caa2a46206..339e74e2fc 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -41,22 +41,32 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh): return mesh.shape[resource], resource -def _validate_mesh_resource_configuration(): +def _validate_mesh_resource_configuration(mesh_resource): """Validate that the mesh resource configuration is consistent and conflict-free.""" - gsr = global_mesh_resource() - - is_dp_enabled = gsr.dp_resource is not None and get_mesh_axis_size(gsr.dp_resource) > 1 - is_tp_enabled = gsr.tp_resource is not None and get_mesh_axis_size(gsr.tp_resource) > 1 - is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1 - is_fsdp_enabled = gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1 + is_dp_enabled = ( + mesh_resource.dp_resource is not None and get_mesh_axis_size(mesh_resource.dp_resource) > 1 + ) + is_tp_enabled = ( + mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1 + ) + is_tpsp_enabled = ( + mesh_resource.tpsp_resource is not None + and get_mesh_axis_size(mesh_resource.tpsp_resource) > 1 + ) + is_fsdp_enabled = ( + mesh_resource.fsdp_resource is not None + and get_mesh_axis_size(mesh_resource.fsdp_resource) > 1 + ) assert not (is_dp_enabled and is_fsdp_enabled), ( "Data parallelism and full-sharded data parallelism cannot be enabled at the same time." - f" Got dp_resource={gsr.dp_resource} and fsdp_resource={gsr.fsdp_resource}" + f" Got dp_resource={mesh_resource.dp_resource} and" + f" fsdp_resource={mesh_resource.fsdp_resource}" ) assert not (is_tp_enabled and is_tpsp_enabled), ( "Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time." - f" Got tp_resource={gsr.tp_resource} and tpsp_resource={gsr.tpsp_resource}" + f" Got tp_resource={mesh_resource.tp_resource} and" + f" tpsp_resource={mesh_resource.tpsp_resource}" ) @@ -155,7 +165,7 @@ def with_sharding_constraint_by_logical_axes( flax_rules = flax.linen.get_logical_axis_rules() if len(flax_rules) > 0: return flax.linen.with_logical_constraint( - x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT + x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.AXIS_IS_UNSHARDED ) except ImportError: pass @@ -305,7 +315,6 @@ def global_shard_guard(resource: MeshResource): old_resources = _GLOBAL_MESH_RESOURCE try: _GLOBAL_MESH_RESOURCE = resource - _validate_mesh_resource_configuration() yield finally: _GLOBAL_MESH_RESOURCE = old_resources @@ -322,6 +331,7 @@ def global_mesh_resource() -> MeshResource: " context. If you are not using multiple GPUs, you can use an empty MeshResource by" " wrapping your program in 'with global_shard_guard(MeshResource()):'" ) + _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) return _GLOBAL_MESH_RESOURCE diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 2e86a77a5a..3bdbe4089e 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -33,6 +33,7 @@ def torch_version() -> tuple[int, ...]: from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding from transformer_engine.pytorch.module import initialize_ub from transformer_engine.pytorch.module import destroy_ub +from transformer_engine.pytorch.module import UserBufferQuantizationMode from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.attention import InferenceParams diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 9d6677b628..7097f4ba0f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -434,8 +434,8 @@ def get_attention_backend( # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - if device_compute_capability == (8, 9) and cudnn_version <= (9, 12, 0): - logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.12") + if device_compute_capability == (8, 9) and cudnn_version <= (9, 13, 0): + logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.13") use_fused_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") @@ -822,7 +822,7 @@ def get_attention_backend( # flash-attn >=2.4.1 | yes # FusedAttention | # sub-backend 0 | yes - # sub-backend 1 | workspace optimization path and sm90+: yes; + # sub-backend 1 | workspace optimization path and sm90: yes; # | otherwise: no # sub-backend 2 | no # UnfusedDotProductAttention | yes @@ -838,8 +838,9 @@ def get_attention_backend( use_flash_attention_2 = False if use_fused_attention and deterministic: if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: - logger.debug("Disabling FusedAttention for determinism reasons") + logger.debug("Disabling FusedAttention for determinism reasons with FP8") use_fused_attention = False + fused_attention_backend = None if ( fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] and is_training @@ -849,8 +850,13 @@ def get_attention_backend( or cudnn_version < (8, 9, 5) ) ): - logger.debug("Disabling FusedAttention for determinism reasons") + logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") + use_fused_attention = False + fused_attention_backend = None + if is_training and device_compute_capability >= (10, 0) and cudnn_version <= (9, 14, 0): + logger.debug("Disabling FusedAttention for determinism reasons on Blackwell") use_fused_attention = False + fused_attention_backend = None # use_flash_attention may have been set above use_flash_attention_2 = use_flash_attention and use_flash_attention_2 diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 3fdf8b14fd..179c80a656 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -551,17 +551,23 @@ def bulk_reload_group(self, group_to_reload): buffer_idx = 0 double_buffer_idx = group_to_reload % 2 + main_stream = torch.cuda.current_stream() + with torch.cuda.stream(self.h2d_stream): # move back tensors for tensor_label, state in self.tensor_tag_to_state.items(): group_id, _ = tensor_label if group_id == group_to_reload: - if self.double_buffering: - reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx] - else: - reload_buffer = None if isinstance(state, tuple): + if self.double_buffering: + reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx] + else: + with torch.cuda.stream(main_stream): + reload_buffer = torch.empty_like( + state[1], device=torch.cuda.current_device() + ) + recovered_tensor = SynchronizedGroupOffloadHandler.reload( state, True, reload_buffer ) @@ -570,14 +576,18 @@ def bulk_reload_group(self, group_to_reload): elif isinstance(state, list): tensor_list = [] for state_tuple in state: - if self.double_buffering: - reload_buffer = self.reload_double_buffer[double_buffer_idx][ - buffer_idx - ] - else: - reload_buffer = None if isinstance(state_tuple, tuple): + if self.double_buffering: + reload_buffer = self.reload_double_buffer[double_buffer_idx][ + buffer_idx + ] + else: + with torch.cuda.stream(main_stream): + reload_buffer = torch.empty_like( + state_tuple[1], device=torch.cuda.current_device() + ) + tensor_list.append( SynchronizedGroupOffloadHandler.reload( state_tuple, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d0e92a59bc..a6b65562eb 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -265,6 +265,17 @@ std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Te std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer); +/*************************************************************************************************** + * Dropout + **************************************************************************************************/ + +std::vector dropout_fwd(const py::handle &input, const float dropout_probability, + std::optional out = std::nullopt); + +py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask, + const float dropout_probability, + std::optional grad_input = std::nullopt); + /*************************************************************************************************** * Softmax **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 819d3e5185..e9647b44fe 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -205,11 +205,8 @@ std::tuple, std::vector> bulk_allocate_fp auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); - // in the case where full buffer is empty because local rank receives no tokens for all the experts - // then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob - // but in the case where some experts receive tokens, some not, we want to leverage from_blob - // as much as possible to avoid CPU overhead - if (buffer->data_ptr() == nullptr) { + bool is_empty_shape = product(shape) == 0; + if (buffer->data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); } return at::from_blob( @@ -359,11 +356,8 @@ std::tuple, std::vector> bulk_allocate_mx auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); - // in the case where full buffer is empty because local rank receives no tokens for all the experts - // then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob - // but in the case where some experts receive tokens, some not, we want to leverage from_blob - // as much as possible to avoid CPU overhead - if (buffer->data_ptr() == nullptr) { + bool is_empty_shape = product(shape) == 0; + if (buffer->data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); } return at::from_blob( diff --git a/transformer_engine/pytorch/csrc/extensions/dropout.cpp b/transformer_engine/pytorch/csrc/extensions/dropout.cpp new file mode 100644 index 0000000000..e6f29d0da7 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/dropout.cpp @@ -0,0 +1,89 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/dropout.h" + +#include +#include + +#include + +#include "../common.h" +#include "../extensions.h" +#include "../pybind.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { +namespace pytorch { + +std::vector dropout_fwd(const py::handle &input, float dropout_probability, + std::optional out) { + using namespace transformer_engine::pytorch::detail; + + // Input tensor + const TensorWrapper input_nvte = makeTransformerEngineTensor(input, py::none()); + + // Allocate output tensor if needed + if (!out) { + at::ScalarType dtype = GetATenDType(input_nvte.dtype()); + if (dtype == at::kFloat8_e4m3fn || dtype == at::kFloat8_e5m2) { + dtype = input.attr("dtype").cast(); + } + const auto shape_uint64 = convertShape(input_nvte.shape()); + const std::vector shape_int64(shape_uint64.begin(), shape_uint64.end()); + const auto opts = at::TensorOptions().dtype(dtype).device(torch::kCUDA); + out = at::empty(shape_int64, opts); + } + TensorWrapper out_nvte = makeTransformerEngineTensor(*out); + + // Mask tensor + auto mask_pyt = allocateTorchTensor(input_nvte.numel() / 8, DType::kByte); + auto mask_nvte = makeTransformerEngineTensor(mask_pyt); + + // RNG state tensor + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState philox_args; + { + std::lock_guard lock(gen->mutex_); + constexpr int64_t rng_elts_per_thread = 4; + philox_args = gen->philox_cuda_state(rng_elts_per_thread); + } + auto rng_state_pyt = allocateTorchTensor(2, DType::kInt64); + NVTE_SCOPED_GIL_RELEASE({ + nvte_extract_seed_and_offset( + reinterpret_cast(rng_state_pyt.data_ptr()), philox_args.captured_, + philox_args.seed_.ptr, philox_args.seed_.val, philox_args.offset_.ptr, + philox_args.offset_.val, philox_args.offset_intragraph_, at::cuda::getCurrentCUDAStream()); + }); + auto rng_state_nvte = makeTransformerEngineTensor(rng_state_pyt); + + // Launch kernel + NVTE_SCOPED_GIL_RELEASE({ + nvte_dropout_fwd(input_nvte.data(), out_nvte.data(), mask_nvte.data(), rng_state_nvte.data(), + dropout_probability, at::cuda::getCurrentCUDAStream()); + }); + + return {py::cast(std::move(*out)), py::cast(mask_pyt)}; +} + +py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask, + const float dropout_probability, std::optional grad_input) { + const auto grad_output_nvte = makeTransformerEngineTensor(grad_output); + const auto mask_nvte = makeTransformerEngineTensor(mask); + if (!grad_input) { + grad_input = at::empty_like(grad_output); + } + auto grad_input_nvte = makeTransformerEngineTensor(*grad_input); + NVTE_SCOPED_GIL_RELEASE({ + nvte_dropout_bwd(grad_output_nvte.data(), mask_nvte.data(), grad_input_nvte.data(), + dropout_probability, at::cuda::getCurrentCUDAStream()); + }); + return py::cast(std::move(*grad_input)); +} + +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index f4768bb9ba..b9f91c7195 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -93,6 +93,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans bool use_split_accumulator, CommOverlapCore* comm_overlap, std::optional comm_type, MaybeTensor extra_output, bool bulk_overlap, float alpha, std::optional beta) { + using namespace transformer_engine::pytorch::detail; + // Input tensors NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); @@ -123,10 +125,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans "into D tensor. Beta has nothing to be applied to."); } + DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); // Output tensor TensorWrapper D_tensor; if (D.is_none()) { - DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); } else { D_tensor = makeTransformerEngineTensor(D, quantizer); @@ -139,12 +141,33 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } + // maintain unquantized tensor in case we need unfused quantization support. + TensorWrapper unquantized_D_tensor; + py::object unquantized_out; + // Unfused quantization is needed in the following cases + // 1. Inputs: BF16, Output: FP8 (GEMM output has to be BF16, so FP8 quantization needed after that) + // 2. Inputs: FP8, Output: FP8 (For any quantization apart from delayed scaling, + // GEMM Output needs to be in BF16, to allow for unfused quantization) + bool unfused_quantization_needed; + if (low_precision) { + unfused_quantization_needed = !quantizer.is_none() && !IsFloat8Quantizers(quantizer.ptr()); + } else { + unfused_quantization_needed = !quantizer.is_none(); + } + + if (unfused_quantization_needed) { + NoneQuantizer q{none}; + std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype); + } + TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; + // Bias tensor TensorWrapper bias_tensor; MaybeTensor bias_grad = std::nullopt; if (bias.has_value()) { if (grad) { - auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA); + auto opts = + torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA); bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); bias_tensor = makeTransformerEngineTensor(*bias_grad); } else { @@ -157,7 +180,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Activation input tensor MaybeTensor pre_gelu_out = std::nullopt; - DType gelu_type = low_precision ? bias_type : D_tensor.dtype(); + DType gelu_type = low_precision ? bias_type : out_tensor.dtype(); if (gelu) { if (!grad) { auto dtype = GetATenDType(gelu_type); @@ -210,7 +233,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Direct GEMM call to the correct overlap if (bulk_overlap) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, comm_type.value(), extra_output_tensor, main_stream); @@ -218,14 +241,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else if (comm_type.value() == CommOverlapType::AG) { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -234,14 +257,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -251,15 +274,15 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(), + nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), alpha, *beta, use_split_accumulator, num_math_sms, main_stream); }); } } else { - if (D_tensor.numel() != 0 && !accumulate) { - D_tensor.zero_(main_stream); + if (out_tensor.numel() != 0 && !accumulate) { + out_tensor.zero_(main_stream); } if (bias.has_value()) { if (bias->numel() != 0 && grad) { @@ -267,7 +290,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + if (unfused_quantization_needed) my_quantizer->quantize(unquantized_D_tensor, D_tensor); // Pack outputs std::vector out; out.emplace_back(std::move(D)); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 6442b05da1..541b16848e 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -305,6 +305,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"), py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd"); + // Dropout + m.def("dropout_fwd", transformer_engine::pytorch::dropout_fwd, "Dropout forward with 8-bit RNG", + py::arg("input"), py::arg("dropout_probability"), py::arg("out") = std::nullopt); + m.def("dropout_bwd", transformer_engine::pytorch::dropout_bwd, "Dropout backward with 8-bit RNG", + py::arg("grad_output"), py::arg("mask"), py::arg("dropout_probability"), + py::arg("grad_input") = std::nullopt); + // Misc m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version, "Get cublasLt version", py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0c75789ed9..e04d424a36 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -96,16 +96,6 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); } std::pair Float8Quantizer::create_tensor( @@ -318,17 +308,6 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); - // quantize output and its transpose - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); } std::pair Float8CurrentScalingQuantizer::create_tensor( @@ -561,20 +540,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti this->all_gather_usage = quantizer.attr("all_gather_usage").cast(); } -void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { - // Change the rowwise and columnwise_data to the configured dtype. - // May be a switch between E5M2 and E4M3. - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); -} +void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype) const { @@ -916,18 +882,7 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize this->dtype = quantizer.attr("dtype").cast(); } -void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); -} +void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, DType dtype) const { diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index eda18a185b..f0fe557c07 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -850,7 +850,7 @@ def make_graphed_callables( num_warmup_iters: int = 3, allow_unused_input: bool = False, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, - fp8_enabled: bool = False, + fp8_enabled: SingleOrTuple[bool] = False, fp8_calibrating: bool = False, fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, @@ -896,8 +896,9 @@ def make_graphed_callables( FP8-related parameters ---------------------- - fp8_enabled: bool, default = `True` - whether or not to enable fp8 + fp8_enabled: (tuple of) bool, default = `False` + whether or not to enable fp8. + If tuple, the length must match the number of modules. fp8_calibrating: bool, default = `False` calibration mode allows collecting statistics such as amax and scale data of fp8 tensors even when executing without fp8 enabled. This is @@ -919,17 +920,25 @@ def make_graphed_callables( """ set_capture_start() - if fp8_enabled and fp8_recipe is None: - fp8_recipe = get_default_fp8_recipe() - elif not fp8_enabled: - fp8_recipe = None - # Handle single module. just_one_callable = False if not isinstance(modules, tuple): just_one_callable = True modules = (modules,) + if not isinstance(fp8_enabled, tuple): + assert isinstance(fp8_enabled, bool), "fp8_enabled must be a bool or a tuple of bools" + fp8_enabled = (fp8_enabled,) * len(modules) + else: + assert len(fp8_enabled) == len( + modules + ), f"fp8_enabled length ({len(fp8_enabled)}) must match modules length ({len(modules)})" + if any(fp8_enabled) and fp8_recipe is None: + fp8_recipe = get_default_fp8_recipe() + elif not any(fp8_enabled): + fp8_recipe = None + module_uses_fp8 = dict(zip((id(m) for m in modules), fp8_enabled)) + # Store FP8 tensors to reset later. saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) @@ -944,15 +953,15 @@ def wrap_autocast(block): old_call_funcs[block_cls] = block_cls.__call__ # Wrap the original call function of the module class. - def call_func(*args, **kwargs): + def call_func(self, *args, **kwargs): with fp8_autocast( - enabled=fp8_enabled, + enabled=module_uses_fp8.get(id(self), False), calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, fp8_group=fp8_group, _graph=True, ): - outputs = old_call_funcs[block_cls](*args, **kwargs) + outputs = old_call_funcs[block_cls](self, *args, **kwargs) return outputs block_cls.__call__ = call_func diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index 5074d32aa2..ac682190c2 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -11,4 +11,4 @@ from .rmsnorm import RMSNorm from .fp8_padding import Fp8Padding from .fp8_unpadding import Fp8Unpadding -from .base import initialize_ub, destroy_ub +from .base import initialize_ub, destroy_ub, UserBufferQuantizationMode diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 5d04b29f71..a6275abd19 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -8,6 +8,7 @@ import os import pickle import warnings +from enum import Enum from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager @@ -49,7 +50,7 @@ from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled -__all__ = ["initialize_ub", "destroy_ub"] +__all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"] _2X_ACC_FPROP = False _2X_ACC_DGRAD = True @@ -63,6 +64,15 @@ layers_atomic_ring_exchange = [] +class UserBufferQuantizationMode(Enum): + """ + UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer. + """ + + NONE = "none" + FP8 = "fp8" + + def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: @@ -111,8 +121,9 @@ def initialize_ub( shape: list, tp_size: int, use_fp8: bool = False, + quantization_modes: List[UserBufferQuantizationMode] = None, dtype: torch.dtype = torch.bfloat16, - ub_cfgs: Optional[dict] = None, + ub_cfgs: Optional[Union[dict, List[dict]]] = None, bootstrap_backend: Union[str, torch.distributed.Backend] = None, ) -> None: r""" @@ -128,7 +139,11 @@ def initialize_ub( tp_size : int number of GPUs in the tensor-parallel process group use_fp8 : bool = False - allocate the communication buffer for FP8 GEMM inputs/outputs + allocate the communication buffer for FP8 GEMM inputs/outputs. + DEPRECATED: Please use `quantization_modes` instead. + quantization_modes : List[UserBufferQuantizationMode] = None + if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list. + falls back to the legacy `use_fp8` parameter if `None` is provided. dtype : torch.dtype = torch.bfloat16 non-FP8 data type of the communication buffer when `use_fp8 = False` ub_cfgs: dict = None @@ -152,6 +167,7 @@ def initialize_ub( for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", "fc2_fprop", "fc2_wgrad"]`. + a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes` bootstrap_backend : str = None `torch.distributed` communication backend for the all-gather, broadcast and barrier collectives during Userbuffers initialization. Not all backends are @@ -168,6 +184,28 @@ def initialize_ub( + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." ) + if not quantization_modes: + warnings.warn( + "Initializing Userbuffers with use_fp8 is deprecated. Please use quantization_modes" + " instead.", + DeprecationWarning, + ) + quantization_modes = [ + UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE + ] + else: + assert isinstance(quantization_modes, list), "quantization_modes must be a list" + assert all( + isinstance(mode, UserBufferQuantizationMode) for mode in quantization_modes + ), "quantization_modes must be a list of UserBufferQuantizationMode" + + if isinstance(ub_cfgs, dict) or ub_cfgs is None: + ub_cfgs = [ub_cfgs] * len(quantization_modes) + else: + assert len(ub_cfgs) == len( + quantization_modes + ), "Number of ub_cfgs settings must match number of quantization configurations" + global _ub_communicators assert _ub_communicators is None, "UB communicators are already initialized." _ub_communicators = {} @@ -309,6 +347,7 @@ def get_default_config(name): def add_ub( name: str, + quantization_mode: UserBufferQuantizationMode, method: str, is_reduce_scatter: bool, num_sm: int = 16, @@ -327,7 +366,9 @@ def add_ub( warnings.warn( "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." ) - assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." + assert ( + quantization_mode == UserBufferQuantizationMode.FP8 + ), "Atomic GEMM overlap supported only for FP8 GEMM." if method in ("bulk", "external"): warnings.warn( f"At {name}, atoimic GEMM not is supported for a bulk overlap." @@ -367,7 +408,11 @@ def add_ub( f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" ) - buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype + buffer_dtype = ( + torch.uint8 + if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf) + else dtype + ) if method == "ring_exchange": ub_obj = tex.CommOverlapP2P( shape, # Communication buffer shape @@ -401,38 +446,47 @@ def add_ub( comm_priority=comm_priority, rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, ) - _ub_communicators[name] = ub_obj - - if ub_cfgs is not None: - for name in dgrad_reduce_scatter_overlap: - if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk": - wgrad_name = name.replace("dgrad", "wgrad") - assert wgrad_name not in ub_cfgs - layers_reduce_scatter_overlap.remove(wgrad_name) - layers_all_gather_overlap.remove(name) - layers_reduce_scatter_overlap.append(name) - methods["bulk"].remove(name) - new_method = ub_cfgs[name]["method"] - methods[new_method].append(name) - - for name in ( - methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"] - ): - ub_cfg = get_default_config(name) - if ub_cfgs is not None and name in ub_cfgs: - fp8_buf = (name in layers_all_gather_overlap) or ( - ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"] - ) - ub_cfg.update(ub_cfgs[name]) - ub_cfg["fp8_buf"] = fp8_buf - add_ub(name, **ub_cfg) + _ub_communicators[(name, quantization_mode)] = ub_obj + + for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs): + if user_ub_cfg is not None: + for name in dgrad_reduce_scatter_overlap: + if ( + name in user_ub_cfg + and "method" in user_ub_cfg[name] + and user_ub_cfg[name]["method"] != "bulk" + ): + wgrad_name = name.replace("dgrad", "wgrad") + assert wgrad_name not in user_ub_cfg + layers_reduce_scatter_overlap.remove(wgrad_name) + layers_all_gather_overlap.remove(name) + layers_reduce_scatter_overlap.append(name) + methods["bulk"].remove(name) + new_method = user_ub_cfg[name]["method"] + methods[new_method].append(name) + + for name in ( + methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"] + ): + ub_cfg = get_default_config(name) + if user_ub_cfg is not None and name in user_ub_cfg: + fp8_buf = (name in layers_all_gather_overlap) or ( + user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"] + ) + ub_cfg.update(user_ub_cfg[name]) + ub_cfg["fp8_buf"] = fp8_buf + add_ub(name, quantization_mode, **ub_cfg) -def get_ub(name: str): +def get_ub(name: str, use_fp8: bool): """Get userbuffer communicator corresponding to give key.""" + # For now use `use_fp8` boolean input as it matches the current design in the modules + # So favour simplicity until the correct design becomes clear. + # This is mainly an internal API so we don't need to worry about future changes + key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE) assert _ub_communicators is not None, "UB manager is not initialized." - assert name in _ub_communicators, f"UB for {name} is not registered." - return _ub_communicators[name] + assert key in _ub_communicators, f"UB for {name} with use_fp8={use_fp8} is not registered." + return _ub_communicators[key] def destroy_ub(): diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 04e3eba7da..cd02f31132 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -173,10 +173,10 @@ def forward( ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output ) if ub_overlap_rs_fprop: - ub_obj = get_ub(ub_name + "_fprop") + ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.RS elif ub_overlap_ag_fprop: - ub_obj = get_ub(ub_name + "_fprop") + ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.AG # Configure quantizer for norm output @@ -575,23 +575,23 @@ def backward( dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -769,7 +769,7 @@ def backward( dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1492,10 +1492,14 @@ def forward( is_first_microbatch = False if self.ub_overlap_rs_fprop: - if get_ub(self.ub_name + "_fprop").is_fp8_ubuf(): + if get_ub( + self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): fp8_output = True if self.ub_overlap_rs_dgrad: - if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): + if get_ub( + self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): fp8_grad = True with torch.cuda.device( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 2e51ac948c..182bf99f86 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -307,7 +307,7 @@ def forward( fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: # Copy into Userbuffers buffer - ub_obj_lnout = get_ub("fc1_fprop") + ub_obj_lnout = get_ub("fc1_fprop", fp8) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_lnout, ln_out, @@ -458,7 +458,7 @@ def forward( ub_obj_fc2out = None reduce_scatter_out = None if ub_overlap_rs: - ub_obj_fc2out = get_ub("fc2_fprop") + ub_obj_fc2out = get_ub("fc2_fprop", fp8) dim_size = list(act_out.size()) dim_size[0] //= tp_world_size dim_size[-1] = fc2_weight.size(0) @@ -740,7 +740,7 @@ def backward( # Note: Cast to expected dtype and perform tensor-parallel communication ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: - ub_obj_fc2_dgrad = get_ub("fc2_dgrad") + ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8) ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, @@ -764,7 +764,7 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) if ctx.ub_bulk_dgrad: - ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_fc1_dgrad, ln_out, @@ -869,7 +869,7 @@ def backward( ub_obj_fc2_dgrad.get_communication_stream() ) - ub_obj_fc2_wgrad = get_ub("fc2_wgrad") + ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1036,16 +1036,16 @@ def fc2_wgrad_gemm( fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] if ctx.ub_overlap_rs_dgrad: # Overlap DGRAD+RS - ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap ln_out all-gather with DGRAD compute - ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap FC1 DGRAD reduce-scatter with WGRAD compute - ub_obj_fc1_wgrad = get_ub("fc1_wgrad") + ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8) ub_type_fc1_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -1539,7 +1539,11 @@ def __init__( self.gemm_gelu_fusion = ( bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and self.activation == "gelu" - and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) + and all( + ("fc1_fprop", use_fp8) not in _ub_communicators + or not get_ub("fc1_fprop", use_fp8).is_atomic_gemm() + for use_fp8 in [False, True] + ) ) self.name = name @@ -1757,7 +1761,7 @@ def forward( fp8_output = False if self.ub_overlap_rs: - if get_ub("fc2_fprop").is_fp8_ubuf(): + if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): fp8_output = True with torch.cuda.device( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 695cbb4e61..2ce6fb4c1d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -145,10 +145,10 @@ def forward( ub_obj = None ub_type = None if ub_overlap_rs_fprop: - ub_obj = get_ub(ub_name + "_fprop") + ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.RS elif ub_overlap_ag_fprop: - ub_obj = get_ub(ub_name + "_fprop") + ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.AG # ------------------------------------------------------ @@ -520,23 +520,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -769,7 +769,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1377,10 +1377,14 @@ def forward( is_first_microbatch = False if self.ub_overlap_rs_fprop: - if get_ub(self.ub_name + "_fprop").is_fp8_ubuf(): + if get_ub( + self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): fp8_output = True if self.ub_overlap_rs_dgrad: - if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): + if get_ub( + self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): fp8_grad = True with torch.cuda.device( diff --git a/transformer_engine/pytorch/onnx_extensions.py b/transformer_engine/pytorch/onnx_extensions.py index 42f5a1d551..38df5fc54a 100644 --- a/transformer_engine/pytorch/onnx_extensions.py +++ b/transformer_engine/pytorch/onnx_extensions.py @@ -112,7 +112,9 @@ def onnx_quantize_fp8_symbolic( doc="TRT FP8 Quantize Linear used for inference.", inputs=[ defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"), - defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for quantization"), + defs.OpSchema.FormalParameter( + "scale_inv", "tensor(float)", "Inverse scale factor for quantization" + ), ], outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")], ) @@ -126,11 +128,10 @@ def onnx_quantize_fp8_symbolic( @torch.library.custom_op("tex::fp8_dequantize", mutates_args=[]) -def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor: +def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor: """Dequantize from Float8Tensor used for inference.""" - scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device) quantizer = Float8Quantizer( - scale_tensor, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3 + 1 / scale_inv, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3 ) quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32) return quantizer_tensor.dequantize() @@ -143,10 +144,9 @@ def _(tensor: torch.Tensor, _) -> torch.Tensor: def onnx_dequantize_fp8_symbolic( - tensor: onnxscript.onnx_types.TensorType, scale: float + tensor: onnxscript.onnx_types.TensorType, scale_inv: onnxscript.onnx_types.TensorType ) -> onnxscript.onnx_types.TensorType: """Symbolic dequantize from Float8Tensor used for inference.""" - scale_inv = op.Constant(value_float=1 / scale) return TRT_FP8DequantizeLinear(tensor, scale_inv) @@ -157,7 +157,9 @@ def onnx_dequantize_fp8_symbolic( doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.", inputs=[ defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"), - defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for dequantization"), + defs.OpSchema.FormalParameter( + "scale_inv", "tensor(float)", "Inverse scale factor for dequantization" + ), ], outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")], ) @@ -166,6 +168,43 @@ def onnx_dequantize_fp8_symbolic( opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema ) +# ONNX FP8 Current Scaling Quantization + + +@torch.library.custom_op("tex::fp8_cs_quantize", mutates_args=[]) +def onnx_cs_quantize_fp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize to FP8 with current scaling; returns (uint8, scale_inv).""" + if tensor.dtype != torch.float32: + tensor = tensor.to(torch.float32) + amax = tensor.abs().max() + eps = torch.tensor(1e-12, dtype=torch.float32, device=tensor.device) + amax = torch.maximum(amax, eps) + fp8_max = torch.tensor(448, dtype=torch.float32, device=tensor.device) + scale = fp8_max / amax + q = torch.ops.tex.fp8_quantize(tensor, scale) + scale_inv = 1 / scale + return q, scale_inv + + +@onnx_cs_quantize_fp8_op.register_fake +def _(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.ones( + 1, dtype=torch.float32, device=tensor.device + ) + + +def onnx_quantize_fp8_cs_symbolic( + tensor: onnxscript.onnx_types.TensorType, +): + """Symbolic quantize with current scaling; computes scale_inv from tensor.""" + # scale_inv = 1 / max(abs(tensor)) + amax = op.ReduceMax(op.Abs(tensor), keepdims=0) + eps = op.Constant(value_float=1.0e-12) + amax = op.Max(amax, eps) + scale_inv = op.Div(amax, op.Constant(value_float=448.0)) + q = TRT_FP8QuantizeLinear(tensor, scale_inv) + return q, scale_inv + # ONNX MXFP8 Quantization @@ -356,6 +395,7 @@ def onnx_attention_mask_func( torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic, torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic, torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic, + torch.ops.tex.fp8_cs_quantize.default: onnx_quantize_fp8_cs_symbolic, torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic, torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic, torch.ops.tex.layernorm.default: onnx_layernorm_symbolic, diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index 958e9b06ce..f0f55322c4 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -8,12 +8,11 @@ from typing import Optional import torch - -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) +import transformer_engine_torch as tex from ...tensor import Quantizer +from ...tensor._internal.float8_tensor_base import Float8TensorBase +from .._common import maybe_autocast_dtype, maybe_dequantize +from ..op import BasicOperation, OperationContext class Dropout(BasicOperation): @@ -27,7 +26,7 @@ class Dropout(BasicOperation): def __init__(self, p: float) -> None: super().__init__() - self.dropout_probability = p + self.dropout_probability: float = p def op_forward( self, @@ -37,21 +36,44 @@ def op_forward( next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: - # Compute dropout if training - out = input_ - is_training = self.training - mask = None - if is_training: + # Output dtype + dtype = maybe_autocast_dtype(default_dtype=input_.dtype) + + # Choose implementation + impl = None + if not self.training: + impl = "evaluation" + elif input_.numel() % 16 == 0 and dtype in (torch.float16, torch.bfloat16): + impl = "fused" + else: + impl = "unfused" + + # Perform dropout + out: torch.Tensor + mask: Optional[torch.Tensor] = None + if impl == "evaluation": + out = input_ + elif impl == "fused": + x = input_ + if not isinstance(x, Float8TensorBase): + x = maybe_dequantize(x, dtype=dtype) + out, mask = tex.dropout_fwd(x, self.dropout_probability) + elif impl == "unfused": + x = maybe_dequantize(input_, dtype=dtype) keep_prob = 1 - self.dropout_probability - mask = torch.empty_like(input_) + mask = torch.empty_like(x) mask.bernoulli_(keep_prob) mask *= 1 / keep_prob - out = out * mask + out = x * mask + else: + raise ValueError(f"Unsupported forward implementation {impl}") # Save context for backward if ctx.requires_grad: ctx.save_for_backward(mask) - ctx.is_training = is_training + ctx.impl = impl + ctx.dropout_probability = self.dropout_probability + ctx.dtype = dtype return out @@ -60,8 +82,21 @@ def op_backward( ctx: OperationContext, grad_output: torch.Tensor, ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass (mask,) = ctx.saved_tensors - grad_input = grad_output - if ctx.is_training: - grad_input = grad_input * mask + + # Perform dropout backward pass + grad_input: torch.Tensor + if ctx.impl == "evaluation": + grad_input = grad_output + elif ctx.impl == "fused": + dy = maybe_dequantize(grad_output, dtype=ctx.dtype) + grad_input = tex.dropout_bwd(dy, mask, ctx.dropout_probability) + elif ctx.impl == "unfused": + dy = maybe_dequantize(grad_output, dtype=ctx.dtype) + grad_input = dy * mask + else: + raise ValueError(f"Unsupported backward implementation {ctx.impl}") + return grad_input, () diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index c595325212..1ecdba6253 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -241,16 +241,16 @@ def _functional_backward( with_dgrad_all_gather_x = False with_wgrad_reduce_scatter_dx = False if tensor_parallel_mode == "row": - ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") + ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute) ub_type_dgrad = CommOverlapType.AG with_dgrad_all_gather_dy = True elif tensor_parallel_mode == "column": if input_requires_grad and weight_requires_grad: with_bulk_overlap = True - ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") + ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute) ub_type_dgrad = CommOverlapType.AG with_dgrad_all_gather_x = True - ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad") + ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute) ub_type_wgrad = CommOverlapType.RS with_wgrad_reduce_scatter_dx = True if ub_comm_wgrad.is_fp8_ubuf(): @@ -258,7 +258,7 @@ def _functional_backward( "Userbuffers reduce-scatter is not supported with FP8 buffers" ) else: - ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") + ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute) ub_type_dgrad = CommOverlapType.RS with_dgrad_reduce_scatter_dx = True if ub_comm_dgrad.is_fp8_ubuf(): @@ -409,7 +409,7 @@ def _functional_backward( # Get the communication stream from the dgrad GEMM to use for the AG dgrad_send_stream, dgrad_recv_stream = ub_comm_dgrad.get_communication_stream() - ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad") + ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute) grad_output_quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 61853f9f41..574642794f 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -189,7 +189,7 @@ def _functional_forward( output_quantizer = None # Get Userbuffers communicator - ub_comm = get_ub(ub_comm_name + "_fprop") + ub_comm = get_ub(ub_comm_name + "_fprop", with_quantized_compute) with_ub_all_gather = tensor_parallel_mode == "column" with_ub_reduce_scatter = tensor_parallel_mode == "row" ub_type = CommOverlapType.AG if with_ub_all_gather else CommOverlapType.RS diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index ae1b5780bb..46543acf28 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -10,14 +10,30 @@ import os import shutil from pathlib import Path - +import platform +import urllib import setuptools +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel +from packaging.version import parse try: + import torch from torch.utils.cpp_extension import BuildExtension except ImportError as e: raise RuntimeError("This package needs Torch to build.") from e +FORCE_BUILD = os.getenv("NVTE_PYTORCH_FORCE_BUILD", "FALSE") == "TRUE" +FORCE_CXX11_ABI = os.getenv("NVTE_PYTORCH_FORCE_CXX11_ABI", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("NVTE_PYTORCH_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +PACKAGE_NAME = "transformer_engine_torch" +BASE_WHEEL_URL = ( + "https://github.com/NVIDIA/TransformerEngine/releases/download/{tag_name}/{wheel_name}" +) +# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as +# torch._C._GLIBCXX_USE_CXX11_ABI +# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 +if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True current_file_path = Path(__file__).parent.resolve() build_tools_dir = current_file_path.parent.parent / "build_tools" @@ -31,13 +47,94 @@ from build_tools.build_ext import get_build_ext from build_tools.utils import copy_common_headers from build_tools.te_version import te_version -from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements +from build_tools.pytorch import ( + setup_pytorch_extension, + install_requirements, + test_requirements, +) os.environ["NVTE_PROJECT_BUILDING"] = "1" CMakeBuildExtension = get_build_ext(BuildExtension, True) +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith("linux"): + return f"linux_{platform.uname().machine}" + if sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + if sys.platform == "win32": + return "win_amd64" + + raise ValueError(f"Unsupported platform: {sys.platform}") + + +def get_wheel_url(): + """Construct the wheel URL for the current platform.""" + torch_version_raw = parse(torch.__version__) + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + nvte_version = te_version() + torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" + cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() + + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build torch, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + torch_cuda_version = parse(torch.version.cuda) + # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 + # to save CI time. Minor versions should be compatible. + torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + cuda_version = f"{torch_cuda_version.major}" + + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f"{PACKAGE_NAME}-{nvte_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" + + wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{nvte_version}", wheel_name=wheel_filename) + + return wheel_url, wheel_filename + + +class CachedWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all grouped gemm installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. + """ + + def run(self): + if FORCE_BUILD: + super().run() + + wheel_url, wheel_filename = get_wheel_url() + print("Guessing wheel URL: ", wheel_url) + try: + urllib.request.urlretrieve(wheel_url, wheel_filename) + + # Make the archive + # Lifted from the root wheel processing command + # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + print("Raw wheel path", wheel_path) + os.rename(wheel_filename, wheel_path) + except (urllib.error.HTTPError, urllib.error.URLError): + print("Precompiled wheel not found. Building from source...") + # If the wheel could not be downloaded, build from source + super().run() + + if __name__ == "__main__": # Extensions common_headers_dir = "common_headers" @@ -50,11 +147,11 @@ # Configure package setuptools.setup( - name="transformer_engine_torch", + name=PACKAGE_NAME, version=te_version(), description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension}, + cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand}, install_requires=install_requirements(), tests_require=test_requirements(), ) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index acc03ba78f..1524584aa7 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -177,7 +177,7 @@ def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: """Function using primitives with ONNX defined translations.""" - out = torch.ops.tex.fp8_dequantize(tensor._data, self.scale.item()) + out = torch.ops.tex.fp8_dequantize(tensor._data, tensor._scale_inv) out = out.to(tensor.dtype) return out @@ -350,15 +350,25 @@ def create_tensor_from_data( def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: """Function using primitives with ONNX defined translations.""" - raise NotImplementedError( - "Float8CurrentScalingQuantizer does not support ONNX quantization yet." + if tensor.dtype != torch.float32: + tensor = tensor.to(torch.float32) + data, scale_inv = torch.ops.tex.fp8_cs_quantize(tensor) + return Float8Tensor( + shape=data.shape, + dtype=torch.float32, + data=data, + fp8_scale_inv=scale_inv, + fp8_dtype=self.dtype, + requires_grad=False, + data_transpose=None, + quantizer=self, ) def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: """Function using primitives with ONNX defined translations.""" - raise NotImplementedError( - "Float8CurrentScalingQuantizer does not support ONNX dequantization yet." - ) + out = torch.ops.tex.fp8_dequantize(tensor._data, tensor._scale_inv) + out = out.to(tensor.dtype) + return out def _canonicalized_amax_reduction_group(self) -> dist_group_type: """Get process group for amax reduction""" From d4c06c58a83e29dd81095dd577d09af2b0a16f86 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 5 Sep 2025 23:44:40 +0000 Subject: [PATCH 02/53] initial draft of changes to get GPT oss based swiglu integrated, gated kernels needs to be fixed Signed-off-by: Varun Thumbe --- .../common/activation/activation_template.h | 10 +- transformer_engine/common/activation/gelu.cu | 8 +- transformer_engine/common/activation/relu.cu | 8 +- .../common/activation/swiglu.cu | 20 ++- .../include/transformer_engine/activation.h | 15 +++ .../common/util/cast_gated_kernels.cuh | 119 ++++++++++-------- transformer_engine/common/util/math.h | 23 ++++ transformer_engine/pytorch/csrc/extensions.h | 3 + .../pytorch/csrc/extensions/activation.cpp | 10 ++ 9 files changed, 147 insertions(+), 69 deletions(-) diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 67f173a4ab..3f701b1560 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -51,22 +51,20 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, } template -void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { +void gated_act_fn(const NVTETensor input, NVTETensor output, Param p, cudaStream_t stream) { using namespace detail; constexpr bool IS_DGATED = false; constexpr NVTETensor grad = nullptr; - - quantize_gated_helper(grad, input, output, stream); + quantize_gated_helper(grad, input, output, p, stream); } template -void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, +void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param p, cudaStream_t stream) { using namespace detail; constexpr bool IS_DGATED = true; - - quantize_gated_helper(grad, input, output, stream); + quantize_gated_helper(grad, input, output, p, stream); } } // namespace transformer_engine diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 0cf43007a7..cea17463bd 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -23,14 +23,14 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + gated_act_fn>(input, output, {}, stream); } void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu); using namespace transformer_engine; - dgated_act_fn, dgelu>(grad, input, output, stream); + dgated_act_fn, dgelu>(grad, input, output, {}, stream); } void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -49,12 +49,12 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + gated_act_fn>(input, output, {}, stream); } void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu); using namespace transformer_engine; - dgated_act_fn, dqgelu>(grad, input, output, stream); + dgated_act_fn, dqgelu>(grad, input, output, {}, stream); } diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index a794b7315f..e7748a8cd6 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -23,14 +23,14 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + gated_act_fn>(input, output, {}, stream); } void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu); using namespace transformer_engine; - dgated_act_fn, drelu>(grad, input, output, stream); + dgated_act_fn, drelu>(grad, input, output, {}, stream); } void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -49,12 +49,12 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + gated_act_fn>(input, output, {}, stream); } void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu); using namespace transformer_engine; - dgated_act_fn, dsrelu>(grad, input, output, stream); + dgated_act_fn, dsrelu>(grad, input, output, {}, stream); } diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 8194964745..d602d9076f 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -23,12 +23,28 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + gated_act_fn>(input, output, {}, stream); } void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dswiglu); using namespace transformer_engine; - dgated_act_fn, dsilu>(grad, input, output, stream); + dgated_act_fn, dsilu>(grad, input, output, {}, stream); } + +void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float alpha, + float min_limit, float max_limit, cudaStream_t stream){ + NVTE_API_CALL(nvte_gptoss_swiglu); + using namespace transformer_engine; + GptOssParam param = {alpha, min_limit, max_limit}; + gated_act_fn>(input, output, param, stream); +} + +void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, float alpha, + float min_limit, float max_limit, cudaStream_t stream){ + NVTE_API_CALL(nvte_gptoss_dswiglu); + using namespace transformer_engine; + GptOssParam param = {alpha, min_limit, max_limit}; + dgated_act_fn, oss_dsilu>(grad, input, output, param, stream); +} \ No newline at end of file diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 49029ed588..0a2d717f79 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -182,6 +182,14 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) * It computes Act(input[N, :H]) x input[N, H:] * \param[in] stream CUDA stream used for the operation. */ + +/* +TODO: Add documentation once the API finalizes. +*/ +void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float alpha, + float min_limit, float max_limit, cudaStream_t stream); + + void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the gated Quick GeLU activation of the input. @@ -239,6 +247,13 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp * \param[in,out] output Outgoing gradient of shape [N, H * 2]. * \param[in] stream CUDA stream used for the operation. */ + +/* +TODO: Add documentation once the API finalizes. +*/ +void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, float alpha, + float min_limit, float max_limit, cudaStream_t stream); + void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 50ff82d85f..ef4d0c5cb4 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -55,7 +55,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const __grid_constant__ CUtensorMap tensor_map_output_act, const __grid_constant__ CUtensorMap tensor_map_output_gate, float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, const size_t rows, const size_t cols) { + const float *const scale_ptr, const size_t rows, const size_t cols, const ParamOP p) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; @@ -176,18 +176,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); const float x = act_elt; - float act_x; - float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; - } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + //TODO: Fix this code for GPT OSS + float act_x=0.0f; + float dact_x=0.0f; + if constexpr(std::is_same::value){ + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } + else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } } + float after_dact = dact_x * grad_elt * gate_elt; float after_dgate = act_x * grad_elt; @@ -197,7 +201,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) amax = fmaxf(amax, fabsf(after_dact)); amax = fmaxf(amax, fabsf(after_dgate)); } else { - const float after_act = ActOP(act_elt, {}) * gate_elt; + const float after_act = ActOP(act_elt, p) * gate_elt; out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); amax = fmaxf(amax, fabsf(after_act)); } @@ -300,7 +304,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { + const size_t scale_stride_colwise, + const ParamOP p) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; @@ -480,21 +485,25 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); const float x = act_elt; - float act_x; - float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; - } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + //TODO: Fix this code for GPT OSS + float act_x=0.0f; + float dact_x=0.0f; + if constexpr(std::is_same::value){ + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } + else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } } + after_act_elt = dact_x * grad_elt * gate_elt; after_gate_elt = act_x * grad_elt; } else { - after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_act_elt = ActOP(act_elt, p) * gate_elt; } // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 if constexpr (!std::is_same_v) { @@ -723,17 +732,21 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad.data.elt[e]); const float x = act_elt; - float act_x; - float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; - } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + // TODO: Fix this code for GPT OSS + float act_x=0.0f; + float dact_x=0.0f; + if constexpr(std::is_same::value){ + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } + else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } } + after_act_elt = dact_x * grad_elt * gate_elt; after_gate_elt = act_x * grad_elt; after_act_rowwise[j] = after_act_elt; @@ -883,7 +896,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) template -void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, +void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, cudaStream_t stream) { checkCuDriverContext(stream); @@ -958,14 +971,14 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols); + cols, p); NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) ); // NOLINT(*) } template -void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, +void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, cudaStream_t stream) { checkCuDriverContext(stream); @@ -1096,7 +1109,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: @@ -1113,7 +1126,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::BIDIMENSIONAL: @@ -1130,7 +1143,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaGetLastError()); break; }); // NOLINT(*) @@ -1138,7 +1151,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out } template -void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { +void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) { CheckInputTensor(input, "gated_act_input"); CheckOutputTensor(*output, "gated_act_output"); NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(), @@ -1165,7 +1178,7 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { reinterpret_cast(output->scale.dptr), reinterpret_cast(output->amax.dptr), reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), - output->flat_last_dim(), {}, stream); + output->flat_last_dim(), p, stream); } else { NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); }); // NOLINT(*) @@ -1174,7 +1187,7 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { template -void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { +void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) { CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(input, "dgated_act_input"); CheckOutputTensor(*output, "dgated_act_output"); @@ -1203,7 +1216,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt reinterpret_cast(output->scale.dptr), reinterpret_cast(output->amax.dptr), reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), - grad.flat_last_dim(), {}, stream); + grad.flat_last_dim(), p, stream); } else { NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); }); // NOLINT(*) @@ -1212,7 +1225,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt template -void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, +void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, cudaStream_t stream) { constexpr bool allow_empty = false; CheckInputTensor(gated_input, "gated_input"); @@ -1252,17 +1265,17 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu if (is_delayed_tensor_scaling(output->scaling_mode)) { if (use_tma_kernels) { - cast_fp8_gated(grad, gated_input, output, stream); + cast_fp8_gated(grad, gated_input, output, p, stream); } else { if constexpr (IS_DGATED) { - cast_dgated(grad, gated_input, output, stream); + cast_dgated(grad, gated_input, output, p, stream); } else { - cast_gated(gated_input, output, stream); + cast_gated(gated_input, output, p, stream); } } } else if (is_mxfp_scaling(output->scaling_mode)) { if (use_tma_kernels) { - cast_mxfp8_gated(grad, gated_input, output, stream); + cast_mxfp8_gated(grad, gated_input, output, p, stream); } else { NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", "by 32, got input of shape ", gated_input.data.shape); @@ -1277,7 +1290,7 @@ namespace detail { template -void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, +void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, ParamOP p, cudaStream_t stream) { using namespace gated_kernels; Tensor grad_empty_tensor; @@ -1287,13 +1300,13 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, if (is_supported_by_CC_100()) { quantize_gated(grad_tensor, gated_input_tensor, - output_tensor, stream); + output_tensor, p, stream); } else { if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { if constexpr (IS_DGATED) { - cast_dgated(grad_tensor, gated_input_tensor, output_tensor, stream); + cast_dgated(grad_tensor, gated_input_tensor, output_tensor, p, stream); } else { - cast_gated(gated_input_tensor, output_tensor, stream); + cast_gated(gated_input_tensor, output_tensor, p, stream); } } else { // MX scaling diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 2d425d6753..58ee519ad4 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -11,6 +11,12 @@ namespace transformer_engine { struct Empty {}; +struct GptOssParam{ + float alpha; + float min_limit; + float max_limit; +}; + template __device__ inline OType gelu(const IType val, const Empty&) { const float cval = val; @@ -57,12 +63,29 @@ __device__ inline OType silu(const IType val, const Empty& e) { return cval * sigmoid(cval, e); } +template +__device__ inline OType oss_silu(const IType val, const GptOssParam& p) { + const Empty e = {}; + const float cval = max(min(val, p.min_limit), p.max_limit); // Clamping + return cval * sigmoid(p.alpha * cval, e); +} + template __device__ inline OType dsilu(const IType val, const Empty& e) { const float cval = val; return cval * dsigmoid(cval, e) + sigmoid(cval, e); } +template +__device__ inline OType oss_dsilu(const IType val, const GptOssParam& p) { + const Empty e = {}; + const bool dclamp_val = (val <= p.max_limit) && (val >= p.min_limit); + const float clamp_val = max(min(val, p.min_limit), p.max_limit); + const float dsilu_val = (p.alpha * clamp_val) * dsigmoid(p.alpha * clamp_val, e) + + sigmoid(p.alpha * clamp_val, e); + return dclamp_val ? dsilu_val: 0.0f; +} + template __device__ inline OType relu(IType value, const Empty&) { return fmaxf(value, 0.f); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index a6b65562eb..368547a73b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -197,6 +197,9 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer); py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit); + +py::object gpt_oss_dswiglu(const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit); /*************************************************************************************************** * LayerNorm **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 7851cc5ffc..daaca9ef93 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -8,6 +8,7 @@ #include "common.h" #include "pybind.h" + namespace transformer_engine::pytorch { template @@ -183,4 +184,13 @@ py::object swiglu(const at::Tensor& input, py::handle quantizer) { py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { return dactivation_helper(grad, input, quantizer); } + +py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit){ + +} + +py::object gpt_oss_dswiglu(const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit){ + +} + } // namespace transformer_engine::pytorch From 1f596af8a9808b393adf45a1a9958c587f55bf25 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Sat, 6 Sep 2025 00:51:29 +0000 Subject: [PATCH 03/53] redundant implementation for the pytorch to te hook up, refactoring to be done later Signed-off-by: Varun Thumbe --- .../common/util/cast_gated_kernels.cuh | 36 ++++++--- transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/activation.cpp | 81 ++++++++++++++++++- 3 files changed, 104 insertions(+), 15 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index ef4d0c5cb4..7fe38f2b7b 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -176,10 +176,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); const float x = act_elt; - //TODO: Fix this code for GPT OSS - float act_x=0.0f; - float dact_x=0.0f; - if constexpr(std::is_same::value){ + float act_x; + float dact_x; + if constexpr(std::is_same::value){ + //TODO: Fix this code for GPT OSS + act_x = 0.0f; + dact_x = 0.0f; + } + else{ if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); act_x = x * s; @@ -485,10 +489,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); const float x = act_elt; - //TODO: Fix this code for GPT OSS - float act_x=0.0f; - float dact_x=0.0f; - if constexpr(std::is_same::value){ + float act_x; + float dact_x; + if constexpr(std::is_same::value){ + //TODO: Fix this code for GPT OSS + act_x=0.0f; + dact_x=0.0f; + } + else{ if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); act_x = x * s; @@ -732,10 +740,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad.data.elt[e]); const float x = act_elt; - // TODO: Fix this code for GPT OSS - float act_x=0.0f; - float dact_x=0.0f; - if constexpr(std::is_same::value){ + float act_x; + float dact_x; + if constexpr(std::is_same::value){ + // TODO: Fix this code for GPT OSS + act_x = 0.0f; + dact_x = 0.0f; + } + else{ if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); act_x = x * s; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 368547a73b..63fe70235f 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -199,7 +199,7 @@ py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle q py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit); -py::object gpt_oss_dswiglu(const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit); +py::object gpt_oss_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit); /*************************************************************************************************** * LayerNorm **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index daaca9ef93..f7fd994832 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -186,11 +186,88 @@ py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle q } py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit){ + init_extension(); + // Input tensor + auto input_tensor = input.contiguous(); + const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + + // Construct output tensor + auto quantizer_cpp = convert_quantizer(quantizer); + const auto input_shape = input_cpp.shape(); + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); + auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); + + // Compute activation + if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || + detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Compute activation directly + NVTE_SCOPED_GIL_RELEASE( + { nvte_gptoss_swiglu(input_cpp.data(), out_cpp.data(), alpha, min_limit, max_limit, at::cuda::getCurrentCUDAStream()); }); + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // Compute activation in high-precision fused together with amax, then quantize. + auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); + auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE( + { nvte_gptoss_swiglu(input_cpp.data(), temp_cpp.data(), alpha, min_limit, max_limit, at::cuda::getCurrentCUDAStream()); }); + quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); + } else { + // Compute activation in high-precision, then quantize + + auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE( + { nvte_gptoss_swiglu(input_cpp.data(), temp_cpp.data(), alpha, min_limit, max_limit, at::cuda::getCurrentCUDAStream()); }); + quantizer_cpp->quantize(temp_cpp, out_cpp); + } + return out_py; } -py::object gpt_oss_dswiglu(const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit){ - +py::object gpt_oss_dswiglu(const at::Tensor &grad_output, const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit){ + init_extension(); + // Grad output and input tensors + auto grad_output_tensor = grad_output.contiguous(); + auto input_tensor = input.contiguous(); + const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor); + const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + + // Construct grad input tensor + auto quantizer_cpp = convert_quantizer(quantizer); + const auto input_shape_te = input_cpp.shape(); + const std::vector input_shape(input_shape_te.data, + input_shape_te.data + input_shape_te.ndim); + auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); + auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); + + // Compute activation backward + if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || + detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Compute activation backward directly + NVTE_SCOPED_GIL_RELEASE({ + nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), alpha, min_limit, max_limit, + at::cuda::getCurrentCUDAStream()); + }); + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // Compute activation backward in high-precision fused together with amax, then quantize. + auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); + auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), alpha, min_limit, max_limit, + at::cuda::getCurrentCUDAStream()); + }); + quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); + } else { + // Compute activation backward in high-precision, then quantize + auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), alpha, min_limit, max_limit, + at::cuda::getCurrentCUDAStream()); + }); + quantizer_cpp->quantize(temp_cpp, grad_input_cpp); + } + + return grad_input_py; + } } // namespace transformer_engine::pytorch From 42f85c305f8938b5fdea872db036fd84e4903d65 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 8 Sep 2025 03:46:28 +0000 Subject: [PATCH 04/53] all gated kernels modified, pytest working for oss swiglu Signed-off-by: Varun Thumbe --- tests/pytorch/test_fusible_ops.py | 70 +++++++++++++++++++ .../common/activation/swiglu.cu | 12 ++-- .../include/transformer_engine/activation.h | 6 +- .../common/util/cast_gated_kernels.cuh | 69 ++++++++++++++---- transformer_engine/common/util/math.h | 30 +++++--- .../common/util/vectorized_pointwise.h | 23 ++++-- transformer_engine/pytorch/csrc/extensions.h | 4 +- .../pytorch/csrc/extensions/activation.cpp | 17 ++--- .../pytorch/csrc/extensions/pybind.cpp | 4 ++ .../pytorch/ops/basic/__init__.py | 2 +- .../pytorch/ops/basic/activation.py | 38 ++++++++++ 11 files changed, 226 insertions(+), 49 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index bb07e87d98..646f3ad23c 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1707,6 +1707,76 @@ def test_swiglu( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantize_forward", (False, True)) + @pytest.mark.parametrize("quantize_backward", (False, True)) + def test_gpt_oss_swiglu( + self, + *, + out_shape: Iterable[int] = (32, 32), + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantize_forward: bool, + quantize_backward: bool, + ): + print(_quantization_list) + # Tensor dimensions + in_shape = list(out_shape) + in_shape[-1] *= 2 + + # Skip invalid configurations + quantized_compute = quantization is not None + if not quantized_compute and (quantize_forward or quantize_backward): + pytest.skip("Quantization scheme has not been provided") + maybe_skip_quantization(quantization, dims=in_shape, device=device) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x_glu, x_linear = x_ref.chunk(2, dim=-1) + x_glu = x_glu.clamp(min=None, max=7.0) + x_linear = x_linear.clamp(min=-7.0, max=7.0) + out_glu = x_glu * torch.sigmoid(1.702 * x_glu) + y_ref = out_glu * (x_linear + 1) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + recipe = make_recipe(quantization) + + forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=quantize_backward), + te_ops.GptOssSwiglu(limit=7.0), + te_ops.Quantize(forward=quantize_forward, backward=False)) + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = forward(x_test) + + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + + @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) @pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2))) @pytest.mark.parametrize("dtype", _dtypes) diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index d602d9076f..0081219027 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -33,18 +33,18 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp dgated_act_fn, dsilu>(grad, input, output, {}, stream); } -void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float alpha, - float min_limit, float max_limit, cudaStream_t stream){ +void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, + float limit, cudaStream_t stream){ NVTE_API_CALL(nvte_gptoss_swiglu); using namespace transformer_engine; - GptOssParam param = {alpha, min_limit, max_limit}; + GptOssParam param = {limit}; gated_act_fn>(input, output, param, stream); } -void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, float alpha, - float min_limit, float max_limit, cudaStream_t stream){ +void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + float limit, cudaStream_t stream){ NVTE_API_CALL(nvte_gptoss_dswiglu); using namespace transformer_engine; - GptOssParam param = {alpha, min_limit, max_limit}; + GptOssParam param = {limit}; dgated_act_fn, oss_dsilu>(grad, input, output, param, stream); } \ No newline at end of file diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 0a2d717f79..c735be8926 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -186,8 +186,7 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) /* TODO: Add documentation once the API finalizes. */ -void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float alpha, - float min_limit, float max_limit, cudaStream_t stream); +void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, cudaStream_t stream); void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); @@ -251,8 +250,7 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp /* TODO: Add documentation once the API finalizes. */ -void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, float alpha, - float min_limit, float max_limit, cudaStream_t stream); +void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, float limit, cudaStream_t stream); void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 7fe38f2b7b..d7994267b7 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -161,7 +161,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; OType *out_act_sh_curr = out_act_sh + buff * buff_elems; OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; - #pragma unroll for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; @@ -171,6 +170,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_elt = static_cast(in_act_sh_curr[shmem_idx]); float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + float dgate_elt = 1.0f; // gating is ideally an identity function + if constexpr(std::is_same::value){ + // In case of GPT OSS, clamp the activation and gate values + const float limit = p.limit; + dgate_elt = gate_elt <= limit && gate_elt >= -limit ? 1.0f : 0.0f; // Derivative of clamp + gate_elt = min(max(-limit, gate_elt), limit) + 1; + } if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); @@ -179,9 +185,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_x; float dact_x; if constexpr(std::is_same::value){ - //TODO: Fix this code for GPT OSS - act_x = 0.0f; - dact_x = 0.0f; + const float limit = p.limit; + const float x = min(act_elt, limit); + const float s = sigmoidf(1.702 * x); + act_x = x * s; + if(x <= limit){ + dact_x = s + s * (1 - s) * 1.702 * x; + } + else{ + dact_x = 0.0f; + } } else{ if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { @@ -197,7 +210,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float after_dact = dact_x * grad_elt * gate_elt; - float after_dgate = act_x * grad_elt; + float after_dgate = act_x * grad_elt * dgate_elt; out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); @@ -485,16 +498,29 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate_sh[shmem_offset_colwise]); float after_act_elt; float after_gate_elt; - + float dgate_elt = 1.0f; // gating is ideally an identity function + if constexpr(std::is_same::value){ + // In case of GPT OSS, clamp the activation and gate values + const float limit = p.limit; + dgate_elt = gate_elt <= limit && gate_elt >= -limit ? 1.0f : 0.0f; // Derivative of clamp + gate_elt = min(max(-limit, gate_elt), limit) + 1.0f; + } if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); const float x = act_elt; float act_x; float dact_x; if constexpr(std::is_same::value){ - //TODO: Fix this code for GPT OSS - act_x=0.0f; - dact_x=0.0f; + const float limit = p.limit; + const float x = min(act_elt, limit); + const float s = sigmoidf(1.702 * x); + act_x = x * s; + if(x <= limit){ + dact_x = s + s * (1 - s) * 1.702 * x; + } + else{ + dact_x = 0.0f; + } } else{ if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { @@ -509,7 +535,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } after_act_elt = dact_x * grad_elt * gate_elt; - after_gate_elt = act_x * grad_elt; + after_gate_elt = act_x * grad_elt * dgate_elt; } else { after_act_elt = ActOP(act_elt, p) * gate_elt; } @@ -736,16 +762,29 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate.data.elt[e]); float after_act_elt; float after_gate_elt; - + float dgate_elt = 1.0f; + if constexpr(std::is_same::value){ + // In case of GPT OSS, clamp the activation and gate values + const float limit = p.limit; + dgate_elt = gate_elt <= limit && gate_elt >= -limit ? 1.0f : 0.0f; // Derivative of clamp + gate_elt = min(max(-limit, gate_elt), limit) + 1.0f; + } if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad.data.elt[e]); const float x = act_elt; float act_x; float dact_x; if constexpr(std::is_same::value){ - // TODO: Fix this code for GPT OSS - act_x = 0.0f; - dact_x = 0.0f; + const float limit = p.limit; + const float x = min(act_elt, limit); + const float s = sigmoidf(1.702 * x); + act_x = x * s; + if(x <= limit){ + dact_x = s + s * (1 - s) * 1.702 * x; + } + else{ + dact_x = 0.0f; + } } else{ if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { @@ -760,7 +799,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } after_act_elt = dact_x * grad_elt * gate_elt; - after_gate_elt = act_x * grad_elt; + after_gate_elt = act_x * grad_elt * dgate_elt; after_act_rowwise[j] = after_act_elt; after_gate_rowwise[j] = after_gate_elt; } else { diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 58ee519ad4..22a480d13b 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -12,9 +12,7 @@ namespace transformer_engine { struct Empty {}; struct GptOssParam{ - float alpha; - float min_limit; - float max_limit; + float limit; }; template @@ -63,11 +61,18 @@ __device__ inline OType silu(const IType val, const Empty& e) { return cval * sigmoid(cval, e); } +template +__device__ inline OType clamp(const IType val, const float min_limit, const float max_limit) { + const float cval = val; + return max(min(cval, max_limit), min_limit); +} + template __device__ inline OType oss_silu(const IType val, const GptOssParam& p) { const Empty e = {}; - const float cval = max(min(val, p.min_limit), p.max_limit); // Clamping - return cval * sigmoid(p.alpha * cval, e); + const float cval = clamp(val, + -std::numeric_limits::infinity(), p.limit); // Clamping + return qgelu(cval, e); } template @@ -76,13 +81,20 @@ __device__ inline OType dsilu(const IType val, const Empty& e) { return cval * dsigmoid(cval, e) + sigmoid(cval, e); } +template +__device__ inline OType dclamp(const IType val, const float min_limit, const float max_limit) { + const float cval = val; + return cval <= max_limit && cval >= min_limit; +} + template __device__ inline OType oss_dsilu(const IType val, const GptOssParam& p) { const Empty e = {}; - const bool dclamp_val = (val <= p.max_limit) && (val >= p.min_limit); - const float clamp_val = max(min(val, p.min_limit), p.max_limit); - const float dsilu_val = (p.alpha * clamp_val) * dsigmoid(p.alpha * clamp_val, e) - + sigmoid(p.alpha * clamp_val, e); + const bool dclamp_val = dclamp(val, + -std::numeric_limits::infinity(), p.limit); + const float clamp_val = clamp(val, + -std::numeric_limits::infinity(), p.limit); + const float dsilu_val = dqgelu(clamp_val, e); return dclamp_val ? dsilu_val: 0.0f; } diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 0d667a0ece..b3fc971e26 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -11,7 +11,7 @@ #include "../common.h" #include "../utils.cuh" - +#include "math.h" namespace transformer_engine { /* \brief Helper class that enables storing multiple values of type DType @@ -431,7 +431,14 @@ __launch_bounds__(unary_kernel_threads) __global__ #pragma unroll for (int i = 0; i < nvec; ++i) { const ComputeType val = static_cast(loader0.separate()[i]); - const ComputeType val2 = static_cast(loader1.separate()[i]); + ComputeType val2 = static_cast(loader1.separate()[i]); + + if constexpr(std::is_same::value){ + // Clamp the gated value and add 1 at the end + // https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 + ComputeType limit = p.limit; + val2 = std::min(std::max(-limit, val2), limit) + 1; + } ComputeType temp = static_cast(Activation(val, p) * val2); if (requires_amax) { __builtin_assume(max >= 0); @@ -532,10 +539,18 @@ __launch_bounds__(unary_kernel_threads) __global__ for (int i = 0; i < nvec; ++i) { const ComputeType grad_val = static_cast(grad_loader.separate()[i]); const ComputeType gelu_in = static_cast(input_loader0.separate()[i]); - const ComputeType gate_in = static_cast(input_loader1.separate()[i]); + ComputeType gate_in = static_cast(input_loader1.separate()[i]); + ComputeType dgate_in = 1.0f; + + if constexpr(std::is_same::value){ + // In case of GPT OSS, clamp the activation and gate values + const ComputeType limit = p.limit; + dgate_in = gate_in <= limit && gate_in >= -limit ? 1.0f : 0.0f; // Derivative of clamp + gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f; + } ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; - ComputeType after_dgate = grad_val * Activation(gelu_in, p); + ComputeType after_dgate = grad_val * Activation(gelu_in, p) * dgate_in; if (requires_amax) { __builtin_assume(max >= 0); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 63fe70235f..657e638b00 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -197,9 +197,9 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer); py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit); +py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float limit); -py::object gpt_oss_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit); +py::object gpt_oss_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, float limit); /*************************************************************************************************** * LayerNorm **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index f7fd994832..f4d87e3f25 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -185,7 +185,7 @@ py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle q return dactivation_helper(grad, input, quantizer); } -py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit){ +py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float limit){ init_extension(); // Input tensor auto input_tensor = input.contiguous(); @@ -195,6 +195,7 @@ py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float a auto quantizer_cpp = convert_quantizer(quantizer); const auto input_shape = input_cpp.shape(); std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + output_shape.back() /= 2; auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); @@ -203,27 +204,27 @@ py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float a detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation directly NVTE_SCOPED_GIL_RELEASE( - { nvte_gptoss_swiglu(input_cpp.data(), out_cpp.data(), alpha, min_limit, max_limit, at::cuda::getCurrentCUDAStream()); }); + { nvte_gptoss_swiglu(input_cpp.data(), out_cpp.data(),limit, at::cuda::getCurrentCUDAStream()); }); } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // Compute activation in high-precision fused together with amax, then quantize. auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE( - { nvte_gptoss_swiglu(input_cpp.data(), temp_cpp.data(), alpha, min_limit, max_limit, at::cuda::getCurrentCUDAStream()); }); + { nvte_gptoss_swiglu(input_cpp.data(), temp_cpp.data(), limit, at::cuda::getCurrentCUDAStream()); }); quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); } else { // Compute activation in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE( - { nvte_gptoss_swiglu(input_cpp.data(), temp_cpp.data(), alpha, min_limit, max_limit, at::cuda::getCurrentCUDAStream()); }); + { nvte_gptoss_swiglu(input_cpp.data(), temp_cpp.data(), limit, at::cuda::getCurrentCUDAStream()); }); quantizer_cpp->quantize(temp_cpp, out_cpp); } return out_py; } -py::object gpt_oss_dswiglu(const at::Tensor &grad_output, const at::Tensor &input, py::handle quantizer, float alpha, float min_limit, float max_limit){ +py::object gpt_oss_dswiglu(const at::Tensor &grad_output, const at::Tensor &input, py::handle quantizer, float limit){ init_extension(); // Grad output and input tensors auto grad_output_tensor = grad_output.contiguous(); @@ -244,7 +245,7 @@ py::object gpt_oss_dswiglu(const at::Tensor &grad_output, const at::Tensor &inpu detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation backward directly NVTE_SCOPED_GIL_RELEASE({ - nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), alpha, min_limit, max_limit, + nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), limit, at::cuda::getCurrentCUDAStream()); }); } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { @@ -252,7 +253,7 @@ py::object gpt_oss_dswiglu(const at::Tensor &grad_output, const at::Tensor &inpu auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), alpha, min_limit, max_limit, + nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), limit, at::cuda::getCurrentCUDAStream()); }); quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); @@ -260,7 +261,7 @@ py::object gpt_oss_dswiglu(const at::Tensor &grad_output, const at::Tensor &inpu // Compute activation backward in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), alpha, min_limit, max_limit, + nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), limit, at::cuda::getCurrentCUDAStream()); }); quantizer_cpp->quantize(temp_cpp, grad_input_cpp); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 541b16848e..d4245c823f 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -136,6 +136,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer")); m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), py::arg("quantizer")); + m.def("gpt_oss_swiglu", transformer_engine::pytorch::gpt_oss_swiglu, "SwiGLU activation used in GPT OSS", py::arg("input"), + py::arg("quantizer"), py::arg("limit") = 7.0f); /* Backward of GELU and variants */ m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); @@ -159,6 +161,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("fwd_input"), py::arg("quantizer")); m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("gpt_oss_dswiglu", transformer_engine::pytorch::gpt_oss_dswiglu, "Backward of SwiGLU used in GPT OSS", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer"), py::arg("limit") = 7.0f); /* DBias + DAct fusions*/ m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 2c903675fb..505a9ea785 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,7 +4,7 @@ """Single tensor operations supported by the operation fuser.""" -from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU +from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU, GptOssSwiglu from .add_extra_input import AddExtraInput from .all_gather import AllGather from .all_reduce import AllReduce diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 5ef421bc1d..ab72eaf9c1 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -27,6 +27,7 @@ "SReGLU", "SiLU", "SwiGLU", + "GptOssSwiglu" ] @@ -389,3 +390,40 @@ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dswiglu(*args, **kwargs) + + +class GptOssSwiglu(_ActivationOperation): + r"""GPT-OSS SwiGLU with clamped SiLU + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GPT-OSS-SwiGLU}(a, b) = \text{clamp}(a, -\infty, \text{limit}) \cdot \sigma(1.702 \cdot \text{clamp}(a, -\infty, \text{limit})) \cdot (\text{clamp}(b, -\text{limit}, \text{limit}) + 1) + + where + + .. math:: + + a = x[..., ::2], \quad b = x[..., 1::2] + + and :math:`\sigma(x)` is the sigmoid function, and :math:`\text{limit}` is a hyperparameter. + + Implementation based on `GPT-OSS`__. + Parameters + ---------- + limit: float + The clamp limit. + + """ + + def __init__(self, *, limit: float, cache_quantized_input: bool = False): + super().__init__(cache_quantized_input=cache_quantized_input) + self.limit = limit + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.gpt_oss_swiglu(*args, limit=self.limit, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.gpt_oss_dswiglu(*args, limit=self.limit, **kwargs) \ No newline at end of file From c9d33117ff1df3c94ed9170fa3e782f973a97816 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Sep 2025 03:49:41 +0000 Subject: [PATCH 05/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe --- tests/pytorch/test_fusible_ops.py | 6 +- transformer_engine/common/activation/gelu.cu | 3 +- transformer_engine/common/activation/relu.cu | 3 +- .../common/activation/swiglu.cu | 11 ++- .../include/transformer_engine/activation.h | 7 +- .../common/util/cast_gated_kernels.cuh | 94 +++++++++---------- transformer_engine/common/util/math.h | 16 ++-- .../common/util/vectorized_pointwise.h | 10 +- transformer_engine/pytorch/csrc/extensions.h | 3 +- .../pytorch/csrc/extensions/activation.cpp | 30 +++--- .../pytorch/csrc/extensions/pybind.cpp | 10 +- .../pytorch/ops/basic/__init__.py | 14 ++- .../pytorch/ops/basic/activation.py | 4 +- 13 files changed, 113 insertions(+), 98 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 646f3ad23c..5674f842d9 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1759,10 +1759,11 @@ def test_gpt_oss_swiglu( forward = te_ops.Sequential( te_ops.Quantize(forward=False, backward=quantize_backward), te_ops.GptOssSwiglu(limit=7.0), - te_ops.Quantize(forward=quantize_forward, backward=False)) + te_ops.Quantize(forward=quantize_forward, backward=False), + ) with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) - + y_test.backward(dy_test) # Expected numerical error @@ -1776,7 +1777,6 @@ def test_gpt_oss_swiglu( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) - @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) @pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2))) @pytest.mark.parametrize("dtype", _dtypes) diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index cea17463bd..27ffd3f4e5 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -56,5 +56,6 @@ void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu); using namespace transformer_engine; - dgated_act_fn, dqgelu>(grad, input, output, {}, stream); + dgated_act_fn, dqgelu>(grad, input, output, {}, + stream); } diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index e7748a8cd6..8e598d5ea5 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -56,5 +56,6 @@ void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu); using namespace transformer_engine; - dgated_act_fn, dsrelu>(grad, input, output, {}, stream); + dgated_act_fn, dsrelu>(grad, input, output, {}, + stream); } diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 0081219027..30042392a6 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -33,8 +33,8 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp dgated_act_fn, dsilu>(grad, input, output, {}, stream); } -void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, - float limit, cudaStream_t stream){ +void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, + cudaStream_t stream) { NVTE_API_CALL(nvte_gptoss_swiglu); using namespace transformer_engine; GptOssParam param = {limit}; @@ -42,9 +42,10 @@ void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, } void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, - float limit, cudaStream_t stream){ + float limit, cudaStream_t stream) { NVTE_API_CALL(nvte_gptoss_dswiglu); using namespace transformer_engine; GptOssParam param = {limit}; - dgated_act_fn, oss_dsilu>(grad, input, output, param, stream); -} \ No newline at end of file + dgated_act_fn, oss_dsilu>(grad, input, output, + param, stream); +} diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index c735be8926..10aeba8ede 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -186,8 +186,8 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) /* TODO: Add documentation once the API finalizes. */ -void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, cudaStream_t stream); - +void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, + cudaStream_t stream); void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); @@ -250,7 +250,8 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp /* TODO: Add documentation once the API finalizes. */ -void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, float limit, cudaStream_t stream); +void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + float limit, cudaStream_t stream); void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index d7994267b7..fe1a522a96 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -55,7 +55,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const __grid_constant__ CUtensorMap tensor_map_output_act, const __grid_constant__ CUtensorMap tensor_map_output_gate, float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, const size_t rows, const size_t cols, const ParamOP p) { + const float *const scale_ptr, const size_t rows, const size_t cols, + const ParamOP p) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; @@ -170,12 +171,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_elt = static_cast(in_act_sh_curr[shmem_idx]); float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); - float dgate_elt = 1.0f; // gating is ideally an identity function - if constexpr(std::is_same::value){ + float dgate_elt = 1.0f; // gating is ideally an identity function + if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values - const float limit = p.limit; - dgate_elt = gate_elt <= limit && gate_elt >= -limit ? 1.0f : 0.0f; // Derivative of clamp - gate_elt = min(max(-limit, gate_elt), limit) + 1; + const float limit = p.limit; + dgate_elt = gate_elt <= limit && gate_elt >= -limit ? 1.0f : 0.0f; // Derivative of clamp + gate_elt = min(max(-limit, gate_elt), limit) + 1; } if constexpr (IS_DGATED) { @@ -184,31 +185,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = act_elt; float act_x; float dact_x; - if constexpr(std::is_same::value){ + if constexpr (std::is_same::value) { const float limit = p.limit; const float x = min(act_elt, limit); const float s = sigmoidf(1.702 * x); act_x = x * s; - if(x <= limit){ + if (x <= limit) { dact_x = s + s * (1 - s) * 1.702 * x; + } else { + dact_x = 0.0f; } - else{ - dact_x = 0.0f; - } - } - else{ + } else { if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); act_x = x * s; dact_x = x * s * (1 - s) + s; - } - else { + } else { act_x = ActOP(x, p); dact_x = DActOP(x, p); } } - float after_dact = dact_x * grad_elt * gate_elt; float after_dgate = act_x * grad_elt * dgate_elt; @@ -321,8 +318,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise, - const ParamOP p) { + const size_t scale_stride_colwise, const ParamOP p) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; @@ -498,37 +494,34 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate_sh[shmem_offset_colwise]); float after_act_elt; float after_gate_elt; - float dgate_elt = 1.0f; // gating is ideally an identity function - if constexpr(std::is_same::value){ + float dgate_elt = 1.0f; // gating is ideally an identity function + if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values - const float limit = p.limit; - dgate_elt = gate_elt <= limit && gate_elt >= -limit ? 1.0f : 0.0f; // Derivative of clamp - gate_elt = min(max(-limit, gate_elt), limit) + 1.0f; + const float limit = p.limit; + dgate_elt = gate_elt <= limit && gate_elt >= -limit ? 1.0f : 0.0f; // Derivative of clamp + gate_elt = min(max(-limit, gate_elt), limit) + 1.0f; } if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); const float x = act_elt; float act_x; float dact_x; - if constexpr(std::is_same::value){ + if constexpr (std::is_same::value) { const float limit = p.limit; const float x = min(act_elt, limit); const float s = sigmoidf(1.702 * x); act_x = x * s; - if(x <= limit){ + if (x <= limit) { dact_x = s + s * (1 - s) * 1.702 * x; + } else { + dact_x = 0.0f; } - else{ - dact_x = 0.0f; - } - } - else{ + } else { if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); act_x = x * s; dact_x = x * s * (1 - s) + s; - } - else { + } else { act_x = ActOP(x, p); dact_x = DActOP(x, p); } @@ -763,36 +756,34 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float after_act_elt; float after_gate_elt; float dgate_elt = 1.0f; - if constexpr(std::is_same::value){ + if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values - const float limit = p.limit; - dgate_elt = gate_elt <= limit && gate_elt >= -limit ? 1.0f : 0.0f; // Derivative of clamp - gate_elt = min(max(-limit, gate_elt), limit) + 1.0f; + const float limit = p.limit; + dgate_elt = + gate_elt <= limit && gate_elt >= -limit ? 1.0f : 0.0f; // Derivative of clamp + gate_elt = min(max(-limit, gate_elt), limit) + 1.0f; } if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad.data.elt[e]); const float x = act_elt; float act_x; float dact_x; - if constexpr(std::is_same::value){ + if constexpr (std::is_same::value) { const float limit = p.limit; const float x = min(act_elt, limit); const float s = sigmoidf(1.702 * x); act_x = x * s; - if(x <= limit){ + if (x <= limit) { dact_x = s + s * (1 - s) * 1.702 * x; + } else { + dact_x = 0.0f; } - else{ - dact_x = 0.0f; - } - } - else{ + } else { if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); act_x = x * s; dact_x = x * s * (1 - s) + s; - } - else { + } else { act_x = ActOP(x, {}); dact_x = DActOP(x, {}); } @@ -1021,8 +1012,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cast_fp8_gated_kernel <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, - tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols, p); + tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p); NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) ); // NOLINT(*) } @@ -1238,7 +1228,8 @@ void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t str template -void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) { +void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p, + cudaStream_t stream) { CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(input, "dgated_act_input"); CheckOutputTensor(*output, "dgated_act_output"); @@ -1341,8 +1332,8 @@ namespace detail { template -void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, ParamOP p, - cudaStream_t stream) { +void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, + ParamOP p, cudaStream_t stream) { using namespace gated_kernels; Tensor grad_empty_tensor; const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; @@ -1355,7 +1346,8 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, } else { if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { if constexpr (IS_DGATED) { - cast_dgated(grad_tensor, gated_input_tensor, output_tensor, p, stream); + cast_dgated(grad_tensor, gated_input_tensor, output_tensor, p, + stream); } else { cast_gated(gated_input_tensor, output_tensor, p, stream); } diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 22a480d13b..73bb33277f 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -11,7 +11,7 @@ namespace transformer_engine { struct Empty {}; -struct GptOssParam{ +struct GptOssParam { float limit; }; @@ -70,8 +70,8 @@ __device__ inline OType clamp(const IType val, const float min_limit, const floa template __device__ inline OType oss_silu(const IType val, const GptOssParam& p) { const Empty e = {}; - const float cval = clamp(val, - -std::numeric_limits::infinity(), p.limit); // Clamping + const float cval = + clamp(val, -std::numeric_limits::infinity(), p.limit); // Clamping return qgelu(cval, e); } @@ -90,12 +90,12 @@ __device__ inline OType dclamp(const IType val, const float min_limit, const flo template __device__ inline OType oss_dsilu(const IType val, const GptOssParam& p) { const Empty e = {}; - const bool dclamp_val = dclamp(val, - -std::numeric_limits::infinity(), p.limit); - const float clamp_val = clamp(val, - -std::numeric_limits::infinity(), p.limit); + const bool dclamp_val = + dclamp(val, -std::numeric_limits::infinity(), p.limit); + const float clamp_val = + clamp(val, -std::numeric_limits::infinity(), p.limit); const float dsilu_val = dqgelu(clamp_val, e); - return dclamp_val ? dsilu_val: 0.0f; + return dclamp_val ? dsilu_val : 0.0f; } template diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index b3fc971e26..c9e54ecef8 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -433,7 +433,7 @@ __launch_bounds__(unary_kernel_threads) __global__ const ComputeType val = static_cast(loader0.separate()[i]); ComputeType val2 = static_cast(loader1.separate()[i]); - if constexpr(std::is_same::value){ + if constexpr (std::is_same::value) { // Clamp the gated value and add 1 at the end // https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 ComputeType limit = p.limit; @@ -542,11 +542,11 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType gate_in = static_cast(input_loader1.separate()[i]); ComputeType dgate_in = 1.0f; - if constexpr(std::is_same::value){ + if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values - const ComputeType limit = p.limit; - dgate_in = gate_in <= limit && gate_in >= -limit ? 1.0f : 0.0f; // Derivative of clamp - gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f; + const ComputeType limit = p.limit; + dgate_in = gate_in <= limit && gate_in >= -limit ? 1.0f : 0.0f; // Derivative of clamp + gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f; } ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 657e638b00..8495179831 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -199,7 +199,8 @@ py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle q py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float limit); -py::object gpt_oss_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, float limit); +py::object gpt_oss_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + float limit); /*************************************************************************************************** * LayerNorm **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index f4d87e3f25..52611f5b03 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -8,7 +8,6 @@ #include "common.h" #include "pybind.h" - namespace transformer_engine::pytorch { template @@ -185,7 +184,7 @@ py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle q return dactivation_helper(grad, input, quantizer); } -py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float limit){ +py::object gpt_oss_swiglu(const at::Tensor& input, py::handle quantizer, float limit) { init_extension(); // Input tensor auto input_tensor = input.contiguous(); @@ -203,28 +202,34 @@ py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float l if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation directly - NVTE_SCOPED_GIL_RELEASE( - { nvte_gptoss_swiglu(input_cpp.data(), out_cpp.data(),limit, at::cuda::getCurrentCUDAStream()); }); + NVTE_SCOPED_GIL_RELEASE({ + nvte_gptoss_swiglu(input_cpp.data(), out_cpp.data(), limit, at::cuda::getCurrentCUDAStream()); + }); } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // Compute activation in high-precision fused together with amax, then quantize. auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { nvte_gptoss_swiglu(input_cpp.data(), temp_cpp.data(), limit, at::cuda::getCurrentCUDAStream()); }); + NVTE_SCOPED_GIL_RELEASE({ + nvte_gptoss_swiglu(input_cpp.data(), temp_cpp.data(), limit, + at::cuda::getCurrentCUDAStream()); + }); quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); } else { // Compute activation in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { nvte_gptoss_swiglu(input_cpp.data(), temp_cpp.data(), limit, at::cuda::getCurrentCUDAStream()); }); + NVTE_SCOPED_GIL_RELEASE({ + nvte_gptoss_swiglu(input_cpp.data(), temp_cpp.data(), limit, + at::cuda::getCurrentCUDAStream()); + }); quantizer_cpp->quantize(temp_cpp, out_cpp); } return out_py; } -py::object gpt_oss_dswiglu(const at::Tensor &grad_output, const at::Tensor &input, py::handle quantizer, float limit){ +py::object gpt_oss_dswiglu(const at::Tensor& grad_output, const at::Tensor& input, + py::handle quantizer, float limit) { init_extension(); // Grad output and input tensors auto grad_output_tensor = grad_output.contiguous(); @@ -246,7 +251,7 @@ py::object gpt_oss_dswiglu(const at::Tensor &grad_output, const at::Tensor &inpu // Compute activation backward directly NVTE_SCOPED_GIL_RELEASE({ nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), limit, - at::cuda::getCurrentCUDAStream()); + at::cuda::getCurrentCUDAStream()); }); } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // Compute activation backward in high-precision fused together with amax, then quantize. @@ -254,7 +259,7 @@ py::object gpt_oss_dswiglu(const at::Tensor &grad_output, const at::Tensor &inpu auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), limit, - at::cuda::getCurrentCUDAStream()); + at::cuda::getCurrentCUDAStream()); }); quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); } else { @@ -262,13 +267,12 @@ py::object gpt_oss_dswiglu(const at::Tensor &grad_output, const at::Tensor &inpu auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), limit, - at::cuda::getCurrentCUDAStream()); + at::cuda::getCurrentCUDAStream()); }); quantizer_cpp->quantize(temp_cpp, grad_input_cpp); } return grad_input_py; - } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d4245c823f..b81608b1b2 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -136,8 +136,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer")); m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), py::arg("quantizer")); - m.def("gpt_oss_swiglu", transformer_engine::pytorch::gpt_oss_swiglu, "SwiGLU activation used in GPT OSS", py::arg("input"), - py::arg("quantizer"), py::arg("limit") = 7.0f); + m.def("gpt_oss_swiglu", transformer_engine::pytorch::gpt_oss_swiglu, + "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), + py::arg("limit") = 7.0f); /* Backward of GELU and variants */ m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); @@ -161,8 +162,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("fwd_input"), py::arg("quantizer")); m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("gpt_oss_dswiglu", transformer_engine::pytorch::gpt_oss_dswiglu, "Backward of SwiGLU used in GPT OSS", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer"), py::arg("limit") = 7.0f); + m.def("gpt_oss_dswiglu", transformer_engine::pytorch::gpt_oss_dswiglu, + "Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"), + py::arg("quantizer"), py::arg("limit") = 7.0f); /* DBias + DAct fusions*/ m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 505a9ea785..6dfdf3cac6 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,7 +4,19 @@ """Single tensor operations supported by the operation fuser.""" -from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU, GptOssSwiglu +from .activation import ( + GELU, + GEGLU, + QGELU, + QGEGLU, + ReLU, + ReGLU, + SReLU, + SReGLU, + SiLU, + SwiGLU, + GptOssSwiglu, +) from .add_extra_input import AddExtraInput from .all_gather import AllGather from .all_reduce import AllReduce diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index ab72eaf9c1..7dfd655e18 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -27,7 +27,7 @@ "SReGLU", "SiLU", "SwiGLU", - "GptOssSwiglu" + "GptOssSwiglu", ] @@ -426,4 +426,4 @@ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.gpt_oss_swiglu(*args, limit=self.limit, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.gpt_oss_dswiglu(*args, limit=self.limit, **kwargs) \ No newline at end of file + return tex.gpt_oss_dswiglu(*args, limit=self.limit, **kwargs) From 5d06c2acf4e866274223ae0c7039731147969cae Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 8 Sep 2025 05:22:01 +0000 Subject: [PATCH 06/53] fix the merge conflict MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz * fix for fp8 blockwise recipe Signed-off-by: zhongboz * resolve comments Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] fix cross entropy vanishing gradients (#2139) * fix cross entropy Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Casper * fix comments Signed-off-by: Casper * fix: few more style issues Signed-off-by: Casper * fix: remove grad_output_stride (unnecessary) Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: only backward was broken Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Generalize cross entropy backward kernel to handle reduced and unreduced loss Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Casper Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon Signed-off-by: Varun Thumbe Fix bug when enabling --overlap-grad-reduce in mcore (#2142) * fix bugs when enabling --overlap-grad-reduce in mcore Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI Signed-off-by: Hongbin Liu * format Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongbin Liu Co-authored-by: Hongbin Liu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Fix CUDA version in setup.py (#2132) * Fix CUDA version in setup.py Signed-off-by: Vladimir Cherepanov * Re-enable building comm-gemm tests Signed-off-by: Vladimir Cherepanov * WAR for nvidia-nvshmem package Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [JAX] NoScaleTensor wrapper for non-quantized data (#2136) * Custom call tests passing Signed-off-by: Jeremy Berchtold * Fix test_layer.py Signed-off-by: Jeremy Berchtold * Lint Signed-off-by: Jeremy Berchtold * Fix comments Signed-off-by: Jeremy Berchtold * Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling Signed-off-by: Jeremy Berchtold * Fix shardy issue with amax being shape 1,1,1 instead of shape (1,) Signed-off-by: Jeremy Berchtold * Add higher-precision VJP tests to test_distributed_layernorm_mlp Signed-off-by: Jeremy Berchtold * Cast non-quantized kernels to input dtype in VJPs Signed-off-by: Jeremy Berchtold * Rename HighPrecisionTensor to NoScaleTensor Signed-off-by: Jeremy Berchtold * Use NoScaleTensor in pure JAX impls where it was missing Signed-off-by: Jeremy Berchtold * Fix tests Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: Varun Thumbe [JAX] Fix GroupedScaledTensor creation with keyword arg (#2154) Fix GroupedScaledTensor creation Signed-off-by: Phuong Nguyen Signed-off-by: Varun Thumbe Fixing few issues with multi-process launching. (#2155) * Fixing few issues with multi-process launching. Signed-off-by: Ming Huang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ming Huang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe Update list of authorized CI users (#2152) Signed-off-by: Tim Moon Signed-off-by: Varun Thumbe a bit of cleanup Signed-off-by: Varun Thumbe --- .github/workflows/trigger-ci.yml | 1 + setup.py | 7 +- tests/cpp/CMakeLists.txt | 1 + tests/jax/multi_process_launch.sh | 6 +- tests/jax/test_custom_call_compute.py | 16 +- tests/jax/test_distributed_layernorm_mlp.py | 17 +- ..._multi_process_distributed_grouped_gemm.py | 16 +- tests/pytorch/test_fusible_ops.py | 7 +- tests/pytorch/test_parallel_cross_entropy.py | 59 +++-- transformer_engine/common/__init__.py | 5 + .../common/activation/activation_template.h | 4 +- transformer_engine/common/activation/gelu.cu | 12 +- transformer_engine/common/activation/relu.cu | 12 +- .../common/activation/swiglu.cu | 16 +- .../include/transformer_engine/activation.h | 6 +- .../include/transformer_engine/recipe.h | 15 ++ .../common/recipe/current_scaling.cu | 66 ++++- .../common/transpose/cast_transpose.h | 5 +- .../quantize_transpose_square_blockwise.cu | 20 +- .../quantize_transpose_vector_blockwise.cu | 18 +- .../common/util/cast_gated_kernels.cuh | 25 +- .../common/util/cast_kernels.cuh | 11 +- transformer_engine/common/util/math.h | 21 +- .../common/util/vectorized_pointwise.h | 6 +- transformer_engine/jax/activation.py | 12 +- .../jax/cpp_extensions/activation.py | 49 ++-- transformer_engine/jax/cpp_extensions/gemm.py | 16 +- .../jax/cpp_extensions/normalization.py | 34 +-- .../jax/cpp_extensions/quantization.py | 70 +++--- transformer_engine/jax/dense.py | 41 ++- transformer_engine/jax/layernorm.py | 5 +- transformer_engine/jax/layernorm_dense.py | 12 +- transformer_engine/jax/layernorm_mlp.py | 23 +- transformer_engine/jax/quantize/quantizer.py | 42 +++- .../jax/quantize/scaling_modes.py | 86 ++++++- transformer_engine/jax/quantize/tensor.py | 124 ++++++--- .../pytorch/csrc/extensions/activation.cpp | 237 +++++------------- transformer_engine/pytorch/csrc/quantizer.cpp | 3 +- transformer_engine/pytorch/module/base.py | 3 +- .../pytorch/module/grouped_linear.py | 6 +- .../pytorch/module/layernorm_mlp.py | 7 +- .../pytorch/triton/cross_entropy.py | 3 + 42 files changed, 660 insertions(+), 485 deletions(-) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 85a81a6d48..f12a95d79a 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -57,6 +57,7 @@ jobs: || github.actor == 'tdophung' || github.actor == 'vthumbe1503' || github.actor == 'janekb04' + || github.actor == 'shengfangd' ) steps: - name: Check if comment is issued by authorized person diff --git a/setup.py b/setup.py index 52adaf9238..ed1f5b8a9d 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ from build_tools.te_version import te_version from build_tools.utils import ( cuda_archs, + cuda_version, get_frameworks, remove_dups, ) @@ -70,11 +71,11 @@ def setup_common_extension() -> CMakeExtension: if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))): cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON") cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution( - "nvidia-cublasmp-cu12" - ).locate_file("nvidia/cublasmp/cu12") + f"nvidia-cublasmp-cu{cuda_version()[0]}" + ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}") cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution( - "nvidia-nvshmem-cu12" + f"nvidia-nvshmem-cu{cuda_version()[0]}" ).locate_file("nvidia/nvshmem") cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}") print("CMAKE_FLAGS:", cmake_flags[-2:]) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index c2c9d0d915..412c5d34d9 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -43,5 +43,6 @@ include_directories(${CMAKE_SOURCE_DIR}) find_package(CUDAToolkit REQUIRED) include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +add_subdirectory(comm_gemm) add_subdirectory(operator) add_subdirectory(util) diff --git a/tests/jax/multi_process_launch.sh b/tests/jax/multi_process_launch.sh index 3e0852f393..fcb066de75 100644 --- a/tests/jax/multi_process_launch.sh +++ b/tests/jax/multi_process_launch.sh @@ -12,12 +12,12 @@ XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true export XLA_FLAGS="${XLA_BASE_FLAGS}" -NUM_RUNS=$(nvidia-smi --query-gpu=count --format=csv,noheader) +NUM_RUNS=$(nvidia-smi -L | wc -l) for ((i=1; i /dev/null 2>&1 & + CUDA_VISIBLE_DEVICES=$i python $SCRIPT_NAME 127.0.0.1:12345 $i $NUM_RUNS > /dev/null 2>&1 & done -CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_PROC +CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS wait diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index d5f21651db..11f07d9133 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -31,6 +31,7 @@ from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax.quantize import ( + NoScaleTensor, ScaledTensor, ScaledTensor1x, ScaledTensor2x, @@ -182,7 +183,7 @@ def assert_dequantized_grouped_scaled_tensor( class TestActivation: def ref_act(self, x, activation_type): - return _jax_act_lu(x, activation_type) + return _jax_act_lu(x, activation_type).data def value_n_grad_ref_func(self, x, activation_type): jitted_reference = jit( @@ -337,8 +338,8 @@ def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantize ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer) else: ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer) - # if isinstance(ln_out, ScaledTensor): - # ln_out = ln_out.dequantize() + # This is a no-op for non-quantized data + ln_out = ln_out.dequantize() return ln_out key = jax.random.PRNGKey(0) @@ -765,7 +766,9 @@ def _test_quantize_dact_dbias( te_output, jax_output, precise_comparison=precise_comparison ) else: - assert_allclose(te_output, jax_output) + assert isinstance(te_output, NoScaleTensor) + assert isinstance(jax_output, NoScaleTensor) + assert_allclose(te_output.data, jax_output.data) if is_dbias: # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. @@ -1020,8 +1023,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer) else: ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer) - if isinstance(ln_out, ScaledTensor): - ln_out = ln_out.dequantize() + ln_out = ln_out.dequantize() return ln_out @@ -1177,7 +1179,7 @@ def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2): bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) - x = _jax_act_lu(linear_1_out, activation_type) + x = _jax_act_lu(linear_1_out, activation_type).data linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ()))) if use_bias: bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 90b762c240..a44921c641 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -173,7 +173,9 @@ def _test_layernorm_mlp_grad( ) # Single GPU - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): + with fp8_autocast( + enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource() + ): single_jitter = jax.jit( value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), @@ -184,7 +186,7 @@ def _test_layernorm_mlp_grad( devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast( - enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource ): k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp")) @@ -226,7 +228,12 @@ def _test_layernorm_mlp_grad( fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 - assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) + + if fwd_test_type == jnp.float16 and use_bias: + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5) + else: + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) + for i in range(len(inputs)): if multi_grads[i] is not None: if isinstance(multi_grads[i], list): @@ -252,7 +259,7 @@ def _test_layernorm_mlp_grad( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, @@ -281,7 +288,7 @@ def test_layernorm_mlp_grad( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad_shardy( self, diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py index 6fce62d8cc..31209d1bc9 100644 --- a/tests/jax/test_multi_process_distributed_grouped_gemm.py +++ b/tests/jax/test_multi_process_distributed_grouped_gemm.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp +import jax.experimental.multihost_utils as jem from transformer_engine.jax.dense import grouped_dense as te_grouped_dense from transformer_engine.jax.quantize import ( @@ -13,7 +14,7 @@ ScalingMode, ) -from utils import assert_allclose +from utils import assert_allclose, dtype_tols N_GROUP = 8 @@ -137,9 +138,16 @@ def run(x, w): out, dx, dw = test_func_jitted(x, w, w_amax) ref_out, ref_dx, ref_dw = ref_func_jitted(x, w_global) - assert_allclose(out, ref_out, dtype=jnp.float8_e4m3fn) - assert_allclose(dx, ref_dx, dtype=jnp.float8_e5m2) - assert_allclose(dw, ref_dw, dtype=jnp.float8_e5m2) + e4m3_tols = dtype_tols(jnp.float8_e4m3fn) + e5m2_tols = dtype_tols(jnp.float8_e5m2) + + out, ref_out = jem.process_allgather((out, ref_out)) + dx, ref_dx = jem.process_allgather((dx, ref_dx)) + dw, ref_dw = jem.process_allgather((dw, ref_dw)) + + jnp.allclose(out, ref_out, **e4m3_tols) + jnp.allclose(dx, ref_dx, **e5m2_tols) + jnp.allclose(dw, ref_dw, **e5m2_tols) if __name__ == "__main__": diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 5674f842d9..93d0c3dc53 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1721,7 +1721,6 @@ def test_gpt_oss_swiglu( quantize_forward: bool, quantize_backward: bool, ): - print(_quantization_list) # Tensor dimensions in_shape = list(out_shape) in_shape[-1] *= 2 @@ -1747,8 +1746,8 @@ def test_gpt_oss_swiglu( # Plain PyTorch implementation x_glu, x_linear = x_ref.chunk(2, dim=-1) - x_glu = x_glu.clamp(min=None, max=7.0) - x_linear = x_linear.clamp(min=-7.0, max=7.0) + x_glu = x_glu.clamp(min=None, max=0.1) + x_linear = x_linear.clamp(min=-0.1, max=0.1) out_glu = x_glu * torch.sigmoid(1.702 * x_glu) y_ref = out_glu * (x_linear + 1) y_ref.backward(dy_ref) @@ -1758,7 +1757,7 @@ def test_gpt_oss_swiglu( forward = te_ops.Sequential( te_ops.Quantize(forward=False, backward=quantize_backward), - te_ops.GptOssSwiglu(limit=7.0), + te_ops.GptOssSwiglu(limit=0.1), te_ops.Quantize(forward=quantize_forward, backward=False), ) with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index 77bea2b360..fa56852ffc 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -6,6 +6,8 @@ import torch from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from utils import dtype_tols + class TestParallelCrossEntropy: @@ -18,19 +20,25 @@ def generate_infra(self, reduce_loss: bool, label_smoothing: float): label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" ) - def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): - + def generate_input( + self, + dtype: torch.dtype, + swap_dim: bool, + ignore_idx: bool, + device: torch.device = "cuda", + ): SQ = random.choice([64, 128]) batch = random.choice([1, 2]) vocab = random.choice([64000, 128000]) ignore = random.sample(range(0, SQ - 1), 5) + # Generate random data if swap_dim: - self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda() + self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (SQ, batch), device=device) else: - self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda() + self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (batch, SQ), device=device) if ignore_idx: for i in ignore: @@ -40,9 +48,14 @@ def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): else: self.tar_test[0][i] = -100 + # Make copy of data for reference implementation self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab)) self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,)) + # Enable autograd + self.input_test.requires_grad_() + self.input_ref.requires_grad_() + def one_iteration_test( self, dtype: torch.dtype, @@ -52,18 +65,20 @@ def one_iteration_test( ignore_idx: bool = False, ): + # Random data self.generate_input(dtype, swap_dim, ignore_idx) - self.input_test.requires_grad_(True) - self.input_ref.requires_grad_(True) - + # Forward pass test_loss = self.test_loss_func( self.input_test, self.tar_test, label_smoothing, reduce_loss, None ) - ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) - # Handle backward pass based on the test scenario + # Compute square to avoid trivial backward pass + test_loss = torch.square(test_loss) + ref_loss = torch.square(ref_loss) + + # Backward pass if reduce_loss: test_loss.backward() ref_loss.backward() @@ -71,16 +86,18 @@ def one_iteration_test( test_loss.sum().backward() ref_loss.sum().backward() - test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss - - if ignore_idx: - print(test_loss, ref_loss) - - # Compare gradients when backward pass was called - torch.testing.assert_close( - torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad - ) - + # Check that loss and grad input match + tols = dtype_tols(dtype) + test_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = ref_loss.reshape(test_loss.size()) + test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = ref_grad_input.reshape(test_grad_input.size()) + torch.testing.assert_close(test_loss, ref_loss, **tols) + torch.testing.assert_close(test_grad_input, ref_grad_input, **tols) + + # Reset data self.input_test = None self.input_ref = None self.tar_test = None diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 7feb5fda5f..dd1ec480b2 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -218,6 +218,11 @@ def _nvidia_cudart_include_dir() -> str: except ModuleNotFoundError: return "" + # Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia" + # above doesn't through. However, they don't set "__file__" attribute. + if nvidia.__file__ is None: + return "" + include_dir = Path(nvidia.__file__).parent / "cuda_runtime" return str(include_dir) if include_dir.exists() else "" diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 3f701b1560..78b90c2e93 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -51,7 +51,7 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, } template -void gated_act_fn(const NVTETensor input, NVTETensor output, Param p, cudaStream_t stream) { +void gated_act_fn(const NVTETensor input, NVTETensor output, Param& p, cudaStream_t stream) { using namespace detail; constexpr bool IS_DGATED = false; constexpr NVTETensor grad = nullptr; @@ -60,7 +60,7 @@ void gated_act_fn(const NVTETensor input, NVTETensor output, Param p, cudaStream template -void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param p, +void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param& p, cudaStream_t stream) { using namespace detail; constexpr bool IS_DGATED = true; diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 27ffd3f4e5..9a5cff7fa2 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -23,14 +23,16 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; - gated_act_fn>(input, output, {}, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu); using namespace transformer_engine; - dgated_act_fn, dgelu>(grad, input, output, {}, stream); + Empty e = {}; + dgated_act_fn, dgelu>(grad, input, output, e, stream); } void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -49,13 +51,15 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; - gated_act_fn>(input, output, {}, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu); using namespace transformer_engine; - dgated_act_fn, dqgelu>(grad, input, output, {}, + Empty e = {}; + dgated_act_fn, dqgelu>(grad, input, output, e, stream); } diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index 8e598d5ea5..be38e187e8 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -23,14 +23,16 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; - gated_act_fn>(input, output, {}, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu); using namespace transformer_engine; - dgated_act_fn, drelu>(grad, input, output, {}, stream); + Empty e = {}; + dgated_act_fn, drelu>(grad, input, output, e, stream); } void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -49,13 +51,15 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; - gated_act_fn>(input, output, {}, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu); using namespace transformer_engine; - dgated_act_fn, dsrelu>(grad, input, output, {}, + Empty e = {}; + dgated_act_fn, dsrelu>(grad, input, output, e, stream); } diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 30042392a6..3b0738b559 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -23,27 +23,33 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; - gated_act_fn>(input, output, {}, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dswiglu); using namespace transformer_engine; - dgated_act_fn, dsilu>(grad, input, output, {}, stream); + Empty e = {}; + dgated_act_fn, dsilu>(grad, input, output, e, stream); } -void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, - cudaStream_t stream) { +void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, const float* const args, + int args_size, cudaStream_t stream) { NVTE_API_CALL(nvte_gptoss_swiglu); + NVTE_CHECK(args_size==1); + const float limit = *args; using namespace transformer_engine; GptOssParam param = {limit}; gated_act_fn>(input, output, param, stream); } void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, - float limit, cudaStream_t stream) { + const float* const args, int args_size, cudaStream_t stream) { NVTE_API_CALL(nvte_gptoss_dswiglu); + NVTE_CHECK(args_size==1); + const float limit = *args; using namespace transformer_engine; GptOssParam param = {limit}; dgated_act_fn, oss_dsilu>(grad, input, output, diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 10aeba8ede..f921851ab9 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -186,8 +186,8 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) /* TODO: Add documentation once the API finalizes. */ -void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, - cudaStream_t stream); +void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, const float* const args, + int args_size, cudaStream_t stream); void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); @@ -251,7 +251,7 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp TODO: Add documentation once the API finalizes. */ void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, - float limit, cudaStream_t stream); + const float* const args, int args_size, cudaStream_t stream); void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 50fb696ea6..2fc8c1095c 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -84,6 +84,21 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( */ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Compute an FP8 tensor's amax with quantization config. + * + * The amax (maximum absolute value) of the input tensor is computed + * and written to the amax buffer of the output tensor, using the provided + * quantization configuration. + * One useful config is the noop tensor, which is needed by cuda graph. + * + * \param[in] input Input tensor. Must be unquantized. + * \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling. + * \param[in] config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig config, cudaStream_t stream); + /*! \brief Update an FP8 tensor's scale based on its amax. * * This is only supported for FP8 tensors with per-tensor scaling. diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index e1657b77a1..fd907efcba 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -23,7 +23,11 @@ constexpr int amax_kernel_threads = 512; template __launch_bounds__(amax_kernel_threads) __global__ void amax_kernel(const InputType *input, float *amax, const size_t N, - const size_t num_aligned_elements) { + const size_t num_aligned_elements, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + VectorizedLoader loader(input, N); InputType max = 0.f; const int warp_id = threadIdx.x / THREADS_PER_WARP; @@ -58,7 +62,8 @@ __launch_bounds__(amax_kernel_threads) __global__ } template -void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { +void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr, + cudaStream_t stream) { // Zero out amax so we can update with atomic max NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); @@ -81,16 +86,17 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud switch (align) { case Alignment::SAME_ALIGNED: amax_kernel - <<>>(input, amax, N, num_aligned_elements); + <<>>(input, amax, N, num_aligned_elements, noop_ptr); break; case Alignment::SAME_UNALIGNED: amax_kernel - <<>>(input, amax, N, num_aligned_elements); + <<>>(input, amax, N, num_aligned_elements, noop_ptr); break; case Alignment::DIFFERENT: { // This case is a logic error, since there is only one pointer (input) // in the alignment check. Still safe to process without vectorization. - amax_kernel<1, true, InputType><<>>(input, amax, N, N); + amax_kernel<1, true, InputType> + <<>>(input, amax, N, N, noop_ptr); break; } } @@ -102,8 +108,10 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud } // namespace } // namespace transformer_engine -void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { - NVTE_API_CALL(nvte_compute_amax); +namespace { + +void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream, + const NVTEQuantizationConfig config_) { using namespace transformer_engine; // Check input tensor @@ -138,12 +146,35 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt to_string(output.amax.dtype), ")"); CheckOutputTensor(output, "output_compute_amax", true); + float *noop_ptr = nullptr; + if (config_ != nullptr) { + const QuantizationConfig *config_cpp = reinterpret_cast(config_); + + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr; + noop_ptr = reinterpret_cast( + (noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr)); + } + // Compute amax TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel(reinterpret_cast(input.data.dptr), reinterpret_cast(output.amax.dptr), input.data.numel(), - stream);); // NOLINT(*) + noop_ptr, stream);); // NOLINT(*) +} + +} // anonymous namespace + +void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_amax); + compute_amax_impl(input_, output_, stream, nullptr); +} + +void nvte_compute_amax_with_config(const NVTETensor input_, const NVTETensor output_, + const NVTEQuantizationConfig config_, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_amax_with_config); + compute_amax_impl(input_, output_, stream, config_); } namespace transformer_engine { @@ -151,7 +182,11 @@ namespace { __global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, const float max_fp8, const bool force_pow_2_scales, - const float epsilon) { + const float epsilon, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon, std::numeric_limits::max()); } @@ -197,10 +232,21 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType, max_fp8 = Quantized_Limits::max_norm;); + // noop tensor for cuda graph + float *noop_ptr = nullptr; + if (config_ != nullptr) { + const QuantizationConfig *config_cpp = reinterpret_cast(config_); + + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr; + noop_ptr = reinterpret_cast( + (noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr)); + } + // Update scale compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>( reinterpret_cast(output.amax.dptr), reinterpret_cast(output.scale.dptr), max_fp8, config.force_pow_2_scales, - config.amax_epsilon); + config.amax_epsilon, noop_ptr); NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index a737239260..abfa226e88 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -27,7 +27,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream); + const SimpleTensor &noop_tensor, cudaStream_t stream); // enum class for rowwise usage enum class FP8BlockwiseRowwiseOption { @@ -59,7 +59,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor SimpleTensor &output_t, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, - const bool pow_2_scale, cudaStream_t stream); + const bool pow_2_scale, const SimpleTensor &noop_tensor, + cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index a603d1f1a2..c3f085b877 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -70,11 +70,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, const __grid_constant__ CUtensorMap tensor_map_output_t, - bool pow_2_scaling) { + bool pow_2_scaling, const float* noop_ptr) { using IVec = Vec; using OVecCast = Vec; using OVecTrans = Vec; + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + // shared mem for amax reduction in entire block, each warp produces one amax, there are // NUM_WARPS_IN_BLOCK amax to reduce __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; @@ -249,11 +253,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, - bool pow_2_scaling) { + bool pow_2_scaling, const float* noop_ptr) { using IVec = Vec; using OVecCast = Vec; using OVecTrans = Vec; + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + // shared mem for amax reduction in entire block, each warp produces one amax, there are // NUM_WARPS_IN_BLOCK amax to reduce __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; @@ -473,7 +481,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream) { + const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_square_blockwise); checkCuDriverContext(stream); @@ -494,6 +502,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor size_t scale_t_stride_x = 0; size_t scale_t_stride_y = 0; + const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); + if (return_transpose) { NVTE_CHECK(output_t.shape.size() == input.shape.size(), "output_t must have same number of dimensions as input."); @@ -541,7 +551,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, - tensor_map_output_trans, pow_2_scale); + tensor_map_output_trans, pow_2_scale, noop_ptr); } else { block_scaled_cast_transpose_kernel_notaligned @@ -552,7 +562,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, - pow_2_scale); + pow_2_scale, noop_ptr); } // full-tile ) // return_transpose ) // OutputType diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 6f5c0f3a6c..4c82b8c81b 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -172,7 +172,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, - const bool pow_2_scaling) { + const bool pow_2_scaling, const float* noop_ptr) { + // skip execution if noop + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE; bool return_columnwise_gemm_ready = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; @@ -520,7 +525,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor SimpleTensor& output_t, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, - const bool pow2_scale, cudaStream_t stream) { + const bool pow2_scale, const SimpleTensor& noop_tensor, + cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; @@ -585,6 +591,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); + const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype, InputType, @@ -613,9 +621,9 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, - columnwise_option, pow2_scale);) // kAligned - ) // OutputType - ) // InputType + columnwise_option, pow2_scale, noop_ptr);) // kAligned + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index fe1a522a96..ce74761a22 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -171,11 +171,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_elt = static_cast(in_act_sh_curr[shmem_idx]); float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); - float dgate_elt = 1.0f; // gating is ideally an identity function + bool dgate_elt = true; // gating is ideally an identity function if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values const float limit = p.limit; - dgate_elt = gate_elt <= limit && gate_elt >= -limit ? 1.0f : 0.0f; // Derivative of clamp + dgate_elt = gate_elt < limit && gate_elt > -limit; // Derivative of clamp gate_elt = min(max(-limit, gate_elt), limit) + 1; } @@ -190,7 +190,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = min(act_elt, limit); const float s = sigmoidf(1.702 * x); act_x = x * s; - if (x <= limit) { + if (x < limit) { dact_x = s + s * (1 - s) * 1.702 * x; } else { dact_x = 0.0f; @@ -207,7 +207,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } float after_dact = dact_x * grad_elt * gate_elt; - float after_dgate = act_x * grad_elt * dgate_elt; + float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f; out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); @@ -494,11 +494,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate_sh[shmem_offset_colwise]); float after_act_elt; float after_gate_elt; - float dgate_elt = 1.0f; // gating is ideally an identity function + bool dgate_elt = true; // gating is ideally an identity function if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values const float limit = p.limit; - dgate_elt = gate_elt <= limit && gate_elt >= -limit ? 1.0f : 0.0f; // Derivative of clamp + dgate_elt = gate_elt < limit && gate_elt > -limit; // Derivative of clamp gate_elt = min(max(-limit, gate_elt), limit) + 1.0f; } if constexpr (IS_DGATED) { @@ -511,7 +511,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = min(act_elt, limit); const float s = sigmoidf(1.702 * x); act_x = x * s; - if (x <= limit) { + if (x < limit) { dact_x = s + s * (1 - s) * 1.702 * x; } else { dact_x = 0.0f; @@ -528,7 +528,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } after_act_elt = dact_x * grad_elt * gate_elt; - after_gate_elt = act_x * grad_elt * dgate_elt; + after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f; } else { after_act_elt = ActOP(act_elt, p) * gate_elt; } @@ -755,12 +755,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate.data.elt[e]); float after_act_elt; float after_gate_elt; - float dgate_elt = 1.0f; + float dgate_elt = true; if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values const float limit = p.limit; - dgate_elt = - gate_elt <= limit && gate_elt >= -limit ? 1.0f : 0.0f; // Derivative of clamp + dgate_elt = gate_elt < limit && gate_elt > -limit; // Derivative of clamp gate_elt = min(max(-limit, gate_elt), limit) + 1.0f; } if constexpr (IS_DGATED) { @@ -773,7 +772,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = min(act_elt, limit); const float s = sigmoidf(1.702 * x); act_x = x * s; - if (x <= limit) { + if (x < limit) { dact_x = s + s * (1 - s) * 1.702 * x; } else { dact_x = 0.0f; @@ -790,7 +789,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } after_act_elt = dact_x * grad_elt * gate_elt; - after_gate_elt = act_x * grad_elt * dgate_elt; + after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f; after_act_rowwise[j] = after_act_elt; after_gate_rowwise[j] = after_gate_elt; } else { diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 1158132e3f..8d87351181 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1427,7 +1427,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor.data, stream); break; } case NVTE_BLOCK_SCALING_1D: { @@ -1455,10 +1456,10 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; } - quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, - output_tensor->columnwise_scale_inv, output_tensor->data, - output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, stream); + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor.data, stream); break; } default: diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 73bb33277f..e3843b0a28 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -61,17 +61,10 @@ __device__ inline OType silu(const IType val, const Empty& e) { return cval * sigmoid(cval, e); } -template -__device__ inline OType clamp(const IType val, const float min_limit, const float max_limit) { - const float cval = val; - return max(min(cval, max_limit), min_limit); -} - template __device__ inline OType oss_silu(const IType val, const GptOssParam& p) { const Empty e = {}; - const float cval = - clamp(val, -std::numeric_limits::infinity(), p.limit); // Clamping + const float cval = min(p.limit, (float)val); // Clamping return qgelu(cval, e); } @@ -81,19 +74,11 @@ __device__ inline OType dsilu(const IType val, const Empty& e) { return cval * dsigmoid(cval, e) + sigmoid(cval, e); } -template -__device__ inline OType dclamp(const IType val, const float min_limit, const float max_limit) { - const float cval = val; - return cval <= max_limit && cval >= min_limit; -} - template __device__ inline OType oss_dsilu(const IType val, const GptOssParam& p) { const Empty e = {}; - const bool dclamp_val = - dclamp(val, -std::numeric_limits::infinity(), p.limit); - const float clamp_val = - clamp(val, -std::numeric_limits::infinity(), p.limit); + const bool dclamp_val = (float)val <= p.limit; + const float clamp_val = min((float)val, p.limit); const float dsilu_val = dqgelu(clamp_val, e); return dclamp_val ? dsilu_val : 0.0f; } diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index c9e54ecef8..270f0375f0 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -540,17 +540,17 @@ __launch_bounds__(unary_kernel_threads) __global__ const ComputeType grad_val = static_cast(grad_loader.separate()[i]); const ComputeType gelu_in = static_cast(input_loader0.separate()[i]); ComputeType gate_in = static_cast(input_loader1.separate()[i]); - ComputeType dgate_in = 1.0f; + bool dgate_in = true; if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values const ComputeType limit = p.limit; - dgate_in = gate_in <= limit && gate_in >= -limit ? 1.0f : 0.0f; // Derivative of clamp + dgate_in = gate_in < limit && gate_in > -limit; // Derivative of clamp gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f; } ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; - ComputeType after_dgate = grad_val * Activation(gelu_in, p) * dgate_in; + ComputeType after_dgate = dgate_in ? grad_val * Activation(gelu_in, p) : 0.0f; if (requires_amax) { __builtin_assume(max >= 0); diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index ef6def2d03..12b35ec43c 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -14,7 +14,7 @@ from . import cpp_extensions as tex -from .quantize.tensor import ScaledTensor +from .quantize.tensor import NoScaleTensor from .quantize.quantizer import Quantizer @@ -22,7 +22,7 @@ def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, -) -> Union[jnp.ndarray, ScaledTensor]: +) -> jnp.ndarray: """Apply activation functions to input tensor with optional quantization. This function applies a sequence of activation functions to the input tensor. @@ -72,8 +72,8 @@ def _activation_fwd_rule(x, activation_type, quantizer): Tuple of (output, context) for backward pass """ fwd_output = tex.act_lu(x, activation_type, quantizer) - if isinstance(fwd_output, ScaledTensor): - fwd_output = fwd_output.dequantize() + # This is a no-op for higher-precision tensors + fwd_output = fwd_output.dequantize() return fwd_output, (x, quantizer) @@ -91,6 +91,10 @@ def _activation_bwd_rule(activation_type, ctx, g): (x, _) = ctx assert x.dtype == g.dtype dx = tex.dact_lu(g, x, activation_type) + # No quantization is used in this VJP backward, so the output should + # always be a NoScaleTensor + assert isinstance(dx, NoScaleTensor) + dx = dx.data return (dx, None) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index fe2253598f..d3c7d2b086 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -29,7 +29,7 @@ ) from .quantization import _jax_dbias, _quantize_dbias_impl from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp -from ..quantize import ScaledTensor, ScaledTensorFactory +from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, QuantizeLayout, @@ -922,7 +922,7 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]: +def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ @@ -941,11 +941,11 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S x = jnp.squeeze(x, axis=-2) if quantizer: return quantizer.quantize(x, flatten_axis=-1) - return x + return NoScaleTensor(data=x, amax=None) def _jax_quantize_dact_dbias( - dz: jnp.ndarray, + dz: Union[jnp.ndarray, NoScaleTensor], x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, @@ -963,7 +963,9 @@ def _jax_quantize_dact_dbias( _, vjp_func = jax.vjp( partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) ) - (dx,) = vjp_func(dz.astype(jnp.float32)) + # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. + dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) + (dx,) = vjp_func(dz) dbias = None if is_dbias: @@ -973,6 +975,7 @@ def _jax_quantize_dact_dbias( dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2) else: dx = dx.astype(x.dtype) + dx = NoScaleTensor(data=dx, amax=None) return dx, dbias @@ -981,7 +984,6 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - noop_scaled_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -990,7 +992,6 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: If quantizer is None: @@ -1035,10 +1036,10 @@ def act_lu( is_outer=True, ) out = out.reshape(output_shape) - if noop_scaled_tensor: - return ScaledTensorFactory.create_2x( - out, None, out, None, scaling_mode=ScalingMode.NO_SCALING, dq_dtype=out.dtype - ) + out = NoScaleTensor( + data=out, + amax=None, + ) return out if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1092,7 +1093,6 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1103,7 +1103,6 @@ def quantize_dact_dbias( activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). is_dbias: If True, compute bias gradient. Defaults to True. quantizer: Optional quantizer for FP8 quantization of the output. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: @@ -1146,19 +1145,10 @@ def quantize_dact_dbias( if is_dbias: dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) - if noop_scaled_tensor: - return ( - ScaledTensorFactory.create_2x( - output, - None, - output, - None, - ScalingMode.NO_SCALING, - dq_dtype=output.dtype, - ), - dbias, - ) - + output = NoScaleTensor( + data=output, + amax=None, + ) return output, dbias # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet @@ -1167,7 +1157,7 @@ def quantize_dact_dbias( dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None ) return _quantize_dbias_impl( - out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 ) is_gated = act_len == 2 @@ -1194,7 +1184,7 @@ def quantize_dact_dbias( quantizer=None, ) out, dbias = _quantize_dbias_impl( - out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 + out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 ) return out, dbias @@ -1258,7 +1248,6 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - noop_scale_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1268,7 +1257,6 @@ def dact_lu( x: Input tensor that was used in forward pass. activation_type: Type of activation function that was applied. quantizer: Optional quantizer for FP8 quantization of the output gradient. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: The gradient of the activation with respect to the input. @@ -1279,6 +1267,5 @@ def dact_lu( activation_type=activation_type, is_dbias=False, quantizer=quantizer, - noop_scaled_tensor=noop_scale_tensor, ) return output diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index be73f708e2..acc8d67274 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -22,6 +22,8 @@ from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize from ..quantize import ( + AbstractBaseTensor, + NoScaleTensor, ScaledTensor, ScaledTensor2x, GroupedScaledTensor1x, @@ -228,6 +230,11 @@ def _dims_are_consecutive(dims): "require non-transposed LHS and transposed RHS operands " "(`contracting_dims=((-1, ), (-1, ))`)." ) + else: + assert lhs.dtype == rhs.dtype, ( + "For TE cuBLAS GEMM for non-quantized inputs, the operand dtypes must be equal." + f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}" + ) # Determine output shape and dtype assert ( @@ -1134,8 +1141,8 @@ def _jax_gemm_fp8_impl(lhs, rhs): def gemm( - lhs: Union[jnp.ndarray, ScaledTensor], - rhs: Union[jnp.ndarray, ScaledTensor], + lhs: Union[jnp.ndarray, AbstractBaseTensor], + rhs: Union[jnp.ndarray, AbstractBaseTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, @@ -1191,6 +1198,11 @@ def gemm( compute the GeLU contribution to the gradient. Only supported with TE's custom call to cuBLAS GEMM. """ + if isinstance(lhs, NoScaleTensor): + lhs = lhs.data + if isinstance(rhs, NoScaleTensor): + rhs = rhs.data + # Try to get LHS and RHS quantizers from a quantizer set for backward compatibility if lhs_quantizer is None or rhs_quantizer is None: quantizer_set = kwargs.get("quantizer_set", None) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 7296afc725..de1877de5c 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -30,7 +30,7 @@ ) from .quantization import _quantize_dbias_impl from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp -from ..quantize import ScaledTensor, ScaledTensorFactory +from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, QuantizeLayout, @@ -845,6 +845,7 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) + ln_out = NoScaleTensor(data=ln_out, amax=None) return ln_out, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1) @@ -869,6 +870,7 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) + ln_out = NoScaleTensor(data=ln_out, amax=None) return ln_out, jnp.squeeze(rsigma, axis=-1) @@ -930,7 +932,7 @@ def layernorm_fwd( scale_dtype=jnp.float32, is_outer=True, ) - return output, mu, rsigma + return NoScaleTensor(data=output, amax=None), mu, rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING @@ -1064,7 +1066,7 @@ def layernorm_bwd( ) mu_empty = jnp.zeros(mu.shape, mu.dtype) rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype) - return vjp_func((dz, mu_empty, rsigma_empty)) + return vjp_func((NoScaleTensor(data=dz, amax=None), mu_empty, rsigma_empty)) return NormBwdPrimitive.outer_primitive.bind( dz, x, @@ -1133,14 +1135,14 @@ def rmsnorm_fwd( scale_dtype=jnp.float32, is_outer=True, ) - return output, rsigma + return NoScaleTensor(data=output, amax=None), rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION ): out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None) - out, _ = _quantize_dbias_impl(out, quantizer) + out, _ = _quantize_dbias_impl(out.data, quantizer) return out, rsigma if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1152,7 +1154,9 @@ def rmsnorm_fwd( epsilon=epsilon, quantizer=None, ) - out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) + out, _ = _quantize_dbias_impl( + out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype + ) return out, rsigma is_2x2x = quantizer.is_2x2x() @@ -1254,7 +1258,7 @@ def rmsnorm_bwd( gamma, ) rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype) - return vjp_func((dz, rsigma_empty)) + return vjp_func((NoScaleTensor(data=dz, amax=None), rsigma_empty)) mu = jnp.empty(()) dx, dgamma, _ = NormBwdPrimitive.outer_primitive.bind( dz, @@ -1276,7 +1280,6 @@ def normalization_fwd( epsilon: float, norm_type: str, quantizer: Optional[Quantizer], - noop_scaled_tensor: bool = False, ): """Common wrapper for normalization forward pass. @@ -1293,7 +1296,6 @@ def normalization_fwd( - 'layernorm': Layer normalization - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: A tuple containing: @@ -1321,20 +1323,6 @@ def normalization_fwd( else: raise ValueError(f"{norm_type=} is not supported.") - if quantizer is None and noop_scaled_tensor: - return ( - ScaledTensorFactory.create_2x( - output, - None, - output, - None, - scaling_mode=ScalingMode.NO_SCALING, - dq_dtype=output.dtype, - ), - mu, - rsigma, - ) - return output, mu, rsigma diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 198beb55eb..1813734b5e 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -4,7 +4,7 @@ """JAX/TE custom ops for quantization""" import operator from functools import reduce -from typing import Tuple, Optional +from typing import Tuple, Optional, Union import math from packaging import version @@ -38,6 +38,7 @@ QuantizeLayout, ScalingMode, compute_scale_from_amax, + NoScaleTensor, ) if version.parse(jax.__version__) >= version.parse("0.5.0"): @@ -64,7 +65,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): 7, 8, 9, - ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer, amax_aval + ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer inner_primitive = None outer_primitive = None @@ -535,11 +536,15 @@ def _jax_quantize( x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 ): if quantizer is None: - return x + if isinstance(x, NoScaleTensor): + return x + return NoScaleTensor(data=x, amax=None) return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) -def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): +def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1): + if isinstance(dx, NoScaleTensor): + dx = dx.data sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis assert sum_axis < dx.ndim, "Flatten axis out of bounds!" dtype = dtype or dx.dtype @@ -558,7 +563,9 @@ def _jax_quantize_dbias( flatten_axis: int = -1, ): if quantizer is None: - return x, None + if isinstance(x, NoScaleTensor): + return x, None + return NoScaleTensor(data=x, amax=None), None return ( quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis), _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis), @@ -566,12 +573,11 @@ def _jax_quantize_dbias( def _quantize_dbias_impl( - x: jnp.ndarray, + x: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -581,28 +587,15 @@ def _quantize_dbias_impl( quantizer is not None ), "quantizer must be provided if dq_dtype is provided" + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + # Early-exit for non-quantized call - dq_dtype = dq_dtype or x.dtype + dq_dtype = dq_dtype or x.data.dtype if quantizer is None: dbias = None if is_dbias: - dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) - if noop_scaled_tensor: - # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor() - # always works. - return ( - ScaledTensorFactory.create_2x( - x, - None, - x, - None, - scaling_mode=ScalingMode.NO_SCALING, - dq_dtype=x.dtype, - data_layout="NN", - flatten_axis=flatten_axis, - ), - dbias, - ) + dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return x, dbias # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, @@ -630,21 +623,25 @@ def _quantize_dbias_impl( dq_dtype=dq_dtype, flatten_axis=flatten_axis, ) - dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) + dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias scale = jnp.empty((), jnp.float32) + amax = None if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Globally reduce amax across all devices for current scaling so we have a single global scale. # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this # until the tensor is dequantized (e.g. in the GEMM). - amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32) + amax = x.amax + if amax is None: + amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,)) scale = compute_scale_from_amax(amax, quantizer.q_dtype) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale # Make sure amax is init with zero - amax = jnp.zeros((1,), jnp.float32) + if amax is None: + amax = jnp.zeros((1,), jnp.float32) # It is faster to use 1x quantization for tensor scaling is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) @@ -665,7 +662,7 @@ def _quantize_dbias_impl( updated_amax, dbias, ) = PrimitiveClass.outer_primitive.bind( - x, + x.data, scale, amax, out_dtype=quantizer.q_dtype, @@ -706,10 +703,9 @@ def _quantize_dbias_impl( def quantize( - x: jnp.ndarray, + x: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, flatten_axis: int = -1, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -719,7 +715,6 @@ def quantize( quantizer: Quantizer for FP8 quantization of the output. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. - noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer is None. Returns: @@ -729,17 +724,15 @@ def quantize( x, quantizer=quantizer, flatten_axis=flatten_axis, - noop_scaled_tensor=noop_scaled_tensor, ) return out def quantize_dbias( - dz: jnp.ndarray, + dz: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, is_dbias: bool = True, flatten_axis: int = -1, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -750,8 +743,6 @@ def quantize_dbias( is_dbias: If True, compute bias gradient. Defaults to True. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. - noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when - quantizer is None. Returns: A tuple containing: @@ -765,7 +756,6 @@ def quantize_dbias( quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis, - noop_scaled_tensor=noop_scaled_tensor, ) @@ -968,7 +958,9 @@ def grouped_quantize( """ if quantizer is None: - return x + if isinstance(x, NoScaleTensor): + return x + return NoScaleTensor(data=x, amax=None) # TODO(Phuong): add support for flatten_axis = -2 assert flatten_axis in ( diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 65d65e7d4a..8087159a3a 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -24,6 +24,7 @@ with_sharding_constraint_by_logical_axes, is_fp8_gemm_with_all_layouts_supported, TensorUsage, + get_quantize_config, ) @@ -80,23 +81,19 @@ def dense( Returns: Transformed output tensor """ - # Remove when tex.quantize() can handle quantizer=None - if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot(): - x = with_sharding_constraint_by_logical_axes(x, input_axes) - output = tex.gemm(x, kernel, contracting_dims=contracting_dims) - if bias is not None: - bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape - output += jnp.reshape(bias, bias_new_shape) - else: - output = _dense( - x, - kernel, - bias, - contracting_dims, - input_axes, - kernel_axes, - quantizer_set, - ) + if not get_quantize_config().is_fp8_enabled(): + input_dtype = x.dtype + kernel = kernel.astype(input_dtype) + + output = _dense( + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + quantizer_set, + ) return output @@ -175,7 +172,9 @@ def _dense_fwd_rule( flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) casted_x = tex.quantize( - x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True + x, + flatten_axis=flatten_axis_x, + quantizer=quantizer_set.x, ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) @@ -183,7 +182,6 @@ def _dense_fwd_rule( kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel, - noop_scaled_tensor=True, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -240,7 +238,6 @@ def _dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, - noop_scaled_tensor=True, ) # GEMM NT @@ -445,7 +442,7 @@ def _grouped_dense_fwd_rule( ctx_kernel = ScaledTensorFactory.create_1x( global_ctx_kernel_data.reshape(-1), ctx_kernel.scale_inv, - ctx_kernel.scaling_mode, + scaling_mode=ctx_kernel.scaling_mode, dq_dtype=ctx_kernel.dq_dtype, is_colwise=False, data_layout="N", @@ -462,7 +459,7 @@ def _grouped_dense_fwd_rule( grouped_gemm_kernel = ScaledTensorFactory.create_1x( grouped_gemm_kernel_data.reshape(-1), ctx_kernel.scale_inv, - ctx_kernel.scaling_mode, + scaling_mode=ctx_kernel.scaling_mode, dq_dtype=ctx_kernel.dq_dtype, is_colwise=True, data_layout="T", diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index 7a3ad597bf..0f5c6aeef6 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -17,7 +17,6 @@ from . import cpp_extensions as tex from .quantize import ( - ScaledTensor, Quantizer, ) @@ -112,8 +111,8 @@ def _layernorm_fwd_rule(x, gamma, beta, norm_type: str, zero_centered_gamma, eps output, mu, rsigma = tex.normalization_fwd( x, gamma, beta, zero_centered_gamma, epsilon, norm_type, quantizer ) - if isinstance(output, ScaledTensor): - output = output.dequantize() + # This is a no-op for higher-precision tensors + output = output.dequantize() return output, (x, mu, rsigma, gamma, beta, quantizer) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index b830cdb4ff..fb97830759 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -22,6 +22,7 @@ noop_quantizer_set, with_sharding_constraint_by_logical_axes, TensorUsage, + get_quantize_config, ) @@ -68,6 +69,11 @@ def layernorm_dense( - The function supports automatic differentiation through JAX's custom VJP - Quantization is applied to both the normalized input and kernel """ + + if not get_quantize_config().is_fp8_enabled(): + input_dtype = x.dtype + kernel = kernel.astype(input_dtype) + output = _layernorm_dense( x, kernel, @@ -188,14 +194,15 @@ def _layernorm_dense_fwd_rule( epsilon, norm_type, quantizer=quantizer_set.x, - noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) # Kernel in (hidden_in, hidden_out...) flatten_axis = 1 - len(kernel.shape) casted_kernel = tex.quantize( - kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True + kernel, + flatten_axis=flatten_axis, + quantizer=quantizer_set.kernel, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -278,7 +285,6 @@ def _layernorm_dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad, - noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 00e3ddc3e8..fc957801af 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -27,6 +27,7 @@ QuantizerSet, noop_quantizer_set, TensorUsage, + get_quantize_config, ) @@ -104,6 +105,11 @@ def layernorm_mlp( not zero_centered_gamma ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" + if not get_quantize_config().is_fp8_enabled(): + input_dtype = x.dtype + kernel_1 = kernel_1.astype(input_dtype) + kernel_2 = kernel_2.astype(input_dtype) + output = _layernorm_mlp( x, gamma, @@ -266,12 +272,13 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, - noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_kernel_1 = tex.quantize( - kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True + kernel_1, + flatten_axis=-2, + quantizer=ffn1_quantizer_set.kernel, ) # NN GEMM @@ -300,13 +307,16 @@ def _layernorm_mlp_fwd_rule( # (batch..., hidden_in) -> (batch..., hidden) casted_act_out = tex.act_lu( - dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True + dot_1_output, + activation_type, + quantizer=ffn2_quantizer_set.x, ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_kernel_2 = tex.quantize( - kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True + kernel_2, + quantizer=ffn2_quantizer_set.kernel, ) # NN GEMM @@ -404,7 +414,9 @@ def _layernorm_mlp_bwd_rule( grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) casted_grad, dbias_2 = tex.quantize_dbias( - grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True + grad, + is_dbias=use_bias_2, + quantizer=ffn1_quantizer_set.dgrad, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -445,7 +457,6 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, - noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 6cecfa361f..306603bbe1 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -19,7 +19,13 @@ from transformer_engine.common import recipe from .scaling_modes import ScalingMode -from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory +from .tensor import ( + ScaledTensor, + ScaledTensor1x, + ScaledTensor2x, + ScaledTensorFactory, + NoScaleTensor, +) from .helper import ( get_quantize_config, get_quantize_config_class, @@ -217,7 +223,11 @@ class CurrentScaleQuantizer(Quantizer): data_layout: str = "NT" def _quantize_func( - self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1 + self, + x: Union[jnp.ndarray, NoScaleTensor], + is_colwise=False, + dq_dtype=None, + flatten_axis=-1, ) -> ScaledTensor1x: """Quantize function helper for delayed scaling FP8. @@ -229,14 +239,17 @@ def _quantize_func( Returns: A ScaledTensor1x containing the quantized data """ - dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + + dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype compute_dtype = jnp.float32 dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) - amax = jnp.max(jnp.abs(x)).reshape((1,)) + amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) - scaled_x = x.astype(compute_dtype) * scale + scaled_x = x.data.astype(compute_dtype) * scale clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) scale_inv = 1.0 / scale @@ -263,7 +276,10 @@ def quantize( Returns: A ScaledTensor1x or ScaledTensor2x containing the quantized data """ - dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + + dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype if flatten_axis < 0: flatten_axis += x.ndim assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" @@ -347,11 +363,14 @@ def _quantize_func( Returns: A ScaledTensor1x containing the quantized data """ - dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + + dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype compute_dtype = jnp.float32 dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) - scaled_x = x.astype(compute_dtype) * self.scale + scaled_x = x.data.astype(compute_dtype) * self.scale # quantize() in the old dot.py do this way, leave this code block here for future debugging # compute_dtype = x.dtype @@ -360,7 +379,8 @@ def _quantize_func( clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) scale_inv = 1.0 / self.scale - self.update(jnp.max(jnp.abs(x)).reshape((1,))) + amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) + self.update(amax) return ScaledTensorFactory.create_1x( data=clipped_scaled_x, scale_inv=scale_inv, @@ -460,6 +480,10 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> Returns: A ScaledTensor1x containing the quantized data """ + if isinstance(x, NoScaleTensor): + # No need for amax in MXFP8 block scaling, so simply extract the jnp.ndarray data tensor from the NoScaleTensor x. + x = x.data + # TODO(Phuong): use quantize_func from JAX if flatten_axis < 0: flatten_axis = x.ndim + flatten_axis diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 868570f73c..e81a614f0e 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -166,6 +166,90 @@ def get_shardy_sharding_rules( """ +class NoScalingModeMetadataImpl(ScalingModeMetadataImpl): + """Implementation for no scaling mode. + + This implementation provides metadata for no scaling mode, for using non-quantized higher-precision datatypes such as bf16. + """ + + def get_scale_dtype(self) -> jnp.dtype: + """Get the data type for scale tensors. This is a placeholder and won't be used for higher-precision values that don't have scaling. + + Returns: + The data type used for scale tensors (float32) + """ + return jnp.float32 + + def get_scale_shape( + self, + data_shape: Tuple[int, ...], + is_colwise: bool = False, + is_padded: bool = True, + flatten_axis: int = -1, + ) -> Tuple[int, ...]: + """Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling. + + Args: + data_shape: The shape of the tensor being scaled + is_colwise: Whether the scaling is column-wise + is_padded: Whether to return padded shape + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + + Returns: + The shape for scale tensors - (1,) + """ + del data_shape, is_colwise, is_padded, flatten_axis + return (0,) + + @lru_cache(maxsize=4) + def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: + """Get the quantize layout for the tensor usage. + + Args: + usage: The usage of the tensor + + Returns: + The quantize layout for the tensor usage + """ + return QuantizeLayout.ROWWISE + + def get_grouped_scale_shape( + self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + ) -> Tuple[int]: + """Get the shape for scale tensors in this mode. + + Args: + data_shape: Original shape of the data tensor + is_colwise: Whether to use column-wise scaling + is_padded: Whether to use padded shapes + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + + Returns: + The shape for scale tensors + """ + del data_shape, group_axis, is_colwise + assert isinstance(n_groups, int) + return (n_groups,) + + def get_shardy_sharding_rules( + self, input_rank, unique_var, flatten_axis + ) -> QuantizeShardyRules: + """Sharding rules for the input and (row, col)wise scale tensors. + + Args: + input_rank: The rank of the input tensor (for which we produce the scale tensor) + unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization. + + Returns: + The Shardy rules for the scaling mode + """ + del flatten_axis + input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) + scale_var = BATCHING + unique_var + "_scale_inv" + return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) + + class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): """Implementation for current scaling mode. @@ -740,5 +824,5 @@ def tree_unflatten(cls, aux_data, _children): ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), # WAR ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(), - ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(), } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 1459175b79..dbbac4abcc 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -25,6 +25,8 @@ __all__ = [ "TensorUsage", + "AbstractBaseTensor", + "NoScaleTensor", "ScaledTensor", "ScaledTensor1x", "ScaledTensor2x", @@ -34,14 +36,9 @@ ] -@register_pytree_node_class @dataclass -class ScaledTensor(ABC): - """Abstract base class for scaled tensors. - - This class defines the interface for all scaled tensor implementations, - providing methods for dequantization and accessing row/column-wise components. - """ +class AbstractBaseTensor(ABC): + """Abstract base class for all tensor types.""" @classmethod def tree_unflatten(cls, aux_data, children): @@ -93,9 +90,76 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st """ +@dataclass +class AbstractBaseTensor1x(AbstractBaseTensor): + """Abstract base class for single layout tensors.""" + + data: jnp.ndarray + amax: jnp.ndarray + + @register_pytree_node_class @dataclass -class ScaledTensor1x(ScaledTensor): +class NoScaleTensor(AbstractBaseTensor1x): + """Higher-precision tensor.""" + + def __post_init__(self): + assert isinstance(self.data, jnp.ndarray), "NoScaleTensor's data must be a jnp.ndarray." + + def tree_flatten(self): + """Flattens the tensor for JAX tree operations. + + Returns: + A tuple containing (children, aux_data) for tree operations + """ + children = (self.data, self.amax) + aux_data = () + return (children, aux_data) + + @property + def ndim(self): + """Number of dimensions of the underlying array.""" + return self.data.ndim + + def dequantize(self): + """This is a no-op for a higher-precision tensor so this simply returns the tensor's data.""" + return self.data + + def get_tensor(self, usage: TensorUsage): + """Returns the tensor based on the tensor usage.""" + q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage) + assert ( + q_layout == QuantizeLayout.ROWWISE + ), "Only ROWWISE layout is supported for NoScaleTensor" + return self + + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): + """Applies sharding constraints to a tensor based on logical axis names. + + Args: + logical_axis_names: Tuple of logical axis names for sharding + + Returns: + The tensor with applied sharding constraints + """ + if not logical_axis_names: + return self + + data = with_sharding_constraint_by_logical_axes(self.data, logical_axis_names) + + return NoScaleTensor( + data=data, + amax=self.amax, + ) + + +class ScaledTensor(ABC): + """Abstract base class for scaled tensors.""" + + +@register_pytree_node_class +@dataclass +class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): """Single-scale quantized tensor implementation. This class represents a tensor quantized with a single scaling factor, @@ -113,9 +177,7 @@ class ScaledTensor1x(ScaledTensor): flatten_axis: The quantization axis for the tensor """ - data: jnp.ndarray scale_inv: jnp.ndarray - amax: jnp.ndarray scaling_mode: ScalingMode dq_dtype: jnp.dtype _dq_func: Callable @@ -154,7 +216,7 @@ def tree_flatten(self): Returns: A tuple containing (children, aux_data) for tree operations """ - children = (self.data, self.scale_inv, self.amax) + children = (self.data, self.amax, self.scale_inv) aux_data = ( self.scaling_mode, self.dq_dtype, @@ -274,15 +336,15 @@ def __init__( self.original_shape = original_shape self.group_axis = group_axis super().__init__( - data, - scale_inv, - amax, - scaling_mode, - dq_dtype, - _dq_func, - is_colwise, - data_layout, - flatten_axis, + data=data, + scale_inv=scale_inv, + amax=amax, + scaling_mode=scaling_mode, + dq_dtype=dq_dtype, + _dq_func=_dq_func, + is_colwise=is_colwise, + data_layout=data_layout, + flatten_axis=flatten_axis, ) def __post_init__(self): @@ -339,7 +401,7 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st @register_pytree_node_class @dataclass -class ScaledTensor2x(ScaledTensor): +class ScaledTensor2x(AbstractBaseTensor, ScaledTensor): """Double-scale quantized tensor implementation. This class represents a tensor quantized with both row-wise and column-wise scaling factors. @@ -503,15 +565,15 @@ def create_1x( flatten_axis = data.ndim - flatten_axis return ScaledTensor1x( - data, - scale_inv, - amax, - scaling_mode, - dq_dtype, - dequantizer.dequantize, - is_colwise, - data_layout, - flatten_axis, + data=data, + scale_inv=scale_inv, + amax=amax, + scaling_mode=scaling_mode, + dq_dtype=dq_dtype, + _dq_func=dequantizer.dequantize, + is_colwise=is_colwise, + data_layout=data_layout, + flatten_axis=flatten_axis, ) @staticmethod @@ -675,7 +737,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, . if isinstance(x, GroupedScaledTensor1x): raise NotImplementedError - if isinstance(x, ScaledTensor): + if isinstance(x, AbstractBaseTensor): return x.apply_sharding_constraint_by_logical_axes(logical_axis_names) return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names) diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 52611f5b03..f1d7f6733b 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -1,17 +1,18 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - #include "../extensions.h" #include "common.h" #include "pybind.h" namespace transformer_engine::pytorch { -template -py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { +using FuncType = void (*)(const NVTETensor, NVTETensor, cudaStream_t); +using FuncWithArgsType = void (*)(const NVTETensor, NVTETensor, const float*, int, cudaStream_t); + +using DFuncType = void (*)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); +using DFuncWithArgsType = void (*)(const NVTETensor, const NVTETensor, NVTETensor, const float*, int, cudaStream_t); + +template +py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1, + const std::vector& args = {}) { init_extension(); // Input tensor @@ -30,31 +31,44 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation directly - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + NVTE_SCOPED_GIL_RELEASE({ + if (!args.empty()) { + act_func_with_args(input_cpp.data(), out_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); + } else { + act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); + } + }); } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // Compute activation in high-precision fused together with amax, then quantize. - auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + NVTE_SCOPED_GIL_RELEASE({ + if (!args.empty()) { + act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); + } else { + act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); + } + }); quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); } else { // Compute activation in high-precision, then quantize - auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + NVTE_SCOPED_GIL_RELEASE({ + if (!args.empty()) { + act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); + } else { + act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); + } + }); quantizer_cpp->quantize(temp_cpp, out_cpp); } return out_py; } -template +template py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, - py::handle quantizer) { + py::handle quantizer, const std::vector& args = {}) { init_extension(); // Grad output and input tensors @@ -66,8 +80,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i // Construct grad input tensor auto quantizer_cpp = convert_quantizer(quantizer); const auto input_shape_te = input_cpp.shape(); - const std::vector input_shape(input_shape_te.data, - input_shape_te.data + input_shape_te.ndim); + const std::vector input_shape(input_shape_te.data, input_shape_te.data + input_shape_te.ndim); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); @@ -76,24 +89,33 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation backward directly NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), - at::cuda::getCurrentCUDAStream()); + if (!args.empty()) { + dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); + } else { + dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), at::cuda::getCurrentCUDAStream()); + } }); } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // Compute activation backward in high-precision fused together with amax, then quantize. auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); + if (!args.empty()) { + dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); + } else { + dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); + } }); quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); } else { // Compute activation backward in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); + if (!args.empty()) { + dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); + } else { + dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); + } }); quantizer_cpp->quantize(temp_cpp, grad_input_cpp); } @@ -101,178 +123,49 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i return grad_input_py; } -/* GELU and variants*/ +/* GELU and variants */ py::object gelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_helper(input, quantizer); } py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object geglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -py::object qgelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); -} - -py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -py::object qgeglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); -} - -py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -/* ReLU and variants*/ -py::object relu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); -} - -py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -py::object reglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); -} - -py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } -py::object srelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); -} - -py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -py::object sreglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); -} - -py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -/* Silu and variants*/ +/* Silu and variants */ py::object silu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_helper(input, quantizer); } py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object swiglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } +/* gpt_oss functions */ py::object gpt_oss_swiglu(const at::Tensor& input, py::handle quantizer, float limit) { - init_extension(); - // Input tensor - auto input_tensor = input.contiguous(); - const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); - - // Construct output tensor - auto quantizer_cpp = convert_quantizer(quantizer); - const auto input_shape = input_cpp.shape(); - std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); - output_shape.back() /= 2; - auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); - auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); - - // Compute activation - if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || - detail::IsMXFP8Quantizers(quantizer.ptr())) { - // Compute activation directly - NVTE_SCOPED_GIL_RELEASE({ - nvte_gptoss_swiglu(input_cpp.data(), out_cpp.data(), limit, at::cuda::getCurrentCUDAStream()); - }); - } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // Compute activation in high-precision fused together with amax, then quantize. - - auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - nvte_gptoss_swiglu(input_cpp.data(), temp_cpp.data(), limit, - at::cuda::getCurrentCUDAStream()); - }); - quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); - } else { - // Compute activation in high-precision, then quantize - - auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - nvte_gptoss_swiglu(input_cpp.data(), temp_cpp.data(), limit, - at::cuda::getCurrentCUDAStream()); - }); - quantizer_cpp->quantize(temp_cpp, out_cpp); - } - return out_py; + std::vector args = {limit}; + return activation_helper(input, quantizer, 2, args); } -py::object gpt_oss_dswiglu(const at::Tensor& grad_output, const at::Tensor& input, - py::handle quantizer, float limit) { - init_extension(); - // Grad output and input tensors - auto grad_output_tensor = grad_output.contiguous(); - auto input_tensor = input.contiguous(); - const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor); - const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); - - // Construct grad input tensor - auto quantizer_cpp = convert_quantizer(quantizer); - const auto input_shape_te = input_cpp.shape(); - const std::vector input_shape(input_shape_te.data, - input_shape_te.data + input_shape_te.ndim); - auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); - auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); - - // Compute activation backward - if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || - detail::IsMXFP8Quantizers(quantizer.ptr())) { - // Compute activation backward directly - NVTE_SCOPED_GIL_RELEASE({ - nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), limit, - at::cuda::getCurrentCUDAStream()); - }); - } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // Compute activation backward in high-precision fused together with amax, then quantize. - auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), limit, - at::cuda::getCurrentCUDAStream()); - }); - quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); - } else { - // Compute activation backward in high-precision, then quantize - auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - nvte_gptoss_dswiglu(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), limit, - at::cuda::getCurrentCUDAStream()); - }); - quantizer_cpp->quantize(temp_cpp, grad_input_cpp); - } - - return grad_input_py; +py::object gpt_oss_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, float limit) { + std::vector args = {limit}; + return dactivation_helper(grad, input, quantizer, args); } -} // namespace transformer_engine::pytorch +} // namespace transformer_engine::pytorch \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index e04d424a36..cd7e70fecb 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -497,7 +497,8 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te // Compute amax if (compute_amax) { - NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); } // Perform amax reduction if needed diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a6275abd19..0f2e3c4de1 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1482,8 +1482,7 @@ def backward_dw(self): (wgrad, bgrad), _ = self.wgrad_store.pop() if not self.fuse_wgrad_accumulation: weight_tensor = noop_cat(self._get_weight_tensors()) - if weight_tensor.grad is None: - weight_tensor.grad = wgrad.to(weight_tensor.dtype) + weight_tensor.grad = wgrad.to(weight_tensor.dtype) if self.use_bias: bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) if bias_tensor.grad is None: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 3d7a5efaca..e9189ccc59 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -452,9 +452,6 @@ def handle_custom_ddp_from_mcore(weight, wgrad): else: wgrad_list = [None] * ctx.num_gemms - if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): - wgrad_list = [None] * ctx.num_gemms - if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() @@ -829,8 +826,7 @@ def backward_dw(self): bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fuse_wgrad_accumulation: for i in range(self.num_gemms): - if weight_params[i].grad is None: - weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype) + weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype) if self.use_bias: for i in range(self.num_gemms): if bias_params[i].grad is None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 182bf99f86..a6c55ceb79 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1197,7 +1197,6 @@ def fc1_wgrad_gemm( "with Userbuffers (tensor-parallel communication overlapping)" ) ctx.wgrad_store.put([ln_out_total, dact], fc1_wgrad_gemm) - fc1_wgrad = None if fuse_gemm_and_bias_fc1_wgrad: fc1_bias_grad = None else: @@ -2168,10 +2167,8 @@ def backward_dw(self): if self.fc1_bias.grad is None: self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype) if not self.fuse_wgrad_accumulation: - if self.fc2_weight.grad is None: - self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype) - if self.fc1_weight.grad is None: - self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype) + self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype) + self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype) del fc2_bias_grad_ del fc2_wgrad del fc1_wgrad diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 323a939223..7cfff1da9d 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -230,6 +230,7 @@ def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, + grad_output_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): @@ -252,6 +253,7 @@ def element_mul_kernel( X_ptr += program_id * X_stride # Load the gradient output value + grad_output_ptr += program_id * grad_output_stride grad_output = tl.load(grad_output_ptr) # Perform the element-wise multiplication @@ -360,6 +362,7 @@ def cross_entropy_backward( _input, _input.stride(-2), grad_output, + 1 if grad_output.numel() > 1 else 0, V, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, From 025ce6b372e15969c0558ad2c72720607f9f4239 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Sep 2025 22:21:57 +0000 Subject: [PATCH 07/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe --- .../common/activation/activation_template.h | 4 +- transformer_engine/common/activation/gelu.cu | 3 +- transformer_engine/common/activation/relu.cu | 3 +- .../common/activation/swiglu.cu | 4 +- .../pytorch/csrc/extensions/activation.cpp | 38 ++++++++++++------- 5 files changed, 31 insertions(+), 21 deletions(-) diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 78b90c2e93..1d9a3fb43c 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -51,7 +51,7 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, } template -void gated_act_fn(const NVTETensor input, NVTETensor output, Param& p, cudaStream_t stream) { +void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { using namespace detail; constexpr bool IS_DGATED = false; constexpr NVTETensor grad = nullptr; @@ -60,7 +60,7 @@ void gated_act_fn(const NVTETensor input, NVTETensor output, Param& p, cudaStrea template -void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param& p, +void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { using namespace detail; constexpr bool IS_DGATED = true; diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 9a5cff7fa2..4949ba5906 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -60,6 +60,5 @@ void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp NVTE_API_CALL(nvte_dqgeglu); using namespace transformer_engine; Empty e = {}; - dgated_act_fn, dqgelu>(grad, input, output, e, - stream); + dgated_act_fn, dqgelu>(grad, input, output, e, stream); } diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index be38e187e8..c74fc6eee9 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -60,6 +60,5 @@ void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp NVTE_API_CALL(nvte_dsreglu); using namespace transformer_engine; Empty e = {}; - dgated_act_fn, dsrelu>(grad, input, output, e, - stream); + dgated_act_fn, dsrelu>(grad, input, output, e, stream); } diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 3b0738b559..d249481602 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -38,7 +38,7 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, const float* const args, int args_size, cudaStream_t stream) { NVTE_API_CALL(nvte_gptoss_swiglu); - NVTE_CHECK(args_size==1); + NVTE_CHECK(args_size == 1); const float limit = *args; using namespace transformer_engine; GptOssParam param = {limit}; @@ -48,7 +48,7 @@ void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, const float* void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, const float* const args, int args_size, cudaStream_t stream) { NVTE_API_CALL(nvte_gptoss_dswiglu); - NVTE_CHECK(args_size==1); + NVTE_CHECK(args_size == 1); const float limit = *args; using namespace transformer_engine; GptOssParam param = {limit}; diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index f1d7f6733b..5dc60a407e 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -8,7 +8,8 @@ using FuncType = void (*)(const NVTETensor, NVTETensor, cudaStream_t); using FuncWithArgsType = void (*)(const NVTETensor, NVTETensor, const float*, int, cudaStream_t); using DFuncType = void (*)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); -using DFuncWithArgsType = void (*)(const NVTETensor, const NVTETensor, NVTETensor, const float*, int, cudaStream_t); +using DFuncWithArgsType = void (*)(const NVTETensor, const NVTETensor, NVTETensor, const float*, + int, cudaStream_t); template py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1, @@ -33,7 +34,8 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int // Compute activation directly NVTE_SCOPED_GIL_RELEASE({ if (!args.empty()) { - act_func_with_args(input_cpp.data(), out_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); + act_func_with_args(input_cpp.data(), out_cpp.data(), args.data(), args.size(), + at::cuda::getCurrentCUDAStream()); } else { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); } @@ -44,7 +46,8 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ if (!args.empty()) { - act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); + act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), + at::cuda::getCurrentCUDAStream()); } else { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); } @@ -55,7 +58,8 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ if (!args.empty()) { - act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); + act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), + at::cuda::getCurrentCUDAStream()); } else { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); } @@ -80,7 +84,8 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i // Construct grad input tensor auto quantizer_cpp = convert_quantizer(quantizer); const auto input_shape_te = input_cpp.shape(); - const std::vector input_shape(input_shape_te.data, input_shape_te.data + input_shape_te.ndim); + const std::vector input_shape(input_shape_te.data, + input_shape_te.data + input_shape_te.ndim); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); @@ -90,9 +95,11 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i // Compute activation backward directly NVTE_SCOPED_GIL_RELEASE({ if (!args.empty()) { - dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); + dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), + args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { - dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), at::cuda::getCurrentCUDAStream()); + dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), + at::cuda::getCurrentCUDAStream()); } }); } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { @@ -101,9 +108,11 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ if (!args.empty()) { - dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); + dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), + args.size(), at::cuda::getCurrentCUDAStream()); } else { - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); + dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); } }); quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); @@ -112,9 +121,11 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ if (!args.empty()) { - dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); + dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), + args.size(), at::cuda::getCurrentCUDAStream()); } else { - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); + dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); } }); quantizer_cpp->quantize(temp_cpp, grad_input_cpp); @@ -163,9 +174,10 @@ py::object gpt_oss_swiglu(const at::Tensor& input, py::handle quantizer, float l return activation_helper(input, quantizer, 2, args); } -py::object gpt_oss_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, float limit) { +py::object gpt_oss_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + float limit) { std::vector args = {limit}; return dactivation_helper(grad, input, quantizer, args); } -} // namespace transformer_engine::pytorch \ No newline at end of file +} // namespace transformer_engine::pytorch From d964b24e6f9c99171ca273e0d51d1af2fa8fb956 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 8 Sep 2025 23:11:53 +0000 Subject: [PATCH 08/53] accidentally had removed some activations, minor bug in the templated function Signed-off-by: Varun Thumbe --- .../pytorch/csrc/extensions/activation.cpp | 67 ++++++++++++++++--- 1 file changed, 58 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 5dc60a407e..a992aee562 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -5,10 +5,10 @@ namespace transformer_engine::pytorch { using FuncType = void (*)(const NVTETensor, NVTETensor, cudaStream_t); -using FuncWithArgsType = void (*)(const NVTETensor, NVTETensor, const float*, int, cudaStream_t); +using FuncWithArgsType = void (*)(const NVTETensor, NVTETensor, const float* const, int, cudaStream_t); using DFuncType = void (*)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); -using DFuncWithArgsType = void (*)(const NVTETensor, const NVTETensor, NVTETensor, const float*, +using DFuncWithArgsType = void (*)(const NVTETensor, const NVTETensor, NVTETensor, const float* const, int, cudaStream_t); template @@ -33,10 +33,11 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation directly NVTE_SCOPED_GIL_RELEASE({ - if (!args.empty()) { + if(act_func == nullptr){ act_func_with_args(input_cpp.data(), out_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); - } else { + } + else { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); } }); @@ -45,7 +46,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if (!args.empty()) { + if(act_func == nullptr){ act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -57,7 +58,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int // Compute activation in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if (!args.empty()) { + if(act_func == nullptr){ act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -94,7 +95,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation backward directly NVTE_SCOPED_GIL_RELEASE({ - if (!args.empty()) { + if(dact_func == nullptr){ dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -107,7 +108,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if (!args.empty()) { + if(dact_func == nullptr){ dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -120,7 +121,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i // Compute activation backward in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if (!args.empty()) { + if(dact_func == nullptr){ dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -151,6 +152,54 @@ py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle qu return dactivation_helper(grad, input, quantizer); } +py::object qgelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} + +py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +py::object qgeglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); +} + +py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +/* ReLU and variants*/ +py::object relu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} + +py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +py::object reglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); +} + +py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +py::object srelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} + +py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +py::object sreglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); +} + +py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} /* Silu and variants */ py::object silu(const at::Tensor& input, py::handle quantizer) { return activation_helper(input, quantizer); From de9ef2fe450daae0d4ea1b647a37219f72814f66 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Sep 2025 23:12:35 +0000 Subject: [PATCH 09/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe --- .../pytorch/csrc/extensions/activation.cpp | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index a992aee562..4ead2878bb 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -5,11 +5,12 @@ namespace transformer_engine::pytorch { using FuncType = void (*)(const NVTETensor, NVTETensor, cudaStream_t); -using FuncWithArgsType = void (*)(const NVTETensor, NVTETensor, const float* const, int, cudaStream_t); +using FuncWithArgsType = void (*)(const NVTETensor, NVTETensor, const float* const, int, + cudaStream_t); using DFuncType = void (*)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); -using DFuncWithArgsType = void (*)(const NVTETensor, const NVTETensor, NVTETensor, const float* const, - int, cudaStream_t); +using DFuncWithArgsType = void (*)(const NVTETensor, const NVTETensor, NVTETensor, + const float* const, int, cudaStream_t); template py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1, @@ -33,11 +34,10 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation directly NVTE_SCOPED_GIL_RELEASE({ - if(act_func == nullptr){ + if (act_func == nullptr) { act_func_with_args(input_cpp.data(), out_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); - } - else { + } else { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); } }); @@ -46,7 +46,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if(act_func == nullptr){ + if (act_func == nullptr) { act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -58,7 +58,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int // Compute activation in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if(act_func == nullptr){ + if (act_func == nullptr) { act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -95,7 +95,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation backward directly NVTE_SCOPED_GIL_RELEASE({ - if(dact_func == nullptr){ + if (dact_func == nullptr) { dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -108,7 +108,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if(dact_func == nullptr){ + if (dact_func == nullptr) { dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -121,7 +121,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i // Compute activation backward in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if(dact_func == nullptr){ + if (dact_func == nullptr) { dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { From 8e174738cc219e40e7f987f2047c45e1366d497c Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 8 Sep 2025 23:18:56 +0000 Subject: [PATCH 10/53] parent de9ef2fe450daae0d4ea1b647a37219f72814f66 author Varun Thumbe 1757373536 +0000 committer Varun Thumbe 1758262513 +0000 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit parent de9ef2fe450daae0d4ea1b647a37219f72814f66 author Varun Thumbe 1757373536 +0000 committer Varun Thumbe 1758262476 +0000 parent de9ef2fe450daae0d4ea1b647a37219f72814f66 author Varun Thumbe 1757373536 +0000 committer Varun Thumbe 1758262304 +0000 merge conflict Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani Add check for sm100 Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz * fix for fp8 blockwise recipe Signed-off-by: zhongboz * resolve comments Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [PyTorch] fix cross entropy vanishing gradients (#2139) * fix cross entropy Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Casper * fix comments Signed-off-by: Casper * fix: few more style issues Signed-off-by: Casper * fix: remove grad_output_stride (unnecessary) Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: only backward was broken Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Generalize cross entropy backward kernel to handle reduced and unreduced loss Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Casper Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon Fix bug when enabling --overlap-grad-reduce in mcore (#2142) * fix bugs when enabling --overlap-grad-reduce in mcore Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI Signed-off-by: Hongbin Liu * format Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongbin Liu Co-authored-by: Hongbin Liu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Fix CUDA version in setup.py (#2132) * Fix CUDA version in setup.py Signed-off-by: Vladimir Cherepanov * Re-enable building comm-gemm tests Signed-off-by: Vladimir Cherepanov * WAR for nvidia-nvshmem package Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [JAX] NoScaleTensor wrapper for non-quantized data (#2136) * Custom call tests passing Signed-off-by: Jeremy Berchtold * Fix test_layer.py Signed-off-by: Jeremy Berchtold * Lint Signed-off-by: Jeremy Berchtold * Fix comments Signed-off-by: Jeremy Berchtold * Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling Signed-off-by: Jeremy Berchtold * Fix shardy issue with amax being shape 1,1,1 instead of shape (1,) Signed-off-by: Jeremy Berchtold * Add higher-precision VJP tests to test_distributed_layernorm_mlp Signed-off-by: Jeremy Berchtold * Cast non-quantized kernels to input dtype in VJPs Signed-off-by: Jeremy Berchtold * Rename HighPrecisionTensor to NoScaleTensor Signed-off-by: Jeremy Berchtold * Use NoScaleTensor in pure JAX impls where it was missing Signed-off-by: Jeremy Berchtold * Fix tests Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold [JAX] Fix GroupedScaledTensor creation with keyword arg (#2154) Fix GroupedScaledTensor creation Signed-off-by: Phuong Nguyen Fixing few issues with multi-process launching. (#2155) * Fixing few issues with multi-process launching. Signed-off-by: Ming Huang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ming Huang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Phuong Nguyen Update list of authorized CI users (#2152) Signed-off-by: Tim Moon Fused RoPE with combined QKV input. (#2122) * Fused RoPE with combined QKV input. Initial commit for Dropout with 8-bit RNG Fix documentation Initial commit for Fused QKV RoPE WIP Initial tests passing Enable rotary percent and margin Enable CP2, start_positions, interleaved Cleanup test Revert "Fix documentation" This reverts commit 53df10044e7769982bd4af2ae2628e6b7717e715. Revert "Initial commit for Dropout with 8-bit RNG" This reverts commit 301505e24031cbcd679069e1c2cd4d00eedf2dca. Cleanup. Minor cleanup Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Optimize kernels Signed-off-by: Vasudevan Rengasamy * Misc. Cleanup Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Optimize kernel performance Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Move fused_qkv_rope test to test_fused_rope.py Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * apply shared memory optimization to separate fused rope kernels Signed-off-by: Xin Yao * fix lint Signed-off-by: Xin Yao --------- Signed-off-by: Vasudevan Rengasamy Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_fused_rope.py | 147 +++++- .../common/fused_rope/fused_rope.cu | 457 ++++++++++++++++-- .../include/transformer_engine/fused_rope.h | 63 +++ .../common/util/cast_gated_kernels.cuh | 12 - transformer_engine/pytorch/attention/rope.py | 163 ++++++- transformer_engine/pytorch/csrc/extensions.h | 12 + .../pytorch/csrc/extensions/activation.cpp | 12 +- .../pytorch/csrc/extensions/apply_rope.cpp | 97 ++++ .../pytorch/csrc/extensions/pybind.cpp | 4 + 9 files changed, 904 insertions(+), 63 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index ae25af9499..62d80b5529 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -1,25 +1,32 @@ # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -from typing import Callable, Tuple, Union +from typing import Callable, Tuple, Union, List import math import torch import pytest from transformer_engine.pytorch.attention.rope import ( RotaryPositionEmbedding, apply_rotary_pos_emb, + apply_fused_qkv_rotary_pos_emb, ) # Gradient is a broadcasted scalar -def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: - return output.sum() * 2 +def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: + if isinstance(output, List): + return sum(t.sum() * 2 for t in output) + else: + return output.sum() * 2 # Gradient is a full tensor -def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: - t = torch.ones_like(output) - return torch.sum(output * t) +def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: + if isinstance(output, List): + return sum(torch.sum(t * torch.ones_like(t)) for t in output) + else: + t = torch.ones_like(output) + return torch.sum(output * t) @pytest.mark.parametrize("start_positions", [True, False]) @@ -238,3 +245,131 @@ def test_fused_rope_thd( torch.testing.assert_close(grad_fused, grad_unfused) assert output_fused.is_contiguous() + + +@pytest.mark.parametrize("start_positions", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096]) +@pytest.mark.parametrize("hidden_size", [64, 128, 256]) +@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) +@pytest.mark.parametrize("margin", [0, 10]) +@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) +@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) +@pytest.mark.parametrize("cp_size", [1, 2]) +@pytest.mark.parametrize("interleaved", [True, False]) +def test_fused_qkv_rope( + dtype: torch.dtype, + seq_length: int, + hidden_size: int, + rotary_percent: float, + margin: int, + tensor_format: str, + loss_func: Callable, + cp_size: int, + interleaved: bool, + start_positions: bool, +) -> None: + if margin == 0 and start_positions == True: + # This makes sure that the `start_positions` offsets being applied + # are with the maximum length of the rope embeddings. + pytest.skip("Skipping test with margin=0 and start_positions=True") + + if start_positions == True and cp_size > 1: + # `start_positions` is only supported for `cp_size=1` and inference. + pytest.skip("Skipping test with cp_size>1 and start_positions=True") + + if seq_length - margin < 0: + pytest.skip("Skipping test with seq_length - margin < 0") + + device = torch.device("cuda:0") + batch_size, head_num = 2, 64 + + t = torch.rand( + (seq_length - margin, batch_size, head_num, hidden_size * 6), + dtype=dtype, + device=device, + ) + + # Get arbitrary offsets to be used with RoPE for all the sequences + start_positions = ( + torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device) + if start_positions + else None + ) + + if tensor_format == "bshd": + t = t.transpose(0, 1).contiguous() + t.requires_grad = True + + rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb_q = rotary_pos_emb_q(seq_length * cp_size) + rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb_k = rotary_pos_emb_k(seq_length * cp_size) + + for cp_rank in range(cp_size): + # unfused + # The fused kernel computes in float32 internally, so we force the unfused func to use float32 + # for more accurate comparison + + t_clone = t.clone() + (query, key, value) = torch.split( + t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3 + ) + query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size) + + query_unfused = apply_rotary_pos_emb( + query, + emb_q, + tensor_format=tensor_format, + start_positions=start_positions, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + key_unfused = apply_rotary_pos_emb( + key, + emb_k, + tensor_format=tensor_format, + start_positions=start_positions, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + value_unfused = value + loss_unfused = loss_func([query_unfused, key_unfused, value_unfused]) + + if not isinstance(start_positions, torch.Tensor): + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() + + t.grad = None + + # fused + query_fused, key_fused, value_fused = apply_fused_qkv_rotary_pos_emb( + t, + emb_q, + emb_k, + tensor_format=tensor_format, + start_positions=start_positions, + interleaved=interleaved, + cp_size=cp_size, + cp_rank=cp_rank, + qkv_split_arg_list=[hidden_size * 4, hidden_size, hidden_size], + ) + loss_fused = loss_func([query_fused, key_fused, value_fused]) + + if not isinstance(start_positions, torch.Tensor): + loss_fused.backward() + grad_fused = t.grad.detach().clone() + t.grad = None + + torch.testing.assert_close(query_fused, query_unfused) + torch.testing.assert_close(key_fused, key_unfused) + torch.testing.assert_close(value_fused, value_unfused) + + if not isinstance(start_positions, torch.Tensor): + torch.testing.assert_close(grad_fused, grad_unfused) diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index df9ea6ee5f..ccd0bc44c5 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -21,12 +21,21 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs const int h, const int d, const int d2, const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { + extern __shared__ float shared_mem_cos_sin[]; + float *shared_mem_cos = shared_mem_cos_sin; + float *shared_mem_sin = shared_mem_cos_sin + d2; + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + __syncthreads(); + #pragma unroll - for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { - float v_cos, v_sin; - sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos); + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + float v_cos = shared_mem_cos[d_id]; + float v_sin = shared_mem_sin[d_id]; int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; @@ -49,12 +58,12 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs // copy the rest if (d > d2) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { - int offset_head = offset_block + h_id * stride_h; - int offset_head_dst = offset_block_dst + h_id * o_stride_h; + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { #pragma unroll - for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { - dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + dst[offset_dst] = src[offset_src]; } } } @@ -67,47 +76,54 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq const int h, const int d, const int d2, const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { + extern __shared__ float shared_mem_cos_sin[]; + float *shared_mem_cos = shared_mem_cos_sin; + float *shared_mem_sin = shared_mem_cos_sin + d2; + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + __syncthreads(); + #pragma unroll - for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { - float v_cos = cosf(freqs[s_id * d2 + d_id]); - float v_sin; - if (!interleaved) { - v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) - : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); - } else { - v_sin = - (d_id % 2 == 0) ? sinf(freqs[s_id * d2 + d_id + 1]) : -sinf(freqs[s_id * d2 + d_id - 1]); - } + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; - float v_src_rotate; + float v_cos = shared_mem_cos[d_id]; + float v_src_rotate, v_sin; if (!interleaved) { - v_src_rotate = (d_id + d2 / 2 < d2) - ? static_cast(src[offset_src + (d2 / 2) * stride_d]) - : static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + if (d_id + d2 / 2 < d2) { + v_src_rotate = static_cast(src[offset_src + (d2 / 2) * stride_d]); + v_sin = shared_mem_sin[d_id + d2 / 2]; + } else { + v_src_rotate = static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + v_sin = -shared_mem_sin[d_id + d2 / 2 - d2]; + } } else { - v_src_rotate = (d_id % 2 == 0) - // d_id + 1 - ? static_cast(src[offset_src + stride_d]) - // d_id - 1 - : static_cast(src[offset_src - stride_d]); + if (d_id % 2 == 0) { + v_src_rotate = static_cast(src[offset_src + stride_d]); + v_sin = shared_mem_sin[d_id + 1]; + } else { + v_src_rotate = static_cast(src[offset_src - stride_d]); + v_sin = -shared_mem_sin[d_id - 1]; + } } dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; } } - // handle the tail + // copy the rest if (d > d2) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { - int offset_head = offset_block + h_id * stride_h; - int offset_head_dst = offset_block_dst + h_id * o_stride_h; + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { #pragma unroll - for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { - dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + dst[offset_dst] = src[offset_src]; } } } @@ -198,6 +214,251 @@ __global__ void fused_rope_backward_kernel( offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } +template +__device__ void fused_qkv_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *out, + const bool interleaved, const int s_id, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, + const int row_offset, const int in_row_length, + const int out_row_length) { + extern __shared__ float shared_mem_cos_sin_qk[]; + // Split the shared memory into cos and sin parts for q or k + float *shared_mem_cos = nullptr; + float *shared_mem_sin = nullptr; + if (row_offset == 0) { // q + shared_mem_cos = shared_mem_cos_sin_qk; + shared_mem_sin = shared_mem_cos_sin_qk + d2; + } else { // k + shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2; + shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2; + } + if (freqs != nullptr) { + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + } + __syncthreads(); + +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id; + if (freqs != nullptr) { + float v_cos, v_sin; + v_cos = shared_mem_cos[d_id]; + v_sin = shared_mem_sin[d_id]; + float v_src = src[offset_src]; + float v_src_rotate; + if (!interleaved) { + v_src_rotate = (d_id + d2 / 2 < d2) + ? -static_cast(src[offset_src + (d2 / 2)]) + : static_cast(src[offset_src + (d2 / 2 - d2)]); + } else { + v_src_rotate = (d_id % 2 == 0) ? -static_cast(src[offset_src + 1]) + : static_cast(src[offset_src - 1]); + } + out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } else { + out[offset_dst] = src[offset_src]; + } + } + } + } + // copy the rest + if (d > d2) { +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { + int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id; + out[offset_dst] = src[offset_src]; + } + } + } + } +} + +template +__device__ void fused_qkv_rope_block_backward(const scalar_t *grad_out, const float *freqs, + scalar_t *out, const bool interleaved, const int s_id, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, + const int row_offset, const int in_row_length, + const int out_row_length) { + extern __shared__ float shared_mem_cos_sin_qk[]; + float *shared_mem_cos = nullptr; + float *shared_mem_sin = nullptr; + // Split the shared memory into cos and sin parts for q or k + if (row_offset == 0) { // q + shared_mem_cos = shared_mem_cos_sin_qk; + shared_mem_sin = shared_mem_cos_sin_qk + d2; + } else { // k + shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2; + shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2; + } + if (freqs != nullptr) { + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + } + __syncthreads(); +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_src = offset_block_dst + h_id * out_row_length + i + d_id; + + float v_src = grad_out[offset_src]; + if (freqs != nullptr) { + float v_cos, v_sin; + v_cos = shared_mem_cos[d_id]; + float v_src_rotate; + if (!interleaved) { + if (d_id + d2 / 2 < d2) { + v_src_rotate = static_cast(grad_out[offset_src + (d2 / 2)]); + v_sin = shared_mem_sin[d_id + d2 / 2]; + } else { + v_src_rotate = static_cast(grad_out[offset_src + (d2 / 2 - d2)]); + v_sin = -shared_mem_sin[d_id + d2 / 2 - d2]; + } + } else { + if (d_id % 2 == 0) { + v_src_rotate = static_cast(grad_out[offset_src + 1]); + v_sin = shared_mem_sin[d_id + 1]; + } else { + v_src_rotate = static_cast(grad_out[offset_src - 1]); + v_sin = -shared_mem_sin[d_id - 1]; + } + } + out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } else { + out[offset_dst] = grad_out[offset_src]; + } + } + } + } + // copy the rest + if (d > d2) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_src = offset_block_dst + h_id * out_row_length + i + d_id; + out[offset_dst] = grad_out[offset_src]; + } + } + } + } +} + +template +__global__ void fused_qkv_rope_forward_kernel( + const scalar_t *qkv_input, const float *q_freqs, const float *k_freqs, + const int *start_positions, scalar_t *q_out, scalar_t *k_out, scalar_t *v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, const int q_split_arg, + const int k_split_arg, const int v_split_arg) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int cur_seqlens = s; + int total_d = q_split_arg + k_split_arg + v_split_arg; + int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v; + if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + offset_block = s_id * b * h * total_d + b_id * h * total_d; + offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg; + offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg; + offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg; + } else { + offset_block = b_id * s * h * total_d + s_id * h * total_d; + offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg; + offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg; + offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg; + } + + int q_limit = q_split_arg; + int k_limit = q_limit + k_split_arg; + int s_id_for_freqs; + if (cp_size > 1) { + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + s_id_for_freqs = s_id + begin_offset; + } + fused_qkv_rope_block_forward(qkv_input, q_freqs, q_out, interleaved, s_id_for_freqs, offset_block, + offset_block_dst_q, h, d, d2, 0, total_d, q_split_arg); + fused_qkv_rope_block_forward(qkv_input, k_freqs, k_out, interleaved, s_id_for_freqs, offset_block, + offset_block_dst_k, h, d, d2, q_limit, total_d, k_split_arg); + fused_qkv_rope_block_forward(qkv_input, nullptr, v_out, interleaved, s_id_for_freqs, offset_block, + offset_block_dst_v, h, d, d2, k_limit, total_d, v_split_arg); +} + +template +__global__ void fused_qkv_rope_backward_kernel( + const scalar_t *grad_out_q, const scalar_t *grad_out_k, const scalar_t *grad_out_v, + const float *q_freqs, const float *k_freqs, scalar_t *qkv_grad, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, const int q_split_arg, + const int k_split_arg, const int v_split_arg) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int cur_seqlens = s; + int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v; + int total_d = q_split_arg + k_split_arg + v_split_arg; + if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + offset_block = s_id * b * h * total_d + b_id * h * total_d; + offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg; + offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg; + offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg; + } else { + offset_block = b_id * s * h * total_d + s_id * h * total_d; + offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg; + offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg; + offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg; + } + int q_limit = q_split_arg; + int k_limit = q_limit + k_split_arg; + int s_id_for_freqs; + if (cp_size > 1) { + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + s_id_for_freqs = s_id; + } + fused_qkv_rope_block_backward(grad_out_q, q_freqs, qkv_grad, interleaved, s_id_for_freqs, + offset_block, offset_block_dst_q, h, d, d2, 0, total_d, + q_split_arg); + fused_qkv_rope_block_backward(grad_out_k, k_freqs, qkv_grad, interleaved, s_id_for_freqs, + offset_block, offset_block_dst_k, h, d, d2, q_limit, total_d, + k_split_arg); + fused_qkv_rope_block_backward(grad_out_v, nullptr, qkv_grad, interleaved, s_id_for_freqs, + offset_block, offset_block_dst_v, h, d, d2, k_limit, total_d, + v_split_arg); +} + template void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs, const int *start_positions, scalar_t *output, @@ -209,6 +470,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin int o_stride_s_or_t, o_stride_b; if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); @@ -224,7 +486,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c const int o_stride_h = d; const int o_stride_d = 1; - fused_rope_forward_kernel<<>>( + fused_rope_forward_kernel<<>>( input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d); @@ -242,6 +504,7 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin int o_stride_s_or_t, o_stride_b; if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); @@ -257,13 +520,58 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se const int o_stride_h = d; const int o_stride_d = 1; - fused_rope_backward_kernel<<>>( + fused_rope_backward_kernel<<>>( output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } +template +void fused_qkv_rope_forward_launcher(const scalar_t *qkv_input, const float *q_freqs, + const float *k_freqs, const int *start_positions, + scalar_t *q_out, scalar_t *k_out, scalar_t *v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + const int THREADS_PER_WARP = 32; + int warps_per_block = (h <= 8) ? h : 8; + dim3 blocks(s, b); + dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k + + fused_qkv_rope_forward_kernel<<>>( + qkv_input, q_freqs, k_freqs, start_positions, q_out, k_out, v_out, qkv_format, interleaved, + cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void fused_qkv_rope_backward_launcher(const scalar_t *q_grad_out, const scalar_t *k_grad_out, + const scalar_t *v_grad_out, const float *q_freqs, + const float *k_freqs, scalar_t *qkv_grad_input, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, + const int b, const int h, const int d, const int d2, + const int qkv_split_arg_list_0, + const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + const int THREADS_PER_WARP = 32; + const int warps_per_block = (h <= 8) ? h : 8; + dim3 blocks(s, b); + dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k + + fused_qkv_rope_backward_kernel<<>>( + q_grad_out, k_grad_out, v_grad_out, q_freqs, k_freqs, qkv_grad_input, qkv_format, interleaved, + cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, const Tensor &start_positions, Tensor *output, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, @@ -297,6 +605,46 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c stride_b, stride_h, stride_d, stream);); } +void fused_qkv_rope_forward(const Tensor &qkv_input, const Tensor &q_freqs, const Tensor &k_freqs, + const Tensor &start_positions, Tensor *q_out, Tensor *k_out, + Tensor *v_out, const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int qkv_split_arg_list_0, + const int qkv_split_arg_list_1, const int qkv_split_arg_list_2, + cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + qkv_input.data.dtype, scalar_t, + fused_qkv_rope_forward_launcher(reinterpret_cast(qkv_input.data.dptr), + reinterpret_cast(q_freqs.data.dptr), + reinterpret_cast(k_freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), + reinterpret_cast(q_out->data.dptr), + reinterpret_cast(k_out->data.dptr), + reinterpret_cast(v_out->data.dptr), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, + qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2, stream);); +} + +void fused_qkv_rope_backward(const Tensor &q_grad_out, const Tensor &k_grad_out, + const Tensor &v_grad_out, const Tensor &q_freqs, const Tensor &k_freqs, + Tensor *qkv_grad_input, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + q_grad_out.data.dtype, scalar_t, + fused_qkv_rope_backward_launcher(reinterpret_cast(q_grad_out.data.dptr), + reinterpret_cast(k_grad_out.data.dptr), + reinterpret_cast(v_grad_out.data.dptr), + reinterpret_cast(q_freqs.data.dptr), + reinterpret_cast(k_freqs.data.dptr), + reinterpret_cast(qkv_grad_input->data.dptr), + qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, + qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2, stream);); +} } // end namespace transformer_engine void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, @@ -328,3 +676,38 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); } + +void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, + const NVTETensor k_freqs, const NVTETensor start_positions, + NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_qkv_rope_forward); + using namespace transformer_engine; + fused_qkv_rope_forward(*convertNVTETensorCheck(qkv_input), *convertNVTETensorCheck(q_freqs), + *convertNVTETensorCheck(k_freqs), *convertNVTETensorCheck(start_positions), + convertNVTETensorCheck(q_out), convertNVTETensorCheck(k_out), + convertNVTETensorCheck(v_out), qkv_format, interleaved, cp_size, cp_rank, + s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2, stream); +} + +void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out, + const NVTETensor v_grad_out, const NVTETensor q_freqs, + const NVTETensor k_freqs, NVTETensor qkv_grad_input, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_qkv_rope_backward); + using namespace transformer_engine; + fused_qkv_rope_backward(*convertNVTETensorCheck(q_grad_out), *convertNVTETensorCheck(k_grad_out), + *convertNVTETensorCheck(v_grad_out), *convertNVTETensorCheck(q_freqs), + *convertNVTETensorCheck(k_freqs), convertNVTETensorCheck(qkv_grad_input), + qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, + qkv_split_arg_list_0, qkv_split_arg_list_1, qkv_split_arg_list_2, stream); +} diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index f0817a97fe..610868f932 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -75,6 +75,69 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream); +/*! \brief Apply rotary positional embedding to the combined QKV input tensor. + * + * \param[in] qkv_input Combined QKV input tensor for fused rope. + * \param[in] q_freqs The freqs tensor for Q. + * \param[in] k_freqs The freqs tensor for K. + * \param[in] start_positions The beginning offsets for applying RoPE embeddings. + * \param[out] q_out Output tensor for Q. + * \param[out] k_out Output tensor for K. + * \param[out] v_out Output tensor for V. + * \param[in] qkv_format QKV format. + * \param[in] interleaved Whether to use interleaved rotary position embedding. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. + * \param[in] s Length of the s dimension of input. + * \param[in] b Length of the b dimension of input. + * \param[in] h Length of the h dimension of input. + * \param[in] d Length of the d dimension of input. + * \param[in] d2 Length of the d dimension of freqs. + * \param[in] qkv_split_arg_list_0 The hidden size for Q. + * \param[in] qkv_split_arg_list_1 The hidden size for K. + * \param[in] qkv_split_arg_list_2 The hidden size for V. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, + const NVTETensor k_freqs, const NVTETensor start_positions, + NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream); + +/*! \brief Compute the backward of the fused qkv rope. + * + * \param[in] q_grad_out Incoming gradient tensor for Q. + * \param[in] k_grad_out Incoming gradient tensor for K. + * \param[in] v_grad_out Incoming gradient tensor for V. + * \param[in] q_freqs The freqs tensor for Q. + * \param[in] k_freqs The freqs tensor for K. + * \param[out] qkv_grad_input Input gradient tensor to calculate. + * \param[in] qkv_format QKV format. + * \param[in] interleaved Whether to use interleaved rotary position embedding. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. + * \param[in] s Length of the s dimension of input. + * \param[in] b Length of the b dimension of input. + * \param[in] h Length of the h dimension of input. + * \param[in] d Length of the d dimension of input. + * \param[in] d2 Length of the d dimension of freqs. + * \param[in] qkv_split_arg_list_0 The hidden size for Q. + * \param[in] qkv_split_arg_list_1 The hidden size for K. + * \param[in] qkv_split_arg_list_2 The hidden size for V. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out, + const NVTETensor v_grad_out, const NVTETensor q_freqs, + const NVTETensor k_freqs, NVTETensor qkv_grad_input, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index ce74761a22..13e028fba7 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -1150,11 +1150,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); - NVTE_CHECK_CUDA(cudaGetLastError()); break; - case ScalingType::COLWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); @@ -1171,11 +1167,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out break; case ScalingType::BIDIMENSIONAL: NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - mxfp8_kernel::cast_mxfp8_gated_kernel <<>>( @@ -1192,11 +1184,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out template void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) { - CheckInputTensor(input, "gated_act_input"); CheckOutputTensor(*output, "gated_act_output"); - NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(), - "Wrong output shape. Expected (after flattening) [", input.flat_first_dim(), - ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); NVTE_CHECK(input.flat_last_dim() % 2 == 0, "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 60685a31d9..139381f2dd 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -5,14 +5,14 @@ """ Rotary Position Embedding implementation of different types along with helper functions """ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, List import torch import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat -__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"] +__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] class RotaryPositionEmbedding(torch.nn.Module): @@ -170,6 +170,86 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], return grad_input, None, None, None, None, None, None, None +class FusedQKVRoPEFunc(torch.autograd.Function): + """ + Function for FusedQKVRoPE + + This implementation accepts combined QKV tensor in `bshd` or `sbhd` format. Q and K RoPE tensors are the additional required inputs. + The RoPE tensors should be of shape (s, 1, 1, d). It produces 3 outputs: Q, K after RoPE, V is the same as input. + """ + + @staticmethod + def forward( + ctx, + qkv: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + start_positions: Union[torch.Tensor, None] = None, + tensor_format: str = "sbhd", + interleaved: bool = False, + cp_size: int = 1, + cp_rank: int = 0, + ) -> torch.Tensor: + """Fused RoPE forward.""" + + if q_freqs.dtype != torch.float32: + q_freqs = q_freqs.float() + if k_freqs.dtype != torch.float32: + k_freqs = k_freqs.float() + assert tensor_format in ( + "sbhd", + "bshd", + ), f"Unsupported tensor_format: {tensor_format}." + assert qkv.is_contiguous(), "QKV Tensor should be contiguous." + assert q_freqs.is_contiguous(), "q_freqs Tensor should be contiguous." + assert k_freqs.is_contiguous(), "k_freqs Tensor should be contiguous." + output = tex.fused_qkv_rope_forward( + qkv, + q_freqs, + k_freqs, + start_positions, + qkv_split_arg_list, + QKVFormat[tensor_format], + interleaved, + cp_size, + cp_rank, + ) + ctx.save_for_backward(q_freqs, k_freqs) + ctx.tensor_format = tensor_format + ctx.qkv_split_arg_list = qkv_split_arg_list + ctx.cp_size = cp_size + ctx.cp_rank = cp_rank + ctx.interleaved = interleaved + return output + + @staticmethod + def backward( + ctx, grad_output_q: torch.Tensor, grad_output_k: torch.Tensor, grad_output_v: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused RoPE backward.""" + q_freqs, k_freqs = ctx.saved_tensors + + grad_output_q = grad_output_q.contiguous() + grad_output_k = grad_output_k.contiguous() + grad_output_v = grad_output_v.contiguous() + + grad_input = tex.fused_qkv_rope_backward( + grad_output_q, + grad_output_k, + grad_output_v, + q_freqs, + k_freqs, + ctx.qkv_split_arg_list, + QKVFormat[ctx.tensor_format], + ctx.interleaved, + ctx.cp_size, + ctx.cp_rank, + ) + + return grad_input, None, None, None, None, None, None, None, None + + def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: """Change sign so the last dimension becomes [-odd, +even] @@ -393,3 +473,82 @@ def apply_rotary_pos_emb( tensor_format, interleaved=interleaved, ) + + +def apply_fused_qkv_rotary_pos_emb( + qkv: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + tensor_format: str = "sbhd", + start_positions: Union[torch.Tensor, None] = None, + interleaved: bool = False, + cu_seqlens: Union[torch.Tensor, None] = None, # pylint: disable=unused-argument + cp_size: int = 1, + cp_rank: int = 0, +) -> torch.Tensor: + """ + Apply rotary positional embedding tensor to the input qkv tensor. + + Support matrix: + Fused: + Training: + qkv_formats: "bshd", "sbhd" + context parallel: yes + start_positions: no + interleaving: yes + Inference: + qkv_formats: "bshd", "sbhd" + context parallelism: no + start_positions: yes + interleaving: yes + + Parameters + ---------- + qkv: torch.Tensor + Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which + rotary positional embedding will be applied. This tensor has q, k, v concatenated + along the last dimension. + q_freqs: torch.Tensor + Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + k_freqs: torch.Tensor + Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + qkv_split_arg_list: List[int] + List of integers that specify the split of the qkv tensor. The list should have 3 elements, + the first element is the number of elements in the q tensor, the second element is the number + of elements in the k tensor, and the third element is the number of elements in the v tensor. + The sum of the elements in the list should be equal to the last dimension of the qkv tensor. + start_positions: torch.Tensor, default = None. + Tokens in a sequence `i` should be applied with position encoding offset by + `start_positions[i]`. If `start_positions=None`, there's no offset. + tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' + is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is + of shape `[seq, bs, ...]`. + interleaved: bool, default = False + Whether to use interleaved rotary position embedding. + cp_size: int, default = 1. + Context parallel world size. + cp_rank: int, default = 0. + Context parallel rank. + """ + + # `start_positions` is only supported for `cp_size=1` and inference. + assert not ( + cp_size > 1 and start_positions is not None + ), """start_positions != None with CP SIZE > 1 is not supported!""" + + assert tensor_format != "thd", "'thd' tensor_format not supported currently." + + return FusedQKVRoPEFunc.apply( + qkv, + q_freqs, + k_freqs, + qkv_split_arg_list, + start_positions, + tensor_format, + interleaved, + cp_size, + cp_rank, + ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 8495179831..9b933256cc 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -342,6 +342,18 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor const std::optional cu_seqlens, const int cp_size, const int cp_rank); +std::tuple fused_qkv_rope_forward( + const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs, + const std::optional start_positions, const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank); + +at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out, + const at::Tensor &v_grad_out, const at::Tensor &q_freqs, + const at::Tensor &k_freqs, + const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank); + /*************************************************************************************************** * Miscellaneous **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 4ead2878bb..a9f6d5a404 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -34,7 +34,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation directly NVTE_SCOPED_GIL_RELEASE({ - if (act_func == nullptr) { + if constexpr (act_func == nullptr) { act_func_with_args(input_cpp.data(), out_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -46,7 +46,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if (act_func == nullptr) { + if constexpr (act_func == nullptr) { act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -58,7 +58,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int // Compute activation in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if (act_func == nullptr) { + if constexpr (act_func == nullptr) { act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -95,7 +95,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation backward directly NVTE_SCOPED_GIL_RELEASE({ - if (dact_func == nullptr) { + if constexpr (dact_func == nullptr) { dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -108,7 +108,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if (dact_func == nullptr) { + if constexpr (dact_func == nullptr) { dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { @@ -121,7 +121,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i // Compute activation backward in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if (dact_func == nullptr) { + if constexpr (dact_func == nullptr) { dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), args.size(), at::cuda::getCurrentCUDAStream()); } else { diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 6f6f827252..d1ba1a351c 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -102,6 +102,65 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, return output; } +std::tuple fused_qkv_rope_forward( + const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs, + const std::optional start_positions, const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, + const int cp_rank) { + TORCH_CHECK(q_freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(q_freqs.size(1) == 1 && q_freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(q_freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + TORCH_CHECK(k_freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(k_freqs.size(1) == 1 && k_freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(k_freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + // output + auto act_options = at::TensorOptions().dtype(qkv_input.scalar_type()).device(qkv_input.device()); + auto q_out_size = qkv_input.sizes().vec(); + q_out_size[2] = q_out_size[2] * qkv_split_arg_list[0] / qkv_split_arg_list[1]; + q_out_size[3] = qkv_split_arg_list[1]; + auto q_out = at::empty(q_out_size, act_options); + auto k_out_size = qkv_input.sizes().vec(); + k_out_size[3] = qkv_split_arg_list[1]; + auto k_out = at::empty(k_out_size, act_options); + auto v_out_size = qkv_input.sizes().vec(); + v_out_size[3] = qkv_split_arg_list[2]; + auto v_out = at::empty(v_out_size, act_options); + + auto qkv_cu = makeTransformerEngineTensor(qkv_input); + auto q_freqs_cu = makeTransformerEngineTensor(q_freqs); + auto k_freqs_cu = makeTransformerEngineTensor(k_freqs); + auto q_out_cu = makeTransformerEngineTensor(q_out); + auto k_out_cu = makeTransformerEngineTensor(k_out); + auto v_out_cu = makeTransformerEngineTensor(v_out); + + auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor + if (start_positions) { + start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + } + + TORCH_CHECK(qkv_input.dim() == 4, "expected 4D input tensor"); + TORCH_CHECK(qkv_input.is_contiguous(), "input tensor must be contiguous"); + + const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD; + const int s = is_sbhd ? qkv_input.size(0) : qkv_input.size(1); + const int b = is_sbhd ? qkv_input.size(1) : qkv_input.size(0); + const int h = qkv_input.size(2); + const int d = qkv_split_arg_list[2]; + const int d2 = q_freqs.size(3); + + nvte_fused_qkv_rope_forward(qkv_cu.data(), q_freqs_cu.data(), k_freqs_cu.data(), + start_positions_cu.data(), q_out_cu.data(), k_out_cu.data(), + v_out_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, h, + d, d2, qkv_split_arg_list[0], qkv_split_arg_list[1], + qkv_split_arg_list[2], at::cuda::getCurrentCUDAStream()); + + return std::make_tuple(q_out, k_out, v_out); +} + at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, const NVTE_QKV_Format qkv_format, const bool interleaved, const std::optional cu_seqlens, const int cp_size, @@ -193,4 +252,42 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor return input_grads; } +at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out, + const at::Tensor &v_grad_out, const at::Tensor &q_freqs, + const at::Tensor &k_freqs, + const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank) { + auto act_options = + at::TensorOptions().dtype(q_grad_out.scalar_type()).device(q_grad_out.device()); + auto qkv_grad_size = q_grad_out.sizes().vec(); + auto total_hd = + (q_grad_out.size(2) + k_grad_out.size(2) + v_grad_out.size(2)) * q_grad_out.size(3); + auto total_d = qkv_split_arg_list[0] + qkv_split_arg_list[1] + qkv_split_arg_list[2]; + qkv_grad_size[2] = total_hd / total_d; + qkv_grad_size[3] = total_d; + auto qkv_grad_input = at::empty(qkv_grad_size, act_options); + const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD; + const int s = is_sbhd ? q_grad_out.size(0) : q_grad_out.size(1); + const int b = is_sbhd ? q_grad_out.size(1) : q_grad_out.size(0); + const int h = qkv_grad_input.size(2); + const int d = qkv_split_arg_list[2]; + const int d2 = q_freqs.size(3); + + auto q_grad_out_cu = makeTransformerEngineTensor(q_grad_out); + auto k_grad_out_cu = makeTransformerEngineTensor(k_grad_out); + auto v_grad_out_cu = makeTransformerEngineTensor(v_grad_out); + auto q_freqs_cu = makeTransformerEngineTensor(q_freqs); + auto k_freqs_cu = makeTransformerEngineTensor(k_freqs); + auto qkv_grad_cu = makeTransformerEngineTensor(qkv_grad_input); + + nvte_fused_qkv_rope_backward(q_grad_out_cu.data(), k_grad_out_cu.data(), v_grad_out_cu.data(), + q_freqs_cu.data(), k_freqs_cu.data(), qkv_grad_cu.data(), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list[0], + qkv_split_arg_list[1], qkv_split_arg_list[2], + at::cuda::getCurrentCUDAStream()); + + return qkv_grad_input; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index b81608b1b2..51ea1a0e8d 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -284,6 +284,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Fused Apply RoPE FWD", py::call_guard()); m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward, "Fused Apply RoPE BWD", py::call_guard()); + m.def("fused_qkv_rope_forward", &transformer_engine::pytorch::fused_qkv_rope_forward, + "Fused Apply QKV RoPE FWD", py::call_guard()); + m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward, + "Fused Apply QKV RoPE BWD", py::call_guard()); // fused router m.def("fused_topk_with_score_function_fwd", From 1f2c65bda2a75e974ba337196f22bf9b16006e95 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 8 Sep 2025 23:26:33 +0000 Subject: [PATCH 11/53] accidentally removed the copyright Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/extensions/activation.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index a9f6d5a404..83979ac291 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -1,3 +1,8 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ #include "../extensions.h" #include "common.h" #include "pybind.h" From 75c4b13f0530a14447bbff38ea2f0c5bb8885b24 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 8 Sep 2025 23:39:34 +0000 Subject: [PATCH 12/53] fix linting issue Signed-off-by: Varun Thumbe --- .../common/util/cast_gated_kernels.cuh | 12 ++---------- transformer_engine/common/util/math.h | 6 +++--- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 13e028fba7..c23bba2e78 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -511,11 +511,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = min(act_elt, limit); const float s = sigmoidf(1.702 * x); act_x = x * s; - if (x < limit) { - dact_x = s + s * (1 - s) * 1.702 * x; - } else { - dact_x = 0.0f; - } + dact_x = x < limit ? s + s * (1 - s) * 1.702 * x : 0.0f; } else { if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); @@ -772,11 +768,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = min(act_elt, limit); const float s = sigmoidf(1.702 * x); act_x = x * s; - if (x < limit) { - dact_x = s + s * (1 - s) * 1.702 * x; - } else { - dact_x = 0.0f; - } + dact_x = x < limit ? s + s * (1 - s) * 1.702 * x : 0.0f; } else { if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index e3843b0a28..e347c4afb5 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -64,7 +64,7 @@ __device__ inline OType silu(const IType val, const Empty& e) { template __device__ inline OType oss_silu(const IType val, const GptOssParam& p) { const Empty e = {}; - const float cval = min(p.limit, (float)val); // Clamping + const float cval = min(p.limit, static_cast(val)); // Clamping return qgelu(cval, e); } @@ -77,8 +77,8 @@ __device__ inline OType dsilu(const IType val, const Empty& e) { template __device__ inline OType oss_dsilu(const IType val, const GptOssParam& p) { const Empty e = {}; - const bool dclamp_val = (float)val <= p.limit; - const float clamp_val = min((float)val, p.limit); + const bool dclamp_val = static_cast(val) <= p.limit; + const float clamp_val = min(static_cast(val), p.limit); const float dsilu_val = dqgelu(clamp_val, e); return dclamp_val ? dsilu_val : 0.0f; } From 288e9266c1efbe3ae4a1ec3658e5e2cc24ca7b62 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 8 Sep 2025 23:57:37 +0000 Subject: [PATCH 13/53] minor issue in comments Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/ops/basic/activation.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 7dfd655e18..5b71d9b032 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -402,12 +402,6 @@ class GptOssSwiglu(_ActivationOperation): \text{GPT-OSS-SwiGLU}(a, b) = \text{clamp}(a, -\infty, \text{limit}) \cdot \sigma(1.702 \cdot \text{clamp}(a, -\infty, \text{limit})) \cdot (\text{clamp}(b, -\text{limit}, \text{limit}) + 1) - where - - .. math:: - - a = x[..., ::2], \quad b = x[..., 1::2] - and :math:`\sigma(x)` is the sigmoid function, and :math:`\text{limit}` is a hyperparameter. Implementation based on `GPT-OSS`__. From 448eceba8398bbea980276516d1338a10cc14d5d Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 9 Sep 2025 20:57:28 -0700 Subject: [PATCH 14/53] Commit is for another PR Signed-off-by: vthumbe1503 --- tests/pytorch/test_sanity.py | 50 +----------------------------------- 1 file changed, 1 insertion(+), 49 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index ae364f80a9..5151aa96e7 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -38,7 +38,7 @@ Float8Quantizer, Float8Tensor, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint from utils import ModelConfig @@ -911,54 +911,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): torch.cuda.synchronize() -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("N", [32]) -@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize( - "input_quantizer", - [ - Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), - MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), - ], -) -@pytest.mark.parametrize( - "out_quantizer", - [ - Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), - MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), - ], -) -def test_sanity_fp8gemm_with_quantization(N, datatype, input_quantizer, out_quantizer): - # For MXFP8 and CurrentScaling, below unfused quantization should happen - # FP8 input --> cublas GEMM --> BF16 output --> Quantize to FP8 --> fp8 Output - offset = 32 - scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype) - scratchpad_fp8 = input_quantizer(scratchpad) - inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N)) - weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N)) - outp_type = torch.float32 - quantized_out, *_ = general_gemm( - weight_fp8, - inp_fp8, - get_workspace(), - outp_type, - quantization_params=out_quantizer, - bias=None, - use_split_accumulator=False, - ) - out, *_ = general_gemm( - weight_fp8, - inp_fp8, - get_workspace(), - outp_type, - quantization_params=None, - bias=None, - use_split_accumulator=False, - ) - expected_quantized_out = out_quantizer(out) - torch.testing.assert_close(expected_quantized_out.dequantize(), quantized_out.dequantize()) - - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_replace_raw_data_for_float8tensor(): """Test the functionality of replace_raw_data""" From 23b582232d19a1d12b53a612fc496da64b9fda57 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 9 Sep 2025 21:01:51 -0700 Subject: [PATCH 15/53] revert changes since this belongs to another PR Signed-off-by: vthumbe1503 --- transformer_engine/pytorch/csrc/quantizer.cpp | 49 ++++++++++++++++++- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index cd7e70fecb..8f0d8c7dc3 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -96,6 +96,16 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); } std::pair Float8Quantizer::create_tensor( @@ -308,6 +318,17 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); + // quantize output and its transpose + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); } std::pair Float8CurrentScalingQuantizer::create_tensor( @@ -541,7 +562,20 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti this->all_gather_usage = quantizer.attr("all_gather_usage").cast(); } -void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} +void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { + // Change the rowwise and columnwise_data to the configured dtype. + // May be a switch between E5M2 and E4M3. + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype) const { @@ -883,7 +917,18 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize this->dtype = quantizer.attr("dtype").cast(); } -void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} +void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, DType dtype) const { From a1a5794a8c803dcf5a1f8bfa2dd11da15c2aef71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Sep 2025 04:02:19 +0000 Subject: [PATCH 16/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/quantizer.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 8f0d8c7dc3..c690cd522a 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -574,7 +574,7 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), rowwise_data.shape); tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); + columnwise_data.shape); } std::pair Float8BlockQuantizer::create_tensor( @@ -927,7 +927,7 @@ void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), rowwise_data.shape); tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); + columnwise_data.shape); } std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, From 0d6a3ea6fa826f1f583ed5294625b2ff9b56ebc9 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 9 Sep 2025 21:08:03 -0700 Subject: [PATCH 17/53] Revert change back since belongs to another PR Signed-off-by: vthumbe1503 --- .../pytorch/csrc/extensions/gemm.cpp | 66 ++++++------------- 1 file changed, 21 insertions(+), 45 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index b9f91c7195..3117705de3 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -93,8 +93,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans bool use_split_accumulator, CommOverlapCore* comm_overlap, std::optional comm_type, MaybeTensor extra_output, bool bulk_overlap, float alpha, std::optional beta) { - using namespace transformer_engine::pytorch::detail; - // Input tensors NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); @@ -125,10 +123,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans "into D tensor. Beta has nothing to be applied to."); } - DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); // Output tensor TensorWrapper D_tensor; if (D.is_none()) { + DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); } else { D_tensor = makeTransformerEngineTensor(D, quantizer); @@ -141,33 +139,12 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } - // maintain unquantized tensor in case we need unfused quantization support. - TensorWrapper unquantized_D_tensor; - py::object unquantized_out; - // Unfused quantization is needed in the following cases - // 1. Inputs: BF16, Output: FP8 (GEMM output has to be BF16, so FP8 quantization needed after that) - // 2. Inputs: FP8, Output: FP8 (For any quantization apart from delayed scaling, - // GEMM Output needs to be in BF16, to allow for unfused quantization) - bool unfused_quantization_needed; - if (low_precision) { - unfused_quantization_needed = !quantizer.is_none() && !IsFloat8Quantizers(quantizer.ptr()); - } else { - unfused_quantization_needed = !quantizer.is_none(); - } - - if (unfused_quantization_needed) { - NoneQuantizer q{none}; - std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype); - } - TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; - // Bias tensor TensorWrapper bias_tensor; MaybeTensor bias_grad = std::nullopt; if (bias.has_value()) { if (grad) { - auto opts = - torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA); + auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA); bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); bias_tensor = makeTransformerEngineTensor(*bias_grad); } else { @@ -180,7 +157,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Activation input tensor MaybeTensor pre_gelu_out = std::nullopt; - DType gelu_type = low_precision ? bias_type : out_tensor.dtype(); + DType gelu_type = low_precision ? bias_type : D_tensor.dtype(); if (gelu) { if (!grad) { auto dtype = GetATenDType(gelu_type); @@ -233,7 +210,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Direct GEMM call to the correct overlap if (bulk_overlap) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, + comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, comm_type.value(), extra_output_tensor, main_stream); @@ -241,14 +218,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else if (comm_type.value() == CommOverlapType::AG) { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, + comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, + comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -257,14 +234,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, + comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, + comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -274,15 +251,15 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), + nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), alpha, *beta, use_split_accumulator, num_math_sms, main_stream); }); } } else { - if (out_tensor.numel() != 0 && !accumulate) { - out_tensor.zero_(main_stream); + if (D_tensor.numel() != 0 && !accumulate) { + D_tensor.zero_(main_stream); } if (bias.has_value()) { if (bias->numel() != 0 && grad) { @@ -290,8 +267,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - std::unique_ptr my_quantizer = convert_quantizer(quantizer); - if (unfused_quantization_needed) my_quantizer->quantize(unquantized_D_tensor, D_tensor); + // Pack outputs std::vector out; out.emplace_back(std::move(D)); @@ -385,7 +361,7 @@ std::optional> te_general_grouped_gemm( auto te_B = makeTransformerEngineTensor(B[i], none); // if there is single output - at::Tensor out_tensor; + at::Tensor D_tensor; auto size_t_shape = pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); bool D_numel_is_zero = false; @@ -400,31 +376,31 @@ std::optional> te_general_grouped_gemm( auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); if (single_output) { if (output_data_ptr == nullptr) { - out_tensor = at::empty(D_shape, opts); + D_tensor = at::empty(D_shape, opts); } else { // We need to check !D_numel_is_zero because if the final input portion has zero elements, // output_data_ptr would point beyond the allocated memory of D. This would cause // at::from_blob to fail as it would reference memory not allocated by CUDA. if (!D_numel_is_zero) { - out_tensor = at::from_blob(output_data_ptr, D_shape, opts); + D_tensor = at::from_blob(output_data_ptr, D_shape, opts); } } char* char_ptr = reinterpret_cast(output_data_ptr); char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size(); output_data_ptr = reinterpret_cast(char_ptr); - D_vectors.emplace_back(out_tensor); + D_vectors.emplace_back(D_tensor); } else { if (D == std::nullopt) { auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); - out_tensor = at::empty(D_shape, opts); - D_vectors.emplace_back(out_tensor); + D_tensor = at::empty(D_shape, opts); + D_vectors.emplace_back(D_tensor); } else { - out_tensor = (*D)[i]; + D_tensor = (*D)[i]; } } if (te_A.numel() == 0 || te_B.numel() == 0) { - if (out_tensor.numel() != 0 && !accumulate) out_tensor.zero_(); + if (D_tensor.numel() != 0 && !accumulate) D_tensor.zero_(); if (bias[i].numel() != 0 && grad) { bias[i].zero_(); } @@ -432,7 +408,7 @@ std::optional> te_general_grouped_gemm( continue; } - auto te_D = makeTransformerEngineTensor(out_tensor); + auto te_D = makeTransformerEngineTensor(D_tensor); auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); From 33c3364d990037adcdab75aea362fe64e3be8c21 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Sep 2025 04:08:31 +0000 Subject: [PATCH 18/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions/gemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 3117705de3..df1c64a344 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -267,7 +267,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - + // Pack outputs std::vector out; out.emplace_back(std::move(D)); From a724c2d0bfcdff7873d469c2d787052357bd1232 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 9 Sep 2025 21:11:48 -0700 Subject: [PATCH 19/53] Changes belong to another PR Signed-off-by: vthumbe1503 --- .../pytorch/csrc/extensions/gemm.cpp | 56 +++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index df1c64a344..7531b91d80 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -124,18 +124,18 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } // Output tensor - TensorWrapper D_tensor; + TensorWrapper out_tensor; if (D.is_none()) { DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); - std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); + std::tie(out_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); } else { - D_tensor = makeTransformerEngineTensor(D, quantizer); - NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()), + out_tensor = makeTransformerEngineTensor(D, quantizer); + NVTE_CHECK(detail::checkGemmShape(D_shape, out_tensor.shape()), "GEMM output has invalid dims (expected ", std::to_string(D_shape), ", got ", - std::to_string(D_tensor.shape()), ")"); + std::to_string(out_tensor.shape()), ")"); if (out_dtype) { - NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ", - static_cast(*out_dtype), ", found ", static_cast(D_tensor.dtype()), ")"); + NVTE_CHECK(*out_dtype == out_tensor.dtype(), "GEMM output has invalid dtype (expected ", + static_cast(*out_dtype), ", found ", static_cast(out_tensor.dtype()), ")"); } } @@ -144,7 +144,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans MaybeTensor bias_grad = std::nullopt; if (bias.has_value()) { if (grad) { - auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA); + auto opts = torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA); bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); bias_tensor = makeTransformerEngineTensor(*bias_grad); } else { @@ -157,7 +157,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Activation input tensor MaybeTensor pre_gelu_out = std::nullopt; - DType gelu_type = low_precision ? bias_type : D_tensor.dtype(); + DType gelu_type = low_precision ? bias_type : out_tensor.dtype(); if (gelu) { if (!grad) { auto dtype = GetATenDType(gelu_type); @@ -210,7 +210,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Direct GEMM call to the correct overlap if (bulk_overlap) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, comm_type.value(), extra_output_tensor, main_stream); @@ -218,14 +218,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else if (comm_type.value() == CommOverlapType::AG) { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -234,14 +234,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -251,15 +251,15 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(), + nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), alpha, *beta, use_split_accumulator, num_math_sms, main_stream); }); } } else { - if (D_tensor.numel() != 0 && !accumulate) { - D_tensor.zero_(main_stream); + if (out_tensor.numel() != 0 && !accumulate) { + out_tensor.zero_(main_stream); } if (bias.has_value()) { if (bias->numel() != 0 && grad) { @@ -294,8 +294,8 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, at::Tensor counter) { // TODO: Handle scaling modes - NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; - NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; + NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYEout_tensor_SCALING; + NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYEout_tensor_SCALING; auto te_A = makeTransformerEngineTensor( A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, @@ -361,7 +361,7 @@ std::optional> te_general_grouped_gemm( auto te_B = makeTransformerEngineTensor(B[i], none); // if there is single output - at::Tensor D_tensor; + at::Tensor out_tensor; auto size_t_shape = pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); bool D_numel_is_zero = false; @@ -376,31 +376,31 @@ std::optional> te_general_grouped_gemm( auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); if (single_output) { if (output_data_ptr == nullptr) { - D_tensor = at::empty(D_shape, opts); + out_tensor = at::empty(D_shape, opts); } else { // We need to check !D_numel_is_zero because if the final input portion has zero elements, // output_data_ptr would point beyond the allocated memory of D. This would cause // at::from_blob to fail as it would reference memory not allocated by CUDA. if (!D_numel_is_zero) { - D_tensor = at::from_blob(output_data_ptr, D_shape, opts); + out_tensor = at::from_blob(output_data_ptr, D_shape, opts); } } char* char_ptr = reinterpret_cast(output_data_ptr); char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size(); output_data_ptr = reinterpret_cast(char_ptr); - D_vectors.emplace_back(D_tensor); + D_vectors.emplace_back(out_tensor); } else { if (D == std::nullopt) { auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); - D_tensor = at::empty(D_shape, opts); - D_vectors.emplace_back(D_tensor); + out_tensor = at::empty(D_shape, opts); + D_vectors.emplace_back(out_tensor); } else { - D_tensor = (*D)[i]; + out_tensor = (*D)[i]; } } if (te_A.numel() == 0 || te_B.numel() == 0) { - if (D_tensor.numel() != 0 && !accumulate) D_tensor.zero_(); + if (out_tensor.numel() != 0 && !accumulate) out_tensor.zero_(); if (bias[i].numel() != 0 && grad) { bias[i].zero_(); } @@ -408,7 +408,7 @@ std::optional> te_general_grouped_gemm( continue; } - auto te_D = makeTransformerEngineTensor(D_tensor); + auto te_D = makeTransformerEngineTensor(out_tensor); auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); From 34d98158fece74cc43f7b3ab15f8d6d723a9f8fe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Sep 2025 04:12:16 +0000 Subject: [PATCH 20/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions/gemm.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 7531b91d80..f5cd80827c 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -135,7 +135,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans std::to_string(out_tensor.shape()), ")"); if (out_dtype) { NVTE_CHECK(*out_dtype == out_tensor.dtype(), "GEMM output has invalid dtype (expected ", - static_cast(*out_dtype), ", found ", static_cast(out_tensor.dtype()), ")"); + static_cast(*out_dtype), ", found ", static_cast(out_tensor.dtype()), + ")"); } } @@ -144,7 +145,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans MaybeTensor bias_grad = std::nullopt; if (bias.has_value()) { if (grad) { - auto opts = torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA); + auto opts = + torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA); bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); bias_tensor = makeTransformerEngineTensor(*bias_grad); } else { From 347526496c8b6439fecebd26fae4b8ab773fe02a Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 9 Sep 2025 21:15:55 -0700 Subject: [PATCH 21/53] Revert changes here Signed-off-by: vthumbe1503 Add bf16/fp32 token-per-expert to the MoE aux loss kernel (#2162) * add bf16/fp32 token-per-expert on the moe-loss-computation on router fusion Signed-off-by: tongliu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: tongliu Co-authored-by: tongliu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [JAX] Scale swizzling via JAX transpose op (#2163) * add swizzle in jax Signed-off-by: Phuong Nguyen * added outer_impl Signed-off-by: Phuong Nguyen * clean up FFI Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Extract cpp distributed tests into a separate project (#2165) * Extract cpp distributed tests into a separate project Signed-off-by: Vladimir Cherepanov * Remove obsolete exclusion Signed-off-by: Vladimir Cherepanov * Run L1_cpp_distributed tests if at least 4 GPUs Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Adds context parallelism utilities: moving cp shards to diff ranks and pad sequence to divisibility factory (#2129) * test - adds unit test for cp utilities and the utilites Signed-off-by: Jonathan Mitchell * assert line change Signed-off-by: Jonathan Mitchell * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jonathan Mitchell Co-authored-by: Jonathan Mitchell Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sudhakar Singh --- qa/L0_cppunittest/test.sh | 2 +- qa/L1_cpp_distributed/test.sh | 10 +- qa/L1_pytorch_distributed_unittest/test.sh | 1 + tests/cpp/CMakeLists.txt | 1 - tests/cpp/comm_gemm/CMakeLists.txt | 19 - tests/cpp_distributed/CMakeLists.txt | 57 ++ .../test_comm_gemm.cu | 2 +- tests/pytorch/attention/test_cp_utils.py | 715 ++++++++++++++++++ .../common/fused_router/fused_moe_aux_loss.cu | 2 +- .../common/fused_router/utils.h | 8 + transformer_engine/jax/cpp_extensions/base.py | 9 +- transformer_engine/jax/cpp_extensions/gemm.py | 97 ++- .../jax/csrc/extensions/gemm.cpp | 52 +- .../dot_product_attention/context_parallel.py | 211 +++++- .../pytorch/csrc/extensions/gemm.cpp | 40 +- 15 files changed, 1100 insertions(+), 126 deletions(-) delete mode 100644 tests/cpp/comm_gemm/CMakeLists.txt create mode 100644 tests/cpp_distributed/CMakeLists.txt rename tests/{cpp/comm_gemm => cpp_distributed}/test_comm_gemm.cu (99%) create mode 100644 tests/pytorch/attention/test_cp_utils.py diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh index aa56d69ed6..cd46b0b63c 100755 --- a/qa/L0_cppunittest/test.sh +++ b/qa/L0_cppunittest/test.sh @@ -17,4 +17,4 @@ cd $TE_PATH/tests/cpp cmake -GNinja -Bbuild . cmake --build build export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS)) -ctest --test-dir build -j$NUM_PARALLEL_JOBS -E '(AgGemm|GemmRs|GemmAr)' +ctest --test-dir build -j$NUM_PARALLEL_JOBS diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh index f4f914b3e9..e074b46ae6 100755 --- a/qa/L1_cpp_distributed/test.sh +++ b/qa/L1_cpp_distributed/test.sh @@ -9,7 +9,9 @@ set -e TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}') export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH -cd $TE_PATH/tests/cpp -cmake -GNinja -S. -Bbuild -cmake --build build -mpirun --allow-run-as-root --np 4 --oversubscribe ./build/comm_gemm/test_comm_gemm +if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then + cd $TE_PATH/tests/cpp_distributed + cmake -GNinja -S. -Bbuild + cmake --build build + mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm +fi diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index e5b4b58617..7f061d222a 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -35,6 +35,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 412c5d34d9..c2c9d0d915 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -43,6 +43,5 @@ include_directories(${CMAKE_SOURCE_DIR}) find_package(CUDAToolkit REQUIRED) include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) -add_subdirectory(comm_gemm) add_subdirectory(operator) add_subdirectory(util) diff --git a/tests/cpp/comm_gemm/CMakeLists.txt b/tests/cpp/comm_gemm/CMakeLists.txt deleted file mode 100644 index 55f5207acf..0000000000 --- a/tests/cpp/comm_gemm/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -add_executable(test_comm_gemm - test_comm_gemm.cu - ../test_common.cu) - -find_package(OpenMP REQUIRED) -find_package(MPI REQUIRED) -find_library(NCCL_LIB - NAMES nccl libnccl - PATH_SUFFIXES lib - REQUIRED) -target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) -target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc CUDNN::cudnn MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) - -include(GoogleTest) -gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt new file mode 100644 index 0000000000..ed3ddeb885 --- /dev/null +++ b/tests/cpp_distributed/CMakeLists.txt @@ -0,0 +1,57 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +cmake_minimum_required(VERSION 3.18) + +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + endif() +endif() + + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +project(transformer_engine_distributed_tests LANGUAGES CUDA CXX) + +add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest) + +include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) + +if(NOT DEFINED TE_LIB_PATH) + execute_process(COMMAND bash -c "python3 -c 'import transformer_engine as te; print(te.__file__)'" + OUTPUT_VARIABLE TE_LIB_FILE + OUTPUT_STRIP_TRAILING_WHITESPACE) + get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) +endif() + +find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) + +message(STATUS "Found transformer_engine library: ${TE_LIB}") +include_directories(../../transformer_engine/common/include) +include_directories(../../transformer_engine/common) +include_directories(../../transformer_engine) +include_directories(${CMAKE_SOURCE_DIR}) + +find_package(CUDAToolkit REQUIRED) + +add_executable(test_comm_gemm + test_comm_gemm.cu + ../cpp/test_common.cu) + +find_package(OpenMP REQUIRED) +find_package(MPI REQUIRED) +find_library(NCCL_LIB + NAMES nccl libnccl + PATH_SUFFIXES lib + REQUIRED) +target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) +target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) + +include(GoogleTest) +gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp/comm_gemm/test_comm_gemm.cu b/tests/cpp_distributed/test_comm_gemm.cu similarity index 99% rename from tests/cpp/comm_gemm/test_comm_gemm.cu rename to tests/cpp_distributed/test_comm_gemm.cu index b34d4db4b8..8355d5f96f 100644 --- a/tests/cpp/comm_gemm/test_comm_gemm.cu +++ b/tests/cpp_distributed/test_comm_gemm.cu @@ -19,7 +19,7 @@ #include #include -#include "../test_common.h" +#include "../cpp/test_common.h" #include "common.h" using transformer_engine::DType; diff --git a/tests/pytorch/attention/test_cp_utils.py b/tests/pytorch/attention/test_cp_utils.py new file mode 100644 index 0000000000..00200c62d2 --- /dev/null +++ b/tests/pytorch/attention/test_cp_utils.py @@ -0,0 +1,715 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Unit tests for context parallel utils.""" +import torch +import unittest +from typing import Tuple +from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( + get_batch_on_this_cp_rank, + pad_thd_sequences_for_cp, + generate_positional_ids_for_cp, +) + + +class TestSequencePadding(unittest.TestCase): + def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor(self): + """Test with custom padding values for all tensors.""" + # Setup + + input_ids = torch.tensor([1, 1, 1, 2, 2, 3, 3, 3, 3]) + cu_seqlens = torch.tensor([0, 3, 5, 9]) + labels = torch.tensor([-100, -100, -100, -100, -100, -100, -100, 13, -100]) + positional_ids = torch.tensor([0, 1, 2, 0, 1, 0, 1, 2, 3]) + divisibility_factor = 8 + + pid = 777 + label_pad = -200 + + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Sequence: [ a a a p p p p p b b pppppp ccccpppp] + print("input_ids_padded: ", input_ids_padded) + print("labels_padded: ", labels_padded) + print("positional_ids_padded: ", positional_ids_padded) + print("cu_seqlens_padded: ", cu_seqlens_padded) + + expected_input_ids = torch.tensor( + [ + 1, + 1, + 1, + pid, + pid, + pid, + pid, + pid, + 2, + 2, + pid, + pid, + pid, + pid, + pid, + pid, + 3, + 3, + 3, + 3, + pid, + pid, + pid, + pid, + ] + ) + expected_cu_seqlens_padded = torch.tensor([0, 8, 16, 24]) + expected_labels_padded = torch.tensor( + [ + -100, + -100, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + -100, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + -100, + -100, + 13, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + ] + ) + expected_positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7] + ) + + assert torch.equal(input_ids_padded, expected_input_ids) + assert torch.equal(labels_padded, expected_labels_padded) + assert torch.equal(positional_ids_padded, expected_positional_ids) + assert torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded) + + def test_mixed_sequence_lengths_with_divisibility_factor(self): + """Test with sequences both shorter and longer than divisibility factor.""" + # Setup - divisibility factor 6 + # Seq 1: length 2 (shorter than 6, needs 4 padding) + # Seq 2: length 7 (longer than 6, needs 5 padding to reach 12) + # Seq 3: length 4 (shorter than 6, needs 2 padding) + # Seq 4: length 10 (longer than 6, needs 2 padding to reach 12) + + input_ids = torch.tensor( + [1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4] + ) + labels = torch.tensor( + [ + 10, + 11, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 30, + 31, + 32, + 33, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + ] + ) + positional_ids = torch.tensor( + [0, 1, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ) + cu_seqlens = torch.tensor([0, 2, 9, 13, 23]) + divisibility_factor = 6 + + pid = 999 + label_pad = -300 + + # Execute + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Assert + # Seq 1: [1,1] + 4 pads = 6 total + # Seq 2: [2,2,2,2,2,2,2] + 5 pads = 12 total + # Seq 3: [3,3,3,3] + 2 pads = 6 total + # Seq 4: [4,4,4,4,4,4,4,4,4,4] + 2 pads = 12 total + + expected_input_ids = torch.tensor( + [ + 1, + 1, + pid, + pid, + pid, + pid, # Seq 1: 2 + 4 padding + 2, + 2, + 2, + 2, + 2, + 2, + 2, + pid, + pid, + pid, + pid, + pid, # Seq 2: 7 + 5 padding + 3, + 3, + 3, + 3, + pid, + pid, # Seq 3: 4 + 2 padding + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + pid, + pid, # Seq 4: 10 + 2 padding + ] + ) + + expected_labels = torch.tensor( + [ + 10, + 11, + label_pad, + label_pad, + label_pad, + label_pad, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + 30, + 31, + 32, + 33, + label_pad, + label_pad, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + label_pad, + label_pad, + ] + ) + + expected_positional_ids = torch.tensor( + [ + 0, + 1, + 2, + 3, + 4, + 5, # Seq 1 positions continue through padding + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, # Seq 2 positions continue + 0, + 1, + 2, + 3, + 4, + 5, # Seq 3 positions continue + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, # Seq 4 positions continue + ] + ) + + expected_cu_seqlens_padded = torch.tensor([0, 6, 18, 24, 36]) + + self.assertTrue(torch.equal(input_ids_padded, expected_input_ids)) + self.assertTrue(torch.equal(labels_padded, expected_labels)) + self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids)) + self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded)) + + def test_sequences_longer_than_divisibility_factor(self): + """Test with all sequences longer than the divisibility factor.""" + # Setup - divisibility factor 4, all sequences longer than 4 + # Seq 1: length 7 (needs 1 padding to reach 8) + # Seq 2: length 11 (needs 1 padding to reach 12) + # Seq 3: length 5 (needs 3 padding to reach 8) + + input_ids = torch.tensor( + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, # 7 tokens + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, # 11 tokens + 3, + 3, + 3, + 3, + 3, # 5 tokens + ] + ) + labels = torch.tensor( + [ + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 300, + 301, + 302, + 303, + 304, + ] + ) + positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 1, 2, 3, 4] + ) + cu_seqlens = torch.tensor([0, 7, 18, 23]) + divisibility_factor = 4 + + pid = 888 + label_pad = -400 + + # Execute + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Assert + # Seq 1: 7 + 1 pad = 8 (divisible by 4) + # Seq 2: 11 + 1 pad = 12 (divisible by 4) + # Seq 3: 5 + 3 pads = 8 (divisible by 4) + + expected_input_ids = torch.tensor( + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + pid, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + pid, + 3, + 3, + 3, + 3, + 3, + pid, + pid, + pid, + ] + ) + + expected_labels = torch.tensor( + [ + 100, + 101, + 102, + 103, + 104, + 105, + 106, + label_pad, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + label_pad, + 300, + 301, + 302, + 303, + 304, + label_pad, + label_pad, + label_pad, + ] + ) + + expected_positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7] + ) + + expected_cu_seqlens_padded = torch.tensor([0, 8, 20, 28]) + + self.assertTrue(torch.equal(input_ids_padded, expected_input_ids)) + self.assertTrue(torch.equal(labels_padded, expected_labels)) + self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids)) + self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded)) + + +class TestContextParallelUtils(unittest.TestCase): + """Test utilities for context parallel functionality.""" + + def setUp(self): + """Set up mock distributed environment.""" + # Mock torch.distributed functions + self.original_get_world_size = torch.distributed.get_world_size + self.original_get_rank = torch.distributed.get_rank + + def tearDown(self): + """Restore original torch.distributed functions.""" + torch.distributed.get_world_size = self.original_get_world_size + torch.distributed.get_rank = self.original_get_rank + + def _mock_distributed_env(self, cp_size, cp_rank): + """Mock the distributed environment for testing.""" + + def mock_get_world_size(group=None): + return cp_size + + def mock_get_rank(group=None): + return cp_rank + + torch.distributed.get_world_size = mock_get_world_size + torch.distributed.get_rank = mock_get_rank + + def test_cp_rank_slicing_simple_case(self): + """Test CP rank slicing with a simple 2-rank, single sequence case.""" + # Setup: Single sequence of length 8, CP size = 2 + # Each sequence gets divided into 2*cp_size = 4 slices of size 2 each + # Rank 0 gets slices [0,1] and [6,7] (first and last) + # Rank 1 gets slices [2,3] and [4,5] (second and second-to-last) + + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) # Shape: (1, 8) - batch first + labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]]) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) # Shape: (8,) - 1D as expected + cu_seqlens = torch.tensor([0, 8]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 0 should get indices [0,1] and [6,7] + expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8]]) + expected_labels_r0 = torch.tensor([[10, 20, 70, 80]]) + expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + # Test rank 1 + self._mock_distributed_env(cp_size=2, cp_rank=1) + input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 1 should get indices [2,3] and [4,5] + expected_input_ids_r1 = torch.tensor([[3, 4, 5, 6]]) + expected_labels_r1 = torch.tensor([[30, 40, 50, 60]]) + expected_pos_ids_r1 = torch.tensor([2, 3, 4, 5]) + + self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1)) + self.assertTrue(torch.equal(labels_r1, expected_labels_r1)) + self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1)) + + def test_cp_rank_slicing_multiple_sequences(self): + """Test CP rank slicing with multiple sequences.""" + # Setup: Two sequences of length 8 each, CP size = 2 + # Total sequence length = 16, cu_seqlens = [0, 8, 16] + + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18]]) + labels = torch.tensor( + [[10, 20, 30, 40, 50, 60, 70, 80, 110, 120, 130, 140, 150, 160, 170, 180]] + ) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]) + cu_seqlens = torch.tensor([0, 8, 16]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # For each sequence, rank 0 gets first and last slices + # Seq 1: indices [0,1] and [6,7] -> values [1,2] and [7,8] + # Seq 2: indices [8,9] and [14,15] -> values [11,12] and [17,18] + expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8, 11, 12, 17, 18]]) + expected_labels_r0 = torch.tensor([[10, 20, 70, 80, 110, 120, 170, 180]]) + expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7, 0, 1, 6, 7]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + def test_cp_rank_slicing_with_cp_size_1(self): + """Test that CP size = 1 returns original tensors unchanged.""" + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) + labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]]) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + cu_seqlens = torch.tensor([0, 8]) + + self._mock_distributed_env(cp_size=1, cp_rank=0) + input_ids_result, labels_result, pos_ids_result = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # With CP size = 1, should return original tensors + self.assertTrue(torch.equal(input_ids_result, input_ids)) + self.assertTrue(torch.equal(labels_result, labels)) + self.assertTrue(torch.equal(pos_ids_result, position_ids)) + + def test_cp_rank_slicing_sequence_dim_detection(self): + """Test that the function correctly detects sequence dimension.""" + # Test with sequence dimension = 0 (sequence_length, batch_size) + input_ids = torch.tensor( + [[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]] + ) # (8, 2) + labels = torch.tensor( + [[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]] + ) + position_ids = torch.tensor( + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]] + ) + cu_seqlens = torch.tensor([0, 8]) + + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Should get indices [0,1] and [6,7] along dimension 0 + expected_input_ids_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]]) + expected_labels_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]]) + expected_pos_ids_r0 = torch.tensor([[0, 0], [1, 1], [6, 6], [7, 7]]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + def test_cp_rank_slicing_mixed_dimensions(self): + """Test CP rank slicing where input_ids/labels are 1D but position_ids has batch dimension.""" + # Setup: Single sequence of length 8, CP size = 2 + # This tests the opposite case from the simple test: + # - input_ids and labels: 1D (no batch dimension) + # - position_ids: 2D (has batch dimension) + + input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # Shape: (8,) - 1D + labels = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80]) # Shape: (8,) - 1D + position_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) # Shape: (1, 8) - 2D with batch + cu_seqlens = torch.tensor([0, 8]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 0 should get indices [0,1] and [6,7] + expected_input_ids_r0 = torch.tensor([1, 2, 7, 8]) # 1D result + expected_labels_r0 = torch.tensor([10, 20, 70, 80]) # 1D result + expected_pos_ids_r0 = torch.tensor([[0, 1, 6, 7]]) # 2D result (preserves batch dim) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + # Test rank 1 + self._mock_distributed_env(cp_size=2, cp_rank=1) + input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 1 should get indices [2,3] and [4,5] + expected_input_ids_r1 = torch.tensor([3, 4, 5, 6]) # 1D result + expected_labels_r1 = torch.tensor([30, 40, 50, 60]) # 1D result + expected_pos_ids_r1 = torch.tensor([[2, 3, 4, 5]]) # 2D result (preserves batch dim) + + self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1)) + self.assertTrue(torch.equal(labels_r1, expected_labels_r1)) + self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1)) + + def test_integration_with_padding_and_cp_slicing(self): + """Integration test: pad sequences then slice for CP ranks.""" + # Start with unpadded sequences + input_ids = torch.tensor([1, 1, 2, 2, 2]) # Two sequences: [1,1] and [2,2,2] + labels = torch.tensor([10, 11, 20, 21, 22]) + positional_ids = torch.tensor([0, 1, 0, 1, 2]) + cu_seqlens = torch.tensor([0, 2, 5]) + divisibility_factor = 4 # Will pad to lengths 4 and 4 + + # First, pad sequences + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=0, + padding_label_id=-100, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Expected after padding: [1,1,0,0,2,2,2,0] with cu_seqlens [0,4,8] + expected_padded = torch.tensor([1, 1, 0, 0, 2, 2, 2, 0]) + self.assertTrue(torch.equal(input_ids_padded, expected_padded)) + + # Now test CP slicing with cp_size=2 + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens_padded, + input_ids_padded.unsqueeze(0), + labels_padded.unsqueeze(0), + positional_ids_padded, + ) + + # Each sequence of length 4 gets divided into 4 slices of size 1 + # Rank 0 gets slices [0] and [3] from each sequence + # Seq 1: indices [0] and [3] -> values [1] and [0] + # Seq 2: indices [4] and [7] -> values [2] and [0] + expected_input_ids_r0 = torch.tensor([[1, 0, 2, 0]]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + + +if __name__ == "__main__": + unittest.main() diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index a738be8736..94082594f6 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -229,7 +229,7 @@ __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf, // Loop: for all positions in each row for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { float C_coeff = Const_buf[0]; - IndexType tokens_per_expert_i = tokens_per_expert[i]; + double tokens_per_expert_i = static_cast(tokens_per_expert[i]); double grad_aux_loss_value = static_cast(grad_aux_loss[0]); // Loop: for all rows for (int j = global_warp_id; j < num_rows; j += global_warp_num) { diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 46e0ba632c..b6f9d87bdc 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -246,6 +246,14 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i using type = int64_t; \ { __VA_ARGS__ } \ } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ default: \ NVTE_ERROR("Invalid type."); \ } diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index a27cec001a..c055705665 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -134,6 +134,13 @@ def impl(): """ return NotImplemented + @classmethod + def outer_impl(cls, *args, **kwargs): + """ + to describe implementation for outer primitive + """ + return cls.impl(*args, **kwargs) + @staticmethod @abstractmethod def batcher(): @@ -196,7 +203,7 @@ def name_of_wrapper_p(): outer_p = core.Primitive(name_of_wrapper_p()) dispatch.prim_requires_devices_during_lowering.add(outer_p) outer_p.multiple_results = cls.multiple_results - outer_p.def_impl(cls.impl) + outer_p.def_impl(cls.outer_impl) outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index acc8d67274..2acc3fb68c 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -152,6 +152,21 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ return lhs_q, rhs_q +@partial(jax.jit, static_argnums=(1, 2)) +def swizzled_scale(scale_inv, flatten_axis, is_colwise): + "Swizzle scale_inv via JAX transpose ops" + original_shape = scale_inv.shape + shape_2d = (math.prod(original_shape[:flatten_axis]), math.prod(original_shape[flatten_axis:])) + if is_colwise: + scale_inv = jnp.transpose(scale_inv.reshape(shape_2d)) + cols, rows = shape_2d + else: + rows, cols = shape_2d + reshape = scale_inv.reshape(rows // 128, 4, 32, cols // 4, 4) + swizzled = jnp.transpose(reshape, (0, 3, 2, 1, 4)) + return swizzled.reshape(original_shape) + + class GemmPrimitive(BasePrimitive): """ Primitive for cuBLAS GEMM @@ -286,28 +301,18 @@ def _dims_are_consecutive(dims): ) pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) - # Need extra workspace for swizzled scale factors - lhs_swizzle_size = 0 - rhs_swizzle_size = 0 - swizzle_dtype = jnp.uint8 - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_swizzle_size = lhs_scale_inv.size - rhs_swizzle_size = rhs_scale_inv.size - lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype) - rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype) - # Declare cuBLAS workspace # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # necessarily 256 bytes aligned, we add some padding to ensure alignment. workspace_size = get_cublas_workspace_size_bytes() + 256 workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) - return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace + return output, bias_grad, pre_gelu_out, workspace @staticmethod def outer_abstract(*args, **kwargs): outputs = GemmPrimitive.abstract(*args, **kwargs) - return outputs[:-3] # discard workspace arrays + return outputs[:-1] # discard workspace array @staticmethod def lowering( @@ -374,24 +379,22 @@ def impl( grad, use_split_accumulator, ): - lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) - lhs_transposed, rhs_transposed = _get_gemm_layout( - (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) - ) - lhs_scale_inv = apply_padding_to_scale_inv( - lhs_scale_inv, - scaling_mode, - lhs.shape, - is_colwise=lhs_transposed, - flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), - ) - rhs_scale_inv = apply_padding_to_scale_inv( - rhs_scale_inv, - scaling_mode, - rhs.shape, - is_colwise=not rhs_transposed, - flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, - ) + if scaling_mode.is_1d_block_scaling(): + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) + ) + lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims) + rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1 + + lhs_scale_inv = apply_padding_to_scale_inv( + lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis + ) + rhs_scale_inv = apply_padding_to_scale_inv( + rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis + ) + lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) + rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) outputs = GemmPrimitive.inner_primitive.bind( lhs, @@ -408,7 +411,39 @@ def impl( grad=grad, use_split_accumulator=use_split_accumulator, ) - return outputs[:-3] # discard workspace arrays + return outputs[:-1] # discard workspace array + + @staticmethod + def outer_impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ): + return GemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ) @staticmethod def batcher( diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 032ac9eb70..113072131d 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -28,8 +28,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { } std::tuple> xla_buffer_to_nvte_gemm_operand( - cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv, - JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { + cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode, + size_t axis_boundary, bool rowwise) { // Set tensor data with collapsed 2D shape auto buffer_dims = buffer.dimensions(); std::vector input_shape = {product(buffer_dims, 0, axis_boundary), @@ -61,40 +61,6 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( } else { input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } - - // Swizzle scaling factors for MXFP8 - if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - // Get the swizzle buffer - NVTE_CHECK(swizzled_scale_inv->element_count() > 0, - "Missing swizzled inverse scale buffer in the JAX primitive."); - auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); - auto swizzled_scale_inv_dtype = - convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type()); - NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1, - "Inverse scale factors need to have an 8-bit data type."); - - // Create tensor to hold swizzled scale factor - TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); - if (rowwise) { - output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); - } else { - output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, - scale_shape); - } - - // Launch swizzle kernel - nvte_swizzle_scaling_factors(input.data(), output.data(), stream); - - // Set swizzled scales into the input tensor - if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); - } else { - input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, - scale_shape); - } - } } return std::make_tuple(std::move(input), input_shape); @@ -103,21 +69,19 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, - Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace, - JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { - // Operands (this includes swizzling MXFP8 scaling factors) // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; - auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( - stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise); - auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( - stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); + auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, scaling_mode, + lhs_axis_boundary, make_lhs_rowwise); + auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, + rhs_axis_boundary, make_rhs_rowwise); // Output tensor std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], @@ -188,8 +152,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Ret() // output .Ret() // bias_grad .Ret() // pre_gelu_out - .Ret() // lhs_swizzled - .Ret() // rhs_swizzled .Ret() // workspace .Attr("scaling_mode") .Attr("lhs_axis_boundary") diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index c6f4647c04..f00bd573f1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4,7 +4,7 @@ """Context Parallelism.""" import os -from typing import List, Union +from typing import List, Union, Tuple import torch import transformer_engine_torch as tex @@ -3927,3 +3927,212 @@ def attn_forward_func_with_cp( raise ValueError(f"Unsupported communication type: {cp_comm_type}!") return out + + +def pad_thd_sequences_for_cp( + input_ids: torch.Tensor, + labels: torch.Tensor, + cu_seqlens: torch.Tensor, + divisibility_factor: int, + padding_token_id: int = 0, + padding_label_id: int = -100, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Pads sequences to be divisible by the divisibility factor. + + Args: + input_ids: Tensor of shape (1, N) or (N,) containing concatenated sequences + labels: Tensor of shape (1, N) or (N,) containing labels for each token + cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths + divisibility_factor: Each sequence length must be divisible by this factor + padding_token_id: Token ID to use for padding (default: 0) + padding_label_id: Label ID to use for padding (default: -100) + + Returns: + Tuple of: + - input_ids_padded: Padded input_ids tensor + - labels_padded: Padded labels tensor + - cu_seqlens_padded: Cumulative sequence lengths accounting for padding + """ + # Flatten input_ids and labels if needed + if input_ids.dim() == 2: + input_ids = input_ids.squeeze(0) + if labels.dim() == 2: + labels = labels.squeeze(0) + + # Compute the sequence lengths from cu_seqlens + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + # List: amount of padding needed for each sequence (make length a multiple of divisibility_factor) + padding_amounts = [ + ((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor + - l.item() + for l in seqlens + ] + + # Extract sequences and labels for each batch item + batch_sequences = [ + input_ids[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]) + ] + batch_labels = [ + labels[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]) + ] + + # Pad sequences and labels to required length + input_ids_padded = torch.cat( + [ + ( + torch.cat([seq, torch.full((pad,), padding_token_id, dtype=seq.dtype)]) + if pad > 0 + else seq + ) + for seq, pad in zip(batch_sequences, padding_amounts) + ] + ) + labels_padded = torch.cat( + [ + ( + torch.cat([seq, torch.full((pad,), padding_label_id, dtype=seq.dtype)]) + if pad > 0 + else seq + ) + for seq, pad in zip(batch_labels, padding_amounts) + ] + ) + + # Compute cumulative padded sequence lengths, starting from 0 + padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) + cu_seqlens_padded = torch.cumsum( + torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), padded_lengths]), dim=0 + ) + + return input_ids_padded, labels_padded, cu_seqlens_padded + + +def generate_positional_ids_for_cp( + cu_seqlens: torch.Tensor, + divisibility_factor: int, + dtype: torch.dtype = torch.long, +) -> torch.Tensor: + """Generate positional IDs for sequences padded to be divisible by divisibility_factor. + + Args: + cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths + divisibility_factor: Each sequence length must be divisible by this factor + dtype: Data type for the generated positional IDs (default: torch.long) + + Returns: + Generated positional_ids tensor where each sequence starts from 0 and continues through padding + """ + # Compute the sequence lengths from cu_seqlens + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + # List: amount of padding needed for each sequence + padding_amounts = [ + ((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor + - l.item() + for l in seqlens + ] + + # Generate positional IDs for each padded sequence (each starts from 0) + padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) + positional_ids = torch.cat( + [torch.arange(0, int(length), dtype=dtype) for length in padded_lengths] + ) + + return positional_ids + + +def get_batch_on_this_cp_rank( + cu_seqlens_padded: torch.Tensor, + input_ids_padded: torch.Tensor, + labels_padded: torch.Tensor, + position_ids_padded: torch.Tensor, + cp_group: torch.distributed.ProcessGroup = None, + qvk_format: str = "thd", +): + """Slice batch input along sequence dimension into multiple chunks for THD format. + + This function is inteded for use in self attention. It will not work for cross attention because + it does not handle the case where the sequence length of the query and key are different. + + Which are parallelized across GPUs in a context parallel group. + This version works with variable-length sequences using cumulative sequence lengths. + """ + if qvk_format not in ["thd", "bshd", "sbhd"]: + raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + if qvk_format == "thd": + # Get context parallel size and rank + cp_size = torch.distributed.get_world_size(group=cp_group) + if cp_size > 1: + cp_rank = torch.distributed.get_rank(group=cp_group) + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_size + slice_sizes = ( + cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + ) // total_slices_of_any_sequence + + # Process each tensor directly instead of using keys_to_change loop + def process_tensor(val): + if val is None: + return val + # Determine which dimension is the sequence dimension + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + if isinstance(cu_seqlens_padded[-1], torch.Tensor): + seq_len_val = cu_seqlens_padded[-1].item() + else: + seq_len_val = cu_seqlens_padded[-1] + + # Handle 1D tensors (like position_ids that don't have batch dimension) + if val.ndim == 1: + if val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "1D tensor shape doesn't match expected sequence length. Make sure the" + " inputs are in THD format and padded correctly." + ) + elif val.ndim >= 2: + if val.shape[1] == seq_len_val: + current_seq_dim = 1 + elif val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "Make sure the inputs are in THD format and padded correctly." + ) + else: + raise ValueError("Tensor must be at least 1D") + + # On this particular rank, for each sequence, get two slices, one from the beginning + # and one from the end. + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(current_seq_dim, torch.cat(cp_rank_slices)) + + # Process each tensor directly + input_ids_padded = process_tensor(input_ids_padded) + labels_padded = process_tensor(labels_padded) + position_ids_padded = process_tensor(position_ids_padded) + else: + raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!") + + return input_ids_padded, labels_padded, position_ids_padded diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index f5cd80827c..f4768bb9ba 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -124,19 +124,18 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } // Output tensor - TensorWrapper out_tensor; + TensorWrapper D_tensor; if (D.is_none()) { DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); - std::tie(out_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); + std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); } else { - out_tensor = makeTransformerEngineTensor(D, quantizer); - NVTE_CHECK(detail::checkGemmShape(D_shape, out_tensor.shape()), + D_tensor = makeTransformerEngineTensor(D, quantizer); + NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()), "GEMM output has invalid dims (expected ", std::to_string(D_shape), ", got ", - std::to_string(out_tensor.shape()), ")"); + std::to_string(D_tensor.shape()), ")"); if (out_dtype) { - NVTE_CHECK(*out_dtype == out_tensor.dtype(), "GEMM output has invalid dtype (expected ", - static_cast(*out_dtype), ", found ", static_cast(out_tensor.dtype()), - ")"); + NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ", + static_cast(*out_dtype), ", found ", static_cast(D_tensor.dtype()), ")"); } } @@ -145,8 +144,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans MaybeTensor bias_grad = std::nullopt; if (bias.has_value()) { if (grad) { - auto opts = - torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA); + auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA); bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); bias_tensor = makeTransformerEngineTensor(*bias_grad); } else { @@ -159,7 +157,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Activation input tensor MaybeTensor pre_gelu_out = std::nullopt; - DType gelu_type = low_precision ? bias_type : out_tensor.dtype(); + DType gelu_type = low_precision ? bias_type : D_tensor.dtype(); if (gelu) { if (!grad) { auto dtype = GetATenDType(gelu_type); @@ -212,7 +210,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Direct GEMM call to the correct overlap if (bulk_overlap) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, + comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, comm_type.value(), extra_output_tensor, main_stream); @@ -220,14 +218,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else if (comm_type.value() == CommOverlapType::AG) { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, + comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, + comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -236,14 +234,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, + comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, + comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -253,15 +251,15 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), + nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), alpha, *beta, use_split_accumulator, num_math_sms, main_stream); }); } } else { - if (out_tensor.numel() != 0 && !accumulate) { - out_tensor.zero_(main_stream); + if (D_tensor.numel() != 0 && !accumulate) { + D_tensor.zero_(main_stream); } if (bias.has_value()) { if (bias->numel() != 0 && grad) { @@ -296,8 +294,8 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, at::Tensor counter) { // TODO: Handle scaling modes - NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYEout_tensor_SCALING; - NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYEout_tensor_SCALING; + NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; + NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; auto te_A = makeTransformerEngineTensor( A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, From 5e687d1671961ff91f877c7ce325fc393d875f6b Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 15 Sep 2025 16:56:07 +0000 Subject: [PATCH 22/53] address review comments Signed-off-by: Varun Thumbe --- tests/pytorch/test_fusible_ops.py | 2 +- .../common/activation/swiglu.cu | 14 +- .../include/transformer_engine/activation.h | 5 +- .../common/util/cast_gated_kernels.cuh | 15 +- transformer_engine/common/util/math.h | 27 +++- transformer_engine/pytorch/csrc/extensions.h | 4 +- .../pytorch/csrc/extensions/activation.cpp | 143 ++++++++++-------- .../pytorch/csrc/extensions/pybind.cpp | 4 +- .../pytorch/ops/basic/activation.py | 7 +- 9 files changed, 127 insertions(+), 94 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 93d0c3dc53..09654f13ff 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1757,7 +1757,7 @@ def test_gpt_oss_swiglu( forward = te_ops.Sequential( te_ops.Quantize(forward=False, backward=quantize_backward), - te_ops.GptOssSwiglu(limit=0.1), + te_ops.GptOssSwiglu(limit=0.1, alpha=1.702), te_ops.Quantize(forward=quantize_forward, backward=False), ) with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index d249481602..36ef809b57 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -35,23 +35,17 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp dgated_act_fn, dsilu>(grad, input, output, e, stream); } -void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, const float* const args, - int args_size, cudaStream_t stream) { +void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, cudaStream_t stream) { NVTE_API_CALL(nvte_gptoss_swiglu); - NVTE_CHECK(args_size == 1); - const float limit = *args; using namespace transformer_engine; - GptOssParam param = {limit}; + GptOssParam param = {limit, alpha}; gated_act_fn>(input, output, param, stream); } -void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, - const float* const args, int args_size, cudaStream_t stream) { +void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, float limit, float alpha, cudaStream_t stream) { NVTE_API_CALL(nvte_gptoss_dswiglu); - NVTE_CHECK(args_size == 1); - const float limit = *args; using namespace transformer_engine; - GptOssParam param = {limit}; + GptOssParam param = {limit, alpha}; dgated_act_fn, oss_dsilu>(grad, input, output, param, stream); } diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index f921851ab9..fd64dccde7 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -186,8 +186,7 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) /* TODO: Add documentation once the API finalizes. */ -void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, const float* const args, - int args_size, cudaStream_t stream); +void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, cudaStream_t stream); void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); @@ -251,7 +250,7 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp TODO: Add documentation once the API finalizes. */ void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, - const float* const args, int args_size, cudaStream_t stream); + float limit, float alpha, cudaStream_t stream); void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index c23bba2e78..21fbc3078d 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -187,11 +187,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float dact_x; if constexpr (std::is_same::value) { const float limit = p.limit; + const float alpha = p.alpha; const float x = min(act_elt, limit); - const float s = sigmoidf(1.702 * x); + const float s = sigmoidf(alpha * x); act_x = x * s; if (x < limit) { - dact_x = s + s * (1 - s) * 1.702 * x; + dact_x = s + s * (1 - s) * alpha * x; } else { dact_x = 0.0f; } @@ -508,10 +509,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float dact_x; if constexpr (std::is_same::value) { const float limit = p.limit; + const float alpha = p.alpha; const float x = min(act_elt, limit); - const float s = sigmoidf(1.702 * x); + const float s = sigmoidf(alpha * x); act_x = x * s; - dact_x = x < limit ? s + s * (1 - s) * 1.702 * x : 0.0f; + dact_x = x < limit ? s + s * (1 - s) * alpha * x : 0.0f; } else { if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); @@ -765,10 +767,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float dact_x; if constexpr (std::is_same::value) { const float limit = p.limit; + const float alpha = p.alpha; const float x = min(act_elt, limit); - const float s = sigmoidf(1.702 * x); + const float s = sigmoidf(alpha * x); act_x = x * s; - dact_x = x < limit ? s + s * (1 - s) * 1.702 * x : 0.0f; + dact_x = x < limit ? s + s * (1 - s) * alpha * x : 0.0f; } else { if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index e347c4afb5..7e387b78e0 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -13,6 +13,7 @@ struct Empty {}; struct GptOssParam { float limit; + float alpha = 1.702f; // Default value for QuickGELU }; template @@ -42,17 +43,29 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) { return s * (1.f - s); } +template +__device__ inline OType qgelu_with_alpha(const IType val, const float alpha) { + const float cval = val; + Empty e = {}; + return cval * sigmoid(alpha * cval, e); +} + template __device__ inline OType qgelu(const IType val, const Empty& e) { + return qgelu_with_alpha(val, 1.702f); +} + +template +__device__ inline OType dqgelu_with_alpha(const IType val, const float alpha) { const float cval = val; - return cval * sigmoid(1.702f * cval, e); + Empty e = {}; + return alpha * cval * dsigmoid(alpha * cval, e) + + sigmoid(alpha * cval, e); } template __device__ inline OType dqgelu(const IType val, const Empty& e) { - const float cval = val; - return 1.702f * cval * dsigmoid(1.702f * cval, e) + - sigmoid(1.702f * cval, e); + return dqgelu_with_alpha(val, 1.702f); } template @@ -63,9 +76,8 @@ __device__ inline OType silu(const IType val, const Empty& e) { template __device__ inline OType oss_silu(const IType val, const GptOssParam& p) { - const Empty e = {}; const float cval = min(p.limit, static_cast(val)); // Clamping - return qgelu(cval, e); + return qgelu_with_alpha(cval, p.alpha); } template @@ -76,10 +88,9 @@ __device__ inline OType dsilu(const IType val, const Empty& e) { template __device__ inline OType oss_dsilu(const IType val, const GptOssParam& p) { - const Empty e = {}; const bool dclamp_val = static_cast(val) <= p.limit; const float clamp_val = min(static_cast(val), p.limit); - const float dsilu_val = dqgelu(clamp_val, e); + const float dsilu_val = dqgelu_with_alpha(clamp_val, p.alpha); return dclamp_val ? dsilu_val : 0.0f; } diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 9b933256cc..cf6651cd5c 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -197,10 +197,10 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer); py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float limit); +py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha); py::object gpt_oss_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, - float limit); + float limit, float alpha); /*************************************************************************************************** * LayerNorm **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 83979ac291..4bb93553de 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -9,17 +9,15 @@ namespace transformer_engine::pytorch { -using FuncType = void (*)(const NVTETensor, NVTETensor, cudaStream_t); -using FuncWithArgsType = void (*)(const NVTETensor, NVTETensor, const float* const, int, - cudaStream_t); - -using DFuncType = void (*)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); -using DFuncWithArgsType = void (*)(const NVTETensor, const NVTETensor, NVTETensor, - const float* const, int, cudaStream_t); - -template -py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1, - const std::vector& args = {}) { +/* Type aliases for readability */ +using FuncType = void (const NVTETensor, NVTETensor, cudaStream_t); +using DFuncType = void (const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); + +template +py::object activation_helper(const at::Tensor& input, + py::handle quantizer, + int shape_divisor = 1, + Args&&... args) { init_extension(); // Input tensor @@ -39,36 +37,48 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation directly NVTE_SCOPED_GIL_RELEASE({ - if constexpr (act_func == nullptr) { - act_func_with_args(input_cpp.data(), out_cpp.data(), args.data(), args.size(), - at::cuda::getCurrentCUDAStream()); - } else { - act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); - } + if constexpr (act_func == nullptr) { + act_func_with_args(input_cpp.data(), + out_cpp.data(), + std::forward(args)..., + at::cuda::getCurrentCUDAStream()); + } else { + act_func(input_cpp.data(), + out_cpp.data(), + at::cuda::getCurrentCUDAStream()); + } }); } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // Compute activation in high-precision fused together with amax, then quantize. auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if constexpr (act_func == nullptr) { - act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), - at::cuda::getCurrentCUDAStream()); - } else { - act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); - } + if constexpr (act_func == nullptr) { + act_func_with_args(input_cpp.data(), + temp_cpp.data(), + std::forward(args)..., + at::cuda::getCurrentCUDAStream()); + } else { + act_func(input_cpp.data(), + temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); + } }); quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); } else { // Compute activation in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if constexpr (act_func == nullptr) { - act_func_with_args(input_cpp.data(), temp_cpp.data(), args.data(), args.size(), - at::cuda::getCurrentCUDAStream()); - } else { - act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); - } + if constexpr (act_func == nullptr) { + act_func_with_args(input_cpp.data(), + temp_cpp.data(), + std::forward(args)..., + at::cuda::getCurrentCUDAStream()); + } else { + act_func(input_cpp.data(), + temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); + } }); quantizer_cpp->quantize(temp_cpp, out_cpp); } @@ -76,9 +86,11 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int return out_py; } -template -py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, - py::handle quantizer, const std::vector& args = {}) { +template +py::object dactivation_helper(const at::Tensor& grad_output, + const at::Tensor& input, + py::handle quantizer, + Args&&... args) { init_extension(); // Grad output and input tensors @@ -100,39 +112,54 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation backward directly NVTE_SCOPED_GIL_RELEASE({ - if constexpr (dact_func == nullptr) { - dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), - args.data(), args.size(), at::cuda::getCurrentCUDAStream()); - } else { - dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), - at::cuda::getCurrentCUDAStream()); - } + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_cpp.data(), + input_cpp.data(), + grad_input_cpp.data(), + std::forward(args)..., + at::cuda::getCurrentCUDAStream()); + } else { + dact_func(grad_output_cpp.data(), + input_cpp.data(), + grad_input_cpp.data(), + at::cuda::getCurrentCUDAStream()); + } }); } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // Compute activation backward in high-precision fused together with amax, then quantize. auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if constexpr (dact_func == nullptr) { - dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), - args.size(), at::cuda::getCurrentCUDAStream()); - } else { - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - } + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_cpp.data(), + input_cpp.data(), + temp_cpp.data(), + std::forward(args)..., + at::cuda::getCurrentCUDAStream()); + } else { + dact_func(grad_output_cpp.data(), + input_cpp.data(), + temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); + } }); quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); } else { // Compute activation backward in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if constexpr (dact_func == nullptr) { - dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), args.data(), - args.size(), at::cuda::getCurrentCUDAStream()); - } else { - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - } + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_cpp.data(), + input_cpp.data(), + temp_cpp.data(), + std::forward(args)..., + at::cuda::getCurrentCUDAStream()); + } else { + dact_func(grad_output_cpp.data(), + input_cpp.data(), + temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); + } }); quantizer_cpp->quantize(temp_cpp, grad_input_cpp); } @@ -223,15 +250,13 @@ py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle q } /* gpt_oss functions */ -py::object gpt_oss_swiglu(const at::Tensor& input, py::handle quantizer, float limit) { - std::vector args = {limit}; - return activation_helper(input, quantizer, 2, args); +py::object gpt_oss_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) { + return activation_helper(input, quantizer, 2, limit, alpha); } py::object gpt_oss_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, - float limit) { - std::vector args = {limit}; - return dactivation_helper(grad, input, quantizer, args); + float limit, float alpha) { + return dactivation_helper(grad, input, quantizer, limit, alpha); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 51ea1a0e8d..45799ce535 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -138,7 +138,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer")); m.def("gpt_oss_swiglu", transformer_engine::pytorch::gpt_oss_swiglu, "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), - py::arg("limit") = 7.0f); + py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); /* Backward of GELU and variants */ m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); @@ -164,7 +164,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("fwd_input"), py::arg("quantizer")); m.def("gpt_oss_dswiglu", transformer_engine::pytorch::gpt_oss_dswiglu, "Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"), - py::arg("quantizer"), py::arg("limit") = 7.0f); + py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); /* DBias + DAct fusions*/ m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 5b71d9b032..dc6b9748e9 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -412,12 +412,13 @@ class GptOssSwiglu(_ActivationOperation): """ - def __init__(self, *, limit: float, cache_quantized_input: bool = False): + def __init__(self, *, limit: float, alpha: float, cache_quantized_input: bool = False): super().__init__(cache_quantized_input=cache_quantized_input) self.limit = limit + self.alpha = alpha def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.gpt_oss_swiglu(*args, limit=self.limit, **kwargs) + return tex.gpt_oss_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.gpt_oss_dswiglu(*args, limit=self.limit, **kwargs) + return tex.gpt_oss_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) From 8535dfb4526bbb00dd213a53d6fa52e280bd10e2 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 15 Sep 2025 17:23:35 +0000 Subject: [PATCH 23/53] cleanup Signed-off-by: Varun Thumbe --- .../include/transformer_engine/activation.h | 44 ++++++++++++++----- .../common/util/vectorized_pointwise.h | 1 - .../pytorch/ops/basic/activation.py | 23 +++++----- 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index fd64dccde7..1e40c38bb6 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -173,21 +173,34 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Computes the gated ReLU activation of the input. + + +/*! \brief Computes the gated Swish activation of the input used in GPT OSS. + * https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 + * This activation has two differences compared to the original SwiGLU + * 1. Both gate and pre-activations are clipped based on parameter limit. + * 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. * * \param[in] input Input tensor of shape [N, H * 2]. * \param[in,out] output Output tensor of shape [N, H]. * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] limit Clipping limits for gate and pre-activation. + * \param[in] alpha Scaling factor for the sigmoid function used in the activation. * \param[in] stream CUDA stream used for the operation. */ - -/* -TODO: Add documentation once the API finalizes. -*/ void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, cudaStream_t stream); +/*! \brief Computes the gated ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the gated Quick GeLU activation of the input. @@ -236,22 +249,33 @@ void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Computes the gated ReLU activation gradient. +/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS. + * https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 + * This activation has two differences compared to the original SwiGLU + * 1. Both gate and pre-activations are clipped based on parameter limit. + * 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. * * \param[in] grad Incoming gradient of shape [N, H]. * \param[in] input Forward input tensor of shape [N, H * 2]. * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] limit Clipping limits for gate and pre-activation. + * \param[in] alpha Scaling factor for the sigmoid function used in the activation. * \param[in] stream CUDA stream used for the operation. */ - -/* -TODO: Add documentation once the API finalizes. -*/ void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, float limit, float alpha, cudaStream_t stream); +/*! \brief Computes the gated ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 270f0375f0..bf38077768 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -435,7 +435,6 @@ __launch_bounds__(unary_kernel_threads) __global__ if constexpr (std::is_same::value) { // Clamp the gated value and add 1 at the end - // https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 ComputeType limit = p.limit; val2 = std::min(std::max(-limit, val2), limit) + 1; } diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index dc6b9748e9..ca1854db7d 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -393,23 +393,24 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: class GptOssSwiglu(_ActivationOperation): - r"""GPT-OSS SwiGLU with clamped SiLU - - The input tensor is split into chunks :math:`a` and :math:`b` - along the last dimension and the following is computed: - - .. math:: - - \text{GPT-OSS-SwiGLU}(a, b) = \text{clamp}(a, -\infty, \text{limit}) \cdot \sigma(1.702 \cdot \text{clamp}(a, -\infty, \text{limit})) \cdot (\text{clamp}(b, -\text{limit}, \text{limit}) + 1) + r"""GPT-OSS + Implementation based on `GPT-OSS`__. - and :math:`\sigma(x)` is the sigmoid function, and :math:`\text{limit}` is a hyperparameter. + This activation has two differences compared to the original SwiGLU + 1. Both gate and pre-activations are clipped based on parameter limit. + 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. + + .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt + from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. - Implementation based on `GPT-OSS`__. Parameters ---------- limit: float The clamp limit. - + alpha: float + The scaling factor for the sigmoid function used in the activation. + cache_quantized_input: bool, default = False + Quantize input tensor when caching for use in the backward pass. """ def __init__(self, *, limit: float, alpha: float, cache_quantized_input: bool = False): From fa0e9a91e36f85374813aec6df5fffe87b5202a2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Sep 2025 16:56:37 +0000 Subject: [PATCH 24/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/activation/swiglu.cu | 6 +- .../pytorch/csrc/extensions/activation.cpp | 121 +++++++----------- 2 files changed, 49 insertions(+), 78 deletions(-) diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 36ef809b57..63b07e69af 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -35,14 +35,16 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp dgated_act_fn, dsilu>(grad, input, output, e, stream); } -void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, cudaStream_t stream) { +void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, + cudaStream_t stream) { NVTE_API_CALL(nvte_gptoss_swiglu); using namespace transformer_engine; GptOssParam param = {limit, alpha}; gated_act_fn>(input, output, param, stream); } -void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, float limit, float alpha, cudaStream_t stream) { +void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + float limit, float alpha, cudaStream_t stream) { NVTE_API_CALL(nvte_gptoss_dswiglu); using namespace transformer_engine; GptOssParam param = {limit, alpha}; diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 4bb93553de..4c01c0cf0f 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -10,14 +10,12 @@ namespace transformer_engine::pytorch { /* Type aliases for readability */ -using FuncType = void (const NVTETensor, NVTETensor, cudaStream_t); -using DFuncType = void (const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); +using FuncType = void(const NVTETensor, NVTETensor, cudaStream_t); +using DFuncType = void(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); template -py::object activation_helper(const at::Tensor& input, - py::handle quantizer, - int shape_divisor = 1, - Args&&... args) { +py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1, + Args&&... args) { init_extension(); // Input tensor @@ -37,48 +35,36 @@ py::object activation_helper(const at::Tensor& input, detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation directly NVTE_SCOPED_GIL_RELEASE({ - if constexpr (act_func == nullptr) { - act_func_with_args(input_cpp.data(), - out_cpp.data(), - std::forward(args)..., - at::cuda::getCurrentCUDAStream()); - } else { - act_func(input_cpp.data(), - out_cpp.data(), - at::cuda::getCurrentCUDAStream()); - } + if constexpr (act_func == nullptr) { + act_func_with_args(input_cpp.data(), out_cpp.data(), std::forward(args)..., + at::cuda::getCurrentCUDAStream()); + } else { + act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); + } }); } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // Compute activation in high-precision fused together with amax, then quantize. auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if constexpr (act_func == nullptr) { - act_func_with_args(input_cpp.data(), - temp_cpp.data(), - std::forward(args)..., - at::cuda::getCurrentCUDAStream()); - } else { - act_func(input_cpp.data(), - temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - } + if constexpr (act_func == nullptr) { + act_func_with_args(input_cpp.data(), temp_cpp.data(), std::forward(args)..., + at::cuda::getCurrentCUDAStream()); + } else { + act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); + } }); quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); } else { // Compute activation in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if constexpr (act_func == nullptr) { - act_func_with_args(input_cpp.data(), - temp_cpp.data(), - std::forward(args)..., - at::cuda::getCurrentCUDAStream()); - } else { - act_func(input_cpp.data(), - temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - } + if constexpr (act_func == nullptr) { + act_func_with_args(input_cpp.data(), temp_cpp.data(), std::forward(args)..., + at::cuda::getCurrentCUDAStream()); + } else { + act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); + } }); quantizer_cpp->quantize(temp_cpp, out_cpp); } @@ -87,10 +73,8 @@ py::object activation_helper(const at::Tensor& input, } template -py::object dactivation_helper(const at::Tensor& grad_output, - const at::Tensor& input, - py::handle quantizer, - Args&&... args) { +py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, + py::handle quantizer, Args&&... args) { init_extension(); // Grad output and input tensors @@ -112,54 +96,39 @@ py::object dactivation_helper(const at::Tensor& grad_output, detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation backward directly NVTE_SCOPED_GIL_RELEASE({ - if constexpr (dact_func == nullptr) { - dact_func_with_args(grad_output_cpp.data(), - input_cpp.data(), - grad_input_cpp.data(), - std::forward(args)..., - at::cuda::getCurrentCUDAStream()); - } else { - dact_func(grad_output_cpp.data(), - input_cpp.data(), - grad_input_cpp.data(), - at::cuda::getCurrentCUDAStream()); - } + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), + std::forward(args)..., at::cuda::getCurrentCUDAStream()); + } else { + dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), + at::cuda::getCurrentCUDAStream()); + } }); } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // Compute activation backward in high-precision fused together with amax, then quantize. auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if constexpr (dact_func == nullptr) { - dact_func_with_args(grad_output_cpp.data(), - input_cpp.data(), - temp_cpp.data(), - std::forward(args)..., - at::cuda::getCurrentCUDAStream()); - } else { - dact_func(grad_output_cpp.data(), - input_cpp.data(), - temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - } + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), + std::forward(args)..., at::cuda::getCurrentCUDAStream()); + } else { + dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); + } }); quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); } else { // Compute activation backward in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - if constexpr (dact_func == nullptr) { - dact_func_with_args(grad_output_cpp.data(), - input_cpp.data(), - temp_cpp.data(), - std::forward(args)..., - at::cuda::getCurrentCUDAStream()); - } else { - dact_func(grad_output_cpp.data(), - input_cpp.data(), - temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - } + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), + std::forward(args)..., at::cuda::getCurrentCUDAStream()); + } else { + dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); + } }); quantizer_cpp->quantize(temp_cpp, grad_input_cpp); } From aee3fb941a7f15e1f254a0de38d2317240b804c6 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 15 Sep 2025 18:04:11 +0000 Subject: [PATCH 25/53] fix linting error Signed-off-by: Varun Thumbe --- .../common/util/cast_gated_kernels.cuh | 41 ++++++++----------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 21fbc3078d..99330cd042 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -174,9 +174,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) bool dgate_elt = true; // gating is ideally an identity function if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values - const float limit = p.limit; - dgate_elt = gate_elt < limit && gate_elt > -limit; // Derivative of clamp - gate_elt = min(max(-limit, gate_elt), limit) + 1; + dgate_elt = gate_elt < p.limit && gate_elt > -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1; } if constexpr (IS_DGATED) { @@ -186,13 +185,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_x; float dact_x; if constexpr (std::is_same::value) { - const float limit = p.limit; - const float alpha = p.alpha; - const float x = min(act_elt, limit); - const float s = sigmoidf(alpha * x); + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - if (x < limit) { - dact_x = s + s * (1 - s) * alpha * x; + if (x < p.limit) { + dact_x = s + s * (1 - s) * p.alpha * x; } else { dact_x = 0.0f; } @@ -498,9 +495,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) bool dgate_elt = true; // gating is ideally an identity function if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values - const float limit = p.limit; - dgate_elt = gate_elt < limit && gate_elt > -limit; // Derivative of clamp - gate_elt = min(max(-limit, gate_elt), limit) + 1.0f; + dgate_elt = gate_elt < p.limit && gate_elt > -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; } if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); @@ -508,12 +504,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_x; float dact_x; if constexpr (std::is_same::value) { - const float limit = p.limit; - const float alpha = p.alpha; - const float x = min(act_elt, limit); - const float s = sigmoidf(alpha * x); + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x < limit ? s + s * (1 - s) * alpha * x : 0.0f; + dact_x = x < p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; } else { if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); @@ -756,9 +750,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float dgate_elt = true; if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values - const float limit = p.limit; - dgate_elt = gate_elt < limit && gate_elt > -limit; // Derivative of clamp - gate_elt = min(max(-limit, gate_elt), limit) + 1.0f; + dgate_elt = gate_elt < p.limit && gate_elt > -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; } if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad.data.elt[e]); @@ -766,12 +759,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_x; float dact_x; if constexpr (std::is_same::value) { - const float limit = p.limit; - const float alpha = p.alpha; - const float x = min(act_elt, limit); - const float s = sigmoidf(alpha * x); + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x < limit ? s + s * (1 - s) * alpha * x : 0.0f; + dact_x = x < p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; } else { if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); From 87ae3d159c4563151e6c624922f17e771bbf353d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Sep 2025 17:25:52 +0000 Subject: [PATCH 26/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci [PyTorch Debug] Fix issue with negative underflow% stat. (#2107) * fix underflows log issue Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/debug/test_api_features.py | 12 +++++---- tests/pytorch/debug/test_log.py | 10 +++---- .../include/transformer_engine/activation.h | 5 ++-- .../debug/features/utils/stats_computation.py | 26 +++++++++++++------ .../pytorch/ops/basic/activation.py | 2 +- 5 files changed, 32 insertions(+), 23 deletions(-) diff --git a/tests/pytorch/debug/test_api_features.py b/tests/pytorch/debug/test_api_features.py index 974772599a..d28db16477 100644 --- a/tests/pytorch/debug/test_api_features.py +++ b/tests/pytorch/debug/test_api_features.py @@ -268,7 +268,7 @@ def assert_empty(): )[0] expected_underflows = ( - ((tensor_fp8._data == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5) + ((tensor_fp8.dequantize() == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5) ) assert debug_api.transformer_engine.inspect_tensor_enabled( @@ -302,7 +302,7 @@ def assert_empty(): )[0] # Second config in same yaml - tensor = torch.rand((100, 100, 5)) + tensor = torch.rand((100, 100, 5)).cuda() debug_api.transformer_engine.inspect_tensor( "decoder.6.mlp.fc1", tensor_name="activation", @@ -316,7 +316,9 @@ def assert_empty(): stats = log() stats_names = [x[3] for x in stats.keys()] all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"]) - assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean() + torch.testing.assert_close( + stats[("decoder.6.mlp.fc1", "activation", "mean", 200)], tensor.mean() + ) debug_api.transformer_engine.inspect_tensor( "decoder.7.mlp.fc1", @@ -331,7 +333,7 @@ def assert_empty(): stats = log() stats_names = [x[3] for x in stats.keys()] all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"]) - assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max() + torch.testing.assert_close(stats[("decoder.7.mlp.fc1", "weight", "max", 200)], tensor.max()) assert not debug_api.transformer_engine.inspect_tensor_enabled( "decoder.7.mlp.fc1", tensor_name="weight", iteration=201 @@ -377,7 +379,7 @@ def fp8_tensor(t): return quantizer(t.cuda()) shape = [1024, 1024] - tensors = [torch.randn(shape) for _ in range(2)] + tensors = [torch.randn(shape).cuda() for _ in range(2)] tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)] feed(tensors[0], tensors_fp8[0], quantizer) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index ca8e10ad69..dcc9861c84 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -167,8 +167,8 @@ def test_numerics(fp8_recipe, feature_dirs): num_quantizers=3, ) - tensor = torch.zeros(1024, 1024).cuda() - tensor[0, :] = 1000 + tensor = torch.randn(1024, 1024).cuda() + tensor[0, 100:200] = -0.0 quantizer = recipe_state.make_quantizers()[0] quantized_tensor = quantizer(tensor) @@ -191,15 +191,13 @@ def test_numerics(fp8_recipe, feature_dirs): if "underflows%" in line: underflows = float(line.split("value=")[1]) expected = ( - ((dequantized_tensor == 0).sum() - (tensor == 0).sum()) - / dequantized_tensor.numel() - * 100 + ((dequantized_tensor == 0).sum() - (tensor == 0).sum()) / tensor.numel() * 100 ) assert underflows == pytest.approx(expected.cpu(), abs=1e-4) if "mse" in line: mse = float(line.split("value=")[1]) expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean") - assert mse == pytest.approx(expected.cpu(), abs=1e-6) + assert mse == pytest.approx(expected.cpu(), abs=1e-4) if "overflows%" in line: overflows = float(line.split("value=")[1]) expected = ( diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 1e40c38bb6..0244a6bc64 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -173,8 +173,6 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); - - /*! \brief Computes the gated Swish activation of the input used in GPT OSS. * https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 * This activation has two differences compared to the original SwiGLU @@ -190,7 +188,8 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) * \param[in] alpha Scaling factor for the sigmoid function used in the activation. * \param[in] stream CUDA stream used for the operation. */ -void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, cudaStream_t stream); +void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, + cudaStream_t stream); /*! \brief Computes the gated ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 3842ab1c56..2fa6985acf 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -199,6 +199,15 @@ def _get(buffers, stat_name): ), } +FP8_NEGATIVE_ZERO = 128 # represnts -0.0 in fp8 + + +def count_nonzero_fp8(fp8_data: torch.Tensor) -> torch.Tensor: + """Count the number of non-zero elements in the fp8 data.""" + fp8_data = fp8_data.view(dtype=torch.uint8) + zero_vals = torch.tensor([0, FP8_NEGATIVE_ZERO], device=fp8_data.device, dtype=torch.uint8) + return fp8_data.numel() - torch.isin(fp8_data, zero_vals).sum() + def add_underflows_stats(recipe_name: str, columnwise: bool = False): """Register *both* underflow stats (num and %) for the given recipe.""" @@ -212,22 +221,23 @@ def add_underflows_stats(recipe_name: str, columnwise: bool = False): stats_to_num[stat_pct] = len(stats_to_num) STATS[stat_num] = ( - lambda x, aux_dict: ( + lambda x, aux_dict: x.count_nonzero() + - count_nonzero_fp8( aux_dict[recipe_name].get_data_tensors( rowwise_data=not columnwise, columnwise_data=columnwise ) - == 0 - ).sum() - - (x == 0).sum(), + ), lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)), ) STATS[stat_pct] = ( lambda x, aux_dict: ( - aux_dict[recipe_name].get_data_tensors( - rowwise_data=not columnwise, columnwise_data=columnwise + x.count_nonzero() + - count_nonzero_fp8( + aux_dict[recipe_name].get_data_tensors( + rowwise_data=not columnwise, columnwise_data=columnwise + ) ) - == 0 - ).sum() + ) / aux_dict[recipe_name].numel() * 100, lambda buffers, _sn_num=stat_num: 100 diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index ca1854db7d..308deaebee 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -399,7 +399,7 @@ class GptOssSwiglu(_ActivationOperation): This activation has two differences compared to the original SwiGLU 1. Both gate and pre-activations are clipped based on parameter limit. 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. - + .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. From 3858eab4230b5c5ff4381dae4fd3e013ce517e5e Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Thu, 18 Sep 2025 23:15:29 +0000 Subject: [PATCH 27/53] Address review comments, fix mxfp8 kernel bug: was not passing clamped swiglu parameter correctly Signed-off-by: Varun Thumbe --- tests/pytorch/test_fusible_ops.py | 19 +++++++++++++------ .../common/activation/swiglu.cu | 16 ++++++++-------- .../include/transformer_engine/activation.h | 16 ++++++++++------ .../common/util/cast_gated_kernels.cuh | 18 +++++++++--------- transformer_engine/common/util/math.h | 6 +++--- .../common/util/vectorized_pointwise.h | 4 ++-- transformer_engine/pytorch/csrc/extensions.h | 4 ++-- .../pytorch/csrc/extensions/activation.cpp | 10 +++++----- .../pytorch/csrc/extensions/pybind.cpp | 4 ++-- .../pytorch/ops/basic/__init__.py | 2 +- .../pytorch/ops/basic/activation.py | 8 ++++---- 11 files changed, 59 insertions(+), 48 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 09654f13ff..eb4edc5cbd 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1711,7 +1711,7 @@ def test_swiglu( @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantize_forward", (False, True)) @pytest.mark.parametrize("quantize_backward", (False, True)) - def test_gpt_oss_swiglu( + def test_clamped_swiglu( self, *, out_shape: Iterable[int] = (32, 32), @@ -1721,6 +1721,7 @@ def test_gpt_oss_swiglu( quantize_forward: bool, quantize_backward: bool, ): + # Test SwiGLU variant used in GPT OSS. # Tensor dimensions in_shape = list(out_shape) in_shape[-1] *= 2 @@ -1743,12 +1744,18 @@ def test_gpt_oss_swiglu( test_device=device, requires_grad=False, ) - + # A low value of limit = 0.1 is used for this test instead of the original + # default = 7.0 used in GPT OSS. This is because low value kills decent number + # of gradients allowing us to check for correctness of gradient computation of + # ClampedSwiGLU. + limit = 0.1 + alpha = 1.702 + # Plain PyTorch implementation x_glu, x_linear = x_ref.chunk(2, dim=-1) - x_glu = x_glu.clamp(min=None, max=0.1) - x_linear = x_linear.clamp(min=-0.1, max=0.1) - out_glu = x_glu * torch.sigmoid(1.702 * x_glu) + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) y_ref = out_glu * (x_linear + 1) y_ref.backward(dy_ref) @@ -1757,7 +1764,7 @@ def test_gpt_oss_swiglu( forward = te_ops.Sequential( te_ops.Quantize(forward=False, backward=quantize_backward), - te_ops.GptOssSwiglu(limit=0.1, alpha=1.702), + te_ops.ClampedSwiGLU(limit=limit, alpha=alpha), te_ops.Quantize(forward=quantize_forward, backward=False), ) with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 63b07e69af..d1d21be63d 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -35,19 +35,19 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp dgated_act_fn, dsilu>(grad, input, output, e, stream); } -void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, +void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, cudaStream_t stream) { - NVTE_API_CALL(nvte_gptoss_swiglu); + NVTE_API_CALL(nvte_clamped_swiglu); using namespace transformer_engine; - GptOssParam param = {limit, alpha}; - gated_act_fn>(input, output, param, stream); + ClampedSwiGLUParam param = {limit, alpha}; + gated_act_fn>(input, output, param, stream); } -void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, +void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, float limit, float alpha, cudaStream_t stream) { - NVTE_API_CALL(nvte_gptoss_dswiglu); + NVTE_API_CALL(nvte_clamped_dswiglu); using namespace transformer_engine; - GptOssParam param = {limit, alpha}; - dgated_act_fn, oss_dsilu>(grad, input, output, + ClampedSwiGLUParam param = {limit, alpha}; + dgated_act_fn, oss_dsilu>(grad, input, output, param, stream); } diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 0244a6bc64..3b74b8f195 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -174,10 +174,12 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the gated Swish activation of the input used in GPT OSS. - * https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 - * This activation has two differences compared to the original SwiGLU + * + * See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 + * This Gated activation has two differences compared to the original SwiGLU * 1. Both gate and pre-activations are clipped based on parameter limit. - * 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. + * 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired + * by original GELU paper https://arxiv.org/pdf/1606.08415 * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. * @@ -188,7 +190,7 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) * \param[in] alpha Scaling factor for the sigmoid function used in the activation. * \param[in] stream CUDA stream used for the operation. */ -void nvte_gptoss_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, +void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, cudaStream_t stream); /*! \brief Computes the gated ReLU activation of the input. @@ -249,10 +251,12 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp cudaStream_t stream); /*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS. + * * https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 * This activation has two differences compared to the original SwiGLU * 1. Both gate and pre-activations are clipped based on parameter limit. - * 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. + * 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired + * by original GELU paper https://arxiv.org/pdf/1606.08415 * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. * @@ -263,7 +267,7 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp * \param[in] alpha Scaling factor for the sigmoid function used in the activation. * \param[in] stream CUDA stream used for the operation. */ -void nvte_gptoss_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, +void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, float limit, float alpha, cudaStream_t stream); /*! \brief Computes the gated ReLU activation gradient. diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 99330cd042..7f5d68fc62 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -172,7 +172,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_elt = static_cast(in_act_sh_curr[shmem_idx]); float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); bool dgate_elt = true; // gating is ideally an identity function - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values dgate_elt = gate_elt < p.limit && gate_elt > -p.limit; // Derivative of clamp gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1; @@ -184,7 +184,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = act_elt; float act_x; float dact_x; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const float x = min(act_elt, p.limit); const float s = sigmoidf(p.alpha * x); act_x = x * s; @@ -493,7 +493,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float after_act_elt; float after_gate_elt; bool dgate_elt = true; // gating is ideally an identity function - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values dgate_elt = gate_elt < p.limit && gate_elt > -p.limit; // Derivative of clamp gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; @@ -503,7 +503,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = act_elt; float act_x; float dact_x; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const float x = min(act_elt, p.limit); const float s = sigmoidf(p.alpha * x); act_x = x * s; @@ -748,7 +748,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float after_act_elt; float after_gate_elt; float dgate_elt = true; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values dgate_elt = gate_elt < p.limit && gate_elt > -p.limit; // Derivative of clamp gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; @@ -758,7 +758,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = act_elt; float act_x; float dact_x; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const float x = min(act_elt, p.limit); const float s = sigmoidf(p.alpha * x); act_x = x * s; @@ -769,8 +769,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) act_x = x * s; dact_x = x * s * (1 - s) + s; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + act_x = ActOP(x, p); + dact_x = DActOP(x, p); } } @@ -779,7 +779,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) after_act_rowwise[j] = after_act_elt; after_gate_rowwise[j] = after_gate_elt; } else { - after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_act_elt = ActOP(act_elt, p) * gate_elt; after_act_rowwise[j] = after_act_elt; } diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 7e387b78e0..5885652c5a 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -11,7 +11,7 @@ namespace transformer_engine { struct Empty {}; -struct GptOssParam { +struct ClampedSwiGLUParam { float limit; float alpha = 1.702f; // Default value for QuickGELU }; @@ -75,7 +75,7 @@ __device__ inline OType silu(const IType val, const Empty& e) { } template -__device__ inline OType oss_silu(const IType val, const GptOssParam& p) { +__device__ inline OType oss_silu(const IType val, const ClampedSwiGLUParam& p) { const float cval = min(p.limit, static_cast(val)); // Clamping return qgelu_with_alpha(cval, p.alpha); } @@ -87,7 +87,7 @@ __device__ inline OType dsilu(const IType val, const Empty& e) { } template -__device__ inline OType oss_dsilu(const IType val, const GptOssParam& p) { +__device__ inline OType oss_dsilu(const IType val, const ClampedSwiGLUParam& p) { const bool dclamp_val = static_cast(val) <= p.limit; const float clamp_val = min(static_cast(val), p.limit); const float dsilu_val = dqgelu_with_alpha(clamp_val, p.alpha); diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index bf38077768..959eb8ea90 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -433,7 +433,7 @@ __launch_bounds__(unary_kernel_threads) __global__ const ComputeType val = static_cast(loader0.separate()[i]); ComputeType val2 = static_cast(loader1.separate()[i]); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { // Clamp the gated value and add 1 at the end ComputeType limit = p.limit; val2 = std::min(std::max(-limit, val2), limit) + 1; @@ -541,7 +541,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType gate_in = static_cast(input_loader1.separate()[i]); bool dgate_in = true; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values const ComputeType limit = p.limit; dgate_in = gate_in < limit && gate_in > -limit; // Derivative of clamp diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index cf6651cd5c..9fdcf7342f 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -197,9 +197,9 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer); py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object gpt_oss_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha); +py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha); -py::object gpt_oss_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, +py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, float limit, float alpha); /*************************************************************************************************** * LayerNorm diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 4c01c0cf0f..856a597c67 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -218,14 +218,14 @@ py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle q return dactivation_helper(grad, input, quantizer); } -/* gpt_oss functions */ -py::object gpt_oss_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) { - return activation_helper(input, quantizer, 2, limit, alpha); +/* clamped functions */ +py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) { + return activation_helper(input, quantizer, 2, limit, alpha); } -py::object gpt_oss_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, +py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, float limit, float alpha) { - return dactivation_helper(grad, input, quantizer, limit, alpha); + return dactivation_helper(grad, input, quantizer, limit, alpha); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 45799ce535..ae6575914c 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -136,7 +136,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer")); m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), py::arg("quantizer")); - m.def("gpt_oss_swiglu", transformer_engine::pytorch::gpt_oss_swiglu, + m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu, "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); /* Backward of GELU and variants */ @@ -162,7 +162,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("fwd_input"), py::arg("quantizer")); m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("gpt_oss_dswiglu", transformer_engine::pytorch::gpt_oss_dswiglu, + m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu, "Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); /* DBias + DAct fusions*/ diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 6dfdf3cac6..28d49bf7b9 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -15,7 +15,7 @@ SReGLU, SiLU, SwiGLU, - GptOssSwiglu, + ClampedSwiGLU, ) from .add_extra_input import AddExtraInput from .all_gather import AllGather diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 308deaebee..961b472e0f 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -27,7 +27,7 @@ "SReGLU", "SiLU", "SwiGLU", - "GptOssSwiglu", + "ClampedSwiGLU", ] @@ -392,7 +392,7 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dswiglu(*args, **kwargs) -class GptOssSwiglu(_ActivationOperation): +class ClampedSwiGLU(_ActivationOperation): r"""GPT-OSS Implementation based on `GPT-OSS`__. @@ -419,7 +419,7 @@ def __init__(self, *, limit: float, alpha: float, cache_quantized_input: bool = self.alpha = alpha def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.gpt_oss_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) + return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.gpt_oss_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) + return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) From de3080e84f821a8bea3390788115f06800c39e0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Sep 2025 23:16:40 +0000 Subject: [PATCH 28/53] [pre-commit.ci] auto fixes from pre-commit.com hooks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Lower precision gated-act to accelerate FP8 current-scaling. (#2153) * Applying the original precision as Norm outputs' and activation compuations. Signed-off-by: Ming Huang * Adding knob to control norm output precision. Signed-off-by: Ming Huang * Removing the knob and applying lower-precision norm with current-scaling only. Signed-off-by: Ming Huang * Fix the error when quantizer==None Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang [PyTorch] Support activation CPU offloading in fusible ops (#2158) * Add CPU offloading logic to ops. Fix test to compute dgrad. Signed-off-by: Tim Moon * Make sure grads are contiguous in op backwards Signed-off-by: Tim Moon * Add op-based MLP to CPU offloading tests Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Handle different weight cache behavior on Hopper/Blackwell Add MXFP8 to CPU offload tests. Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove MXFP8 test Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Do not use normalization forward + amax fusion if cuDNN backend is requested (#2174) * Do not use norm fwd + amax fusion if cudnn backend is requested Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Read envirornment vairable directly to avoid include error Signed-off-by: Jan Bielak --------- Signed-off-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Fix unjoined comm stream in UB communicator (#2160) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> FP8 Output Quantization for GEMM (#2123) * Test working as I think it should work Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * revert accidental change Signed-off-by: Varun Thumbe Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe fix merge conflict Signed-off-by: Varun Thumbe bug: missed a } in the code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Signed-off-by: Varun Thumbe [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Signed-off-by: Varun Thumbe Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov Signed-off-by: Varun Thumbe [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe feat: Add support for multiple quantization modes in the UB communicators (#2043) Signed-off-by: Varun Thumbe [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> Signed-off-by: Varun Thumbe mxfp8 unfused quant support, refined unit test, remove unecessary quantization code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe missed a quant code removal Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor bug fix Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani feat: Add support for multiple quantization modes in the UB communicators (#2043) [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> minor code cleanup Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor cosmetics Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Address review comment Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor comment update Signed-off-by: Varun Thumbe Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> minor bug: quantizer should not be none for unfused quantization Signed-off-by: Varun Thumbe [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani Add check for sm100 Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> fix linting error Signed-off-by: Varun Thumbe [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz * fix for fp8 blockwise recipe Signed-off-by: zhongboz * resolve comments Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> address review comments Signed-off-by: Varun Thumbe * Update test_multi_process_distributed_grouped_gemm.py change accidentally added while merging Signed-off-by: vthumbe1503 * Update dense.py change accidentally added while merging Signed-off-by: vthumbe1503 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address revie comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Bug solved: delayed scaling quantization with mxfp8 inputs didnt work Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the unit test error Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * just to trigger ci Signed-off-by: Varun Thumbe * address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * fix merge conflict Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Varun Thumbe Signed-off-by: vthumbe1503 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> TE Gemma tutorial attempt#2 (#1839) * add tutorial files and other local changes Signed-off-by: Sudhakar Singh * remove extraneous code for easy debu Signed-off-by: Sudhakar Singh * make cuda graphs work with non-paged and paged attention Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * perf imp for kv cache ops Signed-off-by: Sudhakar Singh * add code for calibration Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * optimize kv_cache reindex and copy kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changes to make quantizers work with fp8_calibration Signed-off-by: Sudhakar Singh * avoid reindexing from python side Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * rename variable from previous commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use quantizer only if needed Signed-off-by: Sudhakar Singh * functionality of the tutorial tested and perf checked Signed-off-by: Sudhakar Singh * remove files and update headers/licenses Signed-off-by: Sudhakar Singh * update header/license Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update tutorial for review Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make weights downloadable on the fly; remove extra print statements Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint and update comments Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add comma back, typo Signed-off-by: Sudhakar Singh * sequence_start_positions should be None for training Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add paged attention numberes and update requirements.txt file Signed-off-by: Sudhakar Singh * more fixes Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make tutorial work on blackwell Signed-off-by: Sudhakar Singh * remove gemma FT tutorial for now Signed-off-by: Sudhakar Singh * fixing the headings placement and rewording attention -> kv caching Signed-off-by: Sudhakar Singh * fixes from comments Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the images Signed-off-by: Sudhakar Singh * misc fixes Signed-off-by: Sudhakar Singh * add more comments to te_gemma.py and cleanup utils.py Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more information about the hierarchy of the classes used in the tutorial Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add better cuda graphs picture Signed-off-by: Sudhakar Singh * addd updated cuda graphs pictures Signed-off-by: Sudhakar Singh * add illustrated cuda graphs Signed-off-by: Sudhakar Singh * fix Signed-off-by: Sudhakar Singh * small fixes in documentation Signed-off-by: Sudhakar Singh * add torch.no_grad() to force reduced memory usage Signed-off-by: Sudhakar Singh * some fixes from recent comments Signed-off-by: Sudhakar Singh * more fixes from remaining comments Signed-off-by: Sudhakar Singh * add te_rope_emb to class desc Signed-off-by: Sudhakar Singh * fix tutorial wording; add calibration fix to grouped_linear.py Signed-off-by: Sudhakar Singh --------- Signed-off-by: Sudhakar Singh Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Fix memory overhead of linear layer when all gather from sequence parallel (#2125) * fix memory overhead of all gather from sequence parallel Signed-off-by: Yuzhong Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * quick fix the errors when for UB buffers Signed-off-by: Yuzhong Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/module/linear.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Avoid deallocating FP8 scale-invs since they are reused Signed-off-by: Tim Moon --------- Signed-off-by: Yuzhong Wang Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon Fix incorrect TP rank calculation when using data parallel (#2179) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> [Pytorch] Add Cutlass Grouped GEMM Support for fine-grained MoE Model (#2045) * feat: add cutlass group gemm support Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: refactor multi tensor gemm interface Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: refactor nvte_multi_stream_cublas_gemm func and add license info Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: add unit test for cutlass group gemm Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: add cutlass support type protect Signed-off-by: Min Yang * add tests and fix lint Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: fix unit tests error Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: refactor host workspace malloc Signed-off-by: Min Yang * update cutlass Signed-off-by: Xin Yao * update cutlass Signed-off-by: Xin Yao * further relex threshold and add a env var to warn fall back Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Min Yang Signed-off-by: Xin Yao Signed-off-by: alan yang <89962857+cassiewilliam@users.noreply.github.com> Co-authored-by: Min Yang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Phuong Nguyen [PyTorch] Support FA3 for MLA and with CP (#1907) feature(FA3,MLA,CP): 1. Update FA3 to commit-id 3ba6f82 (tag 2.8.0.post2 with compile error fixed), PR-1604 support hdimQK != hdimV backward 2. Update get_attention_backend method because FA3 support MLA now 3. Add CP MLA support for FA3 4. Add unit tests for FA3 MLA CP 5. Update attention doc Signed-off-by: zhujian Fix cuDNN version checks when getting backend and for sm89 kv cache (#2185) * Fix cudnn version checks for kv cache for sm89. Add cudnn version check in preparation for 9.14 when getting backend Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor fix for cuDNN version condition check Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .gitmodules | 3 + 3rdparty/cutlass | 1 + docs/examples/attention/attention.ipynb | 2 +- docs/examples/te_gemma/media/calibration.svg | 620 ++++++++++++ .../te_gemma/media/calibration_1_half.svg | 415 ++++++++ .../te_gemma/media/calibration_2_half.svg | 401 ++++++++ .../te_gemma/media/fp8_model_init.svg | 500 ++++++++++ .../te_gemma/media/fp8_model_init_1_half.svg | 358 +++++++ .../te_gemma/media/fp8_model_init_2_half.svg | 371 +++++++ .../te_gemma/media/generation_animation.gif | Bin 0 -> 135280 bytes docs/examples/te_gemma/media/graphs.svg | 232 +++++ .../media/transformer_cuda_graphed.png | Bin 0 -> 369694 bytes docs/examples/te_gemma/requirements.txt | 4 + docs/examples/te_gemma/te_gemma.py | 703 +++++++++++++ .../te_gemma/te_gemma_loading_weights.py | 189 ++++ .../tutorial_generation_gemma_with_te.ipynb | 941 ++++++++++++++++++ docs/examples/te_gemma/utils.py | 370 +++++++ ...tutorial_accelerate_hf_llama_with_te.ipynb | 2 +- docs/index.rst | 1 + tests/jax/test_custom_call_compute.py | 13 +- .../attention/test_attention_with_cp.py | 8 + tests/pytorch/test_cpu_offloading.py | 196 ++-- tests/pytorch/test_fusible_ops.py | 4 +- tests/pytorch/test_numerics.py | 144 ++- transformer_engine/common/CMakeLists.txt | 22 +- .../common/activation/swiglu.cu | 8 +- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 16 +- .../userbuffers/userbuffers.cu | 14 +- .../userbuffers/userbuffers.h | 4 +- .../common/fused_attn/fused_attn.cpp | 10 +- .../common/gemm/cublaslt_gemm.cu | 119 ++- .../common/gemm/cutlass_grouped_gemm.cu | 77 ++ .../common/gemm/cutlass_grouped_gemm.cuh | 348 +++++++ .../include/transformer_engine/activation.h | 4 +- .../common/include/transformer_engine/gemm.h | 11 +- .../common/normalization/layernorm/ln_api.cpp | 3 +- .../normalization/rmsnorm/rmsnorm_api.cpp | 3 +- .../quantize_transpose_vector_blockwise.cu | 11 +- .../jax/cpp_extensions/activation.py | 6 +- .../jax/cpp_extensions/normalization.py | 4 + .../jax/csrc/extensions/gemm.cpp | 8 +- .../dot_product_attention/context_parallel.py | 249 +++-- .../attention/dot_product_attention/utils.py | 58 +- .../pytorch/attention/inference.py | 28 +- .../pytorch/attention/multi_head_attention.py | 24 +- .../pytorch/csrc/extensions/apply_rope.cpp | 3 +- .../pytorch/csrc/extensions/gemm.cpp | 62 +- .../pytorch/csrc/extensions/normalization.cpp | 12 +- transformer_engine/pytorch/csrc/quantizer.cpp | 49 +- .../pytorch/module/grouped_linear.py | 2 +- .../pytorch/module/layernorm_linear.py | 25 +- .../pytorch/module/layernorm_mlp.py | 17 +- transformer_engine/pytorch/module/linear.py | 18 +- transformer_engine/pytorch/ops/_common.py | 4 +- .../pytorch/ops/basic/activation.py | 9 +- .../pytorch/ops/basic/basic_linear.py | 3 + .../pytorch/ops/basic/dropout.py | 3 + .../pytorch/ops/basic/l2normalization.py | 9 +- .../pytorch/ops/basic/layer_norm.py | 7 +- .../pytorch/ops/basic/rmsnorm.py | 7 +- .../fused/forward_linear_bias_activation.py | 13 +- .../ops/fused/forward_linear_bias_add.py | 15 +- .../ops/fused/forward_linear_scale_add.py | 5 +- .../ops/fused/userbuffers_forward_linear.py | 3 + .../_internal/float8_blockwise_tensor_base.py | 5 + .../tensor/_internal/float8_tensor_base.py | 9 +- 66 files changed, 6407 insertions(+), 378 deletions(-) create mode 160000 3rdparty/cutlass create mode 100644 docs/examples/te_gemma/media/calibration.svg create mode 100755 docs/examples/te_gemma/media/calibration_1_half.svg create mode 100644 docs/examples/te_gemma/media/calibration_2_half.svg create mode 100644 docs/examples/te_gemma/media/fp8_model_init.svg create mode 100644 docs/examples/te_gemma/media/fp8_model_init_1_half.svg create mode 100644 docs/examples/te_gemma/media/fp8_model_init_2_half.svg create mode 100644 docs/examples/te_gemma/media/generation_animation.gif create mode 100644 docs/examples/te_gemma/media/graphs.svg create mode 100644 docs/examples/te_gemma/media/transformer_cuda_graphed.png create mode 100755 docs/examples/te_gemma/requirements.txt create mode 100755 docs/examples/te_gemma/te_gemma.py create mode 100755 docs/examples/te_gemma/te_gemma_loading_weights.py create mode 100755 docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb create mode 100755 docs/examples/te_gemma/utils.py create mode 100644 transformer_engine/common/gemm/cutlass_grouped_gemm.cu create mode 100644 transformer_engine/common/gemm/cutlass_grouped_gemm.cuh diff --git a/.gitmodules b/.gitmodules index 21492db5ef..4b188d6bb1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend url = https://github.com/NVIDIA/cudnn-frontend.git +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cutlass b/3rdparty/cutlass new file mode 160000 index 0000000000..57e3cfb47a --- /dev/null +++ b/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit 57e3cfb47a2d9e0d46eb6335c3dc411498efa198 diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 6cd56d23da..61a6ad949f 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -390,7 +390,7 @@ "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n", "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n", "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n", - "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | No | Yes (`bshd`,`thd`) | Yes |\n", + "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes | Yes (`bshd`,`thd`) | Yes |\n", "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", diff --git a/docs/examples/te_gemma/media/calibration.svg b/docs/examples/te_gemma/media/calibration.svg new file mode 100644 index 0000000000..16e1a43141 --- /dev/null +++ b/docs/examples/te_gemma/media/calibration.svg @@ -0,0 +1,620 @@ + + + + + + + + + + + FP8 with initial scaling factors + + + High + precision + weight + + Initial + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + + + + Calibration + + + High + precision + weight + + FP8 scaling + factors + + High + precision + input + + High + precision + GEMM + + + + FP8 with calibrated scaling factors + + + High + precision + weight + + Calibrated + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + + diff --git a/docs/examples/te_gemma/media/calibration_1_half.svg b/docs/examples/te_gemma/media/calibration_1_half.svg new file mode 100755 index 0000000000..478604d415 --- /dev/null +++ b/docs/examples/te_gemma/media/calibration_1_half.svg @@ -0,0 +1,415 @@ + + + + + + + + + + + + + High + precision + weight + + Initial + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + + + + + High + precision + weight + + FP8 scaling + factors + + High + precision + input + + High + precision + GEMM + + + + + FP8 with initial scaling factors + Calibration + + diff --git a/docs/examples/te_gemma/media/calibration_2_half.svg b/docs/examples/te_gemma/media/calibration_2_half.svg new file mode 100644 index 0000000000..439f4c16fb --- /dev/null +++ b/docs/examples/te_gemma/media/calibration_2_half.svg @@ -0,0 +1,401 @@ + + + + + + + + + + + + Calibration + + + High + precision + weight + + FP8 scaling + factors + + High + precision + input + + High + precision + GEMM + + + + FP8 with calibrated scaling factors + + + High + precision + weight + + Calibrated + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + diff --git a/docs/examples/te_gemma/media/fp8_model_init.svg b/docs/examples/te_gemma/media/fp8_model_init.svg new file mode 100644 index 0000000000..57af23dc31 --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init.svg @@ -0,0 +1,500 @@ + + + + + + + + + + FP32/BF16 + + FP8 + FP8 with fp8_model_init() + + + FP8 + weight + + FP8 + GEMM + + + + + High + precision + weight + + High + precision + input + + High + precision + GEMM + + + + + High + precision + weight + + FP8 + Weight + + + FP8 + Input + + + FP8 + GEMM + + + + + + High + precision + input + + + FP8 + Input + + + + + High + precision + input + + diff --git a/docs/examples/te_gemma/media/fp8_model_init_1_half.svg b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg new file mode 100644 index 0000000000..d86751e071 --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg @@ -0,0 +1,358 @@ + + + + + + + + + + + FP32/BF16 + + + + High + precision + weight + + High + precision + input + + High + precision + GEMM + + + FP8 + + + High + precision + weight + + FP8 + Weight + + + FP8 + Input + + + FP8 + GEMM + + + + + + High + precision + input + + diff --git a/docs/examples/te_gemma/media/fp8_model_init_2_half.svg b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg new file mode 100644 index 0000000000..c3e4146bad --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg @@ -0,0 +1,371 @@ + + + + + + + + + + + FP8 + FP8 with fp8_model_init() + + + FP8 + weight + + FP8 + GEMM + + + + High + precision + weight + + FP8 + Weight + + + FP8 + Input + + + FP8 + GEMM + + + + + + High + precision + input + + + FP8 + Input + + + + + High + precision + input + diff --git a/docs/examples/te_gemma/media/generation_animation.gif b/docs/examples/te_gemma/media/generation_animation.gif new file mode 100644 index 0000000000000000000000000000000000000000..25150cb9b64162084b017442a3905c57127c6713 GIT binary patch literal 135280 zcmdSfRZtvG_%7(d!T^K=5F}EkGa;EQ1g3?#|%u?(VLGySwl2e`@#C z*{i+YzUum_tE;*%zUq4Wk&%<;<2TOuz=!An08qZ8DNAXnONgmTak8-^zyba%Jt88F z0Ym{T|IGpavn=ra{r&CjEjT#%zd}z>Phnx9m6a7YH#Zg*mW+(d+S*!OT^%(wwY$6f z$;pY9me%CtWKvQR5-m3lrzj!~*ZJiYCMITNWF!Rz#rgSpMMZ^!gM+H7>cqsv#>R%I zsAy(p=D&acjEs!f+1WcfIsotph*a#boxSbdy{oIMUlj;{ngD2Q?7qIfO-)VH)6-E= zQ4$go&d$z#eSI7p955J6U0ppRBg53xl#-INw6ydS$5-mFN)`nlu)nZwZ14WhLH_?d zkki2cUqAnVz#vF)NN8AiL}XNSOl(|yLSj;KN@`kqMrKxaPHtX)L17WJxTLhKyrQzI zx~8_SzM-+HxuvzOy`!_MyQjCWe_(KEcw}^Jd}4BHdS-TReqnKGd1ZBNeFL_+wY{^u zw|{VWbbNApc7Abrb$xStcmMGC^!)Pr_6|TmC6TSn>hc4AV$vV1%O zCE${^ULCB;9f+im2qBTH&KrtlP%qXWs?HxtWdAvwCRbB1mdf}0XmzNja3WLG50zBD zwrDC>HkRqfa4mGEP&HRRUB0e(u0*?r*LrQZu4JLYs3(L}p}usf#&V|k$4GtIN`u4p zaJoW6`C5zH_0if$L&Zi17y*q;v9WTq2l9!T&0w^#YI`7xS|LNRsd{%LiOXhvw5eu) zB1IKZ2-^wJBI!u&+jCc*EbAA%=>K5WTMgyQ-RQiUH1 zi(*HUP#sYEi)&h$eHA9kF%dqX%@uV=K5)o|$H+NQONb%h|Cpjxe@G{!N|du1VvQ

70h+CDgkjRjPS0?*>3nJ| zqhb<-RNW8+_+e7_pOl0^RZFLj8b^a+!w`9?)Ut!49~8 zp%T5n0$DiOY%!{LTLlA?j;xMVAA{8w$ zKnyg!uyzN6iwxaC;4s+_6O@zJ6zGx^bre{g51rLJC>J$O%10Mj;r=_Vzn$VE&`T1e z3xl|+C%hhwkU73uo>CA(o?s!`Um=~5zNWBxr{B3OwA4YL7Y-><{gZ?*1AUFk*K=^cVsuWke^HgK z#Pw#YJ%=nRQI9q=<8$mEW&b$;-7Vw0^ZQ-7n5A-o{p#)cw4ksd=ZZbWYVZ`qJs}iZ zk3D^RidD|aOL;>k-U@41AEWU@_8#XRe9Fao=gcP)i>eUcgIo8jZD4Qb`OJU1lH~NF zRae$@=1TyQ8bWmB6dNr!-^!tYCmf_%M_2N}bzeUvDgsWE<~V+*mV6+aZ3jWvg|Q_4 zArx?J$h2qf%(WBuzS`N}yBvI-ski)ET_tvJvYHFyPrc?E}i=zjX?gWc`Q-yb4(B58VB7{gG3an8&CG> z3Bzp|UKH6hfp4adLcd;OoidN1HyMO^9{oXIw{b8CzyYVI0g9RLS1s?eMZ9?&J$_TqW%)=x826dl3{7E*(Fwf-fh9B&ZO&=7kI9?2&BooGD@$2jL zdK#+OH}R+m&zlj8v!k2b*gIv_LVrv4+wE^xwKJ)%LXzR~PmQEUPJq*y$hqkxm5L{v z(lea6gw}?sQKZ?j*F$S|Iff||wquJDp(2C(Crv80+7L&cSqhRnX+YSOKry#}56f^j z<(EHD-|suW@92Ot0ddknvKol`NJ2V~W4WP48zrSl@^T?qfXbdDkxNKRJzZg{g$u)# zQ%KH}RYJ{S_Q|6tpY-lK3z~eUQc5)6FsE3bx;hly= z0Bx_`&qk`?5*-OQLY7T*sbE0uY_qw+1sKeS=v z;CNAbimwac*Ujvn|NPZz?6syW#6Mr`qX*eSq9^n&U1N=*O#4avk57z0V9LjZWN|>8 zaRdi1(z8%uXeWj5{g1WMWoUM&34?@7KqxjRX&=3@_txT&OnGJUvw$&$pnxCo&9%p2 zK5Uy%A?_x-AxbCGC~sO+H*zKWXq4oh*zw{#me((8hWnKvNXD_$seOar<}k;0o#J0_)6CKp^W`KtX$FyX znRqW&yXScVGfNLYpWiR~4njMU(H$8hmp2!%D*goNF-xTJj;_Ghd;dTMdsd`uh$P$u zDo5;}*Lwm!1$X>xWEBC1Ux&APNxL<CXF6tj#a*jHo z(mIhQ1`l-|JSU;LbrytKCrccx?44-s6Z78%-K~Xm{2~R;oXs3J$j9HWhZ_!8NfE(Q zd;R?D;hb^AD@#03UUG3d(jCybv)ShFbndTfml-h_fjKB=tK;5D`aSC|I2R%@dW0eZ z5N%|-8@`#LAiv(w$;AOr8~V@9TcMe4ev9pP4HAdo_f_k&4XErC@^ownXckUX`ZA_Q zWNRD8y&QEPUSref`-Ik`%T0~9V;y|AfW==R1BHkd z{pGB%=WmG`It2`fyF#H=HJwsb!!PARb3-KREHjBu*WO_OJA4-(o1H&%0}hP^|LR=L zkOEEfVrkJE1po*>hgecZ^M5wPOe~P7x9R`R%=_S8`TYvgXOn#g@=Dr-^Z2<-s0-6c zq7Veo`u=rs66NXE^ZOv{NAa~K*I8%JSg}!*kk9P0K=ote*Dy^=g8?S}2MPck(0I`e zr^qiw)7-0o?hAvHv#fw)3@^8@Gy4|%We<%K!k-Y#AOcI#zit3t6R=tVAaoJ(L=XH6 z^pg|!@jj6L97Ki85Rkz~f|m6CzL0FVN3;B%l|DXrvBxyX-;)0iAMyoP;;s8%DsgWo z8Pp4@^qtR`Kt1u;pe%apqblB!LRT}HU!sT z-32`(s!-{^cJn1dbeqI>Zua3V#ZI*No9OoV5eDbIKkEp?VSV|@{Wk!`T>alB zeNf93PoR}NBsuX<#ADJGdkDO9_$T!rsadu~8pKH`$!SZ;Cl{g>-}SkD0e)5}nuA8I zjwx6N$)!svI}}EF`3RUr%Gm(L-EHD}e)00A)Yc_*W5$HwEG2^>87Uo90nvDPznCt^ zv`MF!_J!0*!t`nW^jV$s`M>Fl$?3~Y>8ngj>FbZ_Fv5&2{)`=+jJ>}Z2gw;nO&KRk z8E20f7lfHt{Fyg8nRkCPALx@apPDjXmNMTSGXX?d@B&$gx>+CGvw$gCsLffRW^_Z)_l944CP9G2xAwx=9+ zqFhdaTyEXmFYdW~DY*j8xkAgiB2T%ZM0w)W0(p|UdD8BAvMG7;&3TH;dCE_DszmwU z1oAa>^EKV`wNvunb({0`m-B;-;Ejk1Oaux{bqmbh3oKI#teXpLmkaEl3LJ6MJdI_&Bdk5#pSwCkLJQeqVycyl6v=&#*~ug z<`U;(5FKVoCsAp)KxwaTX}^2vU`pw5bLr@E>G)IWBvIM4K-sKr*@!^NBw+z=W!dU- z+4@r%jHrA|pnONSe9yi7Af^1Mx%_0g{Oqaxf~evOwHyafkS7mE5CaEC!S-BS@$yvh zMpXGKQ27m@PzkSB3G}E$wW$Q9R-(65Vh~qh3RZp8tHQRa!ttmgO06OZts-lwB44Q@ zW(Im#!!ItU8p*@Kf2*KN1$7ix!(CO=kEC)EQ_y-;z+aWH>I>GCC08@sRIo9ZaOhPC z9jEZ1QScJi3YL%x=+z28mwauh?T6J!FxM$RC&`qM%7ucI^y;*ZlQlwXMMvsXkLwIa zl66u^^*lg7&?xLY>LUp2bp`8f_3G^vYTq3~>%>EA-A3xXMr!}1Hn=i2e7uCG1p~JG zv408H{XuKgWNy4ks_&#=vLOJ?87FUn!}>a_+ld-CRT&@@Oz;;X@o0d0zb5>dw2uz# zg{q`*RgF=^P2p%Dq4?wiWEH55}W zF9(4X1`jXFR;Cd6l-0JF3if|S!?^;S`!pXBYo9%%?2cfdkbqv$K<{W6NMYU|YnbY{ zL#G2YgH_4%ctR#UlAlsPqOW#v$K!W581V`i36|h=EzfkNS9Tu3x;q@&y`LL5FuVTr zK$ej@vQtU?4p}FLjJ0@B=bk^(yt8!g-gUNOvgqY~1xvu=`qCHn@uGLq)2gPGej(NQ z3m|e4-rDTdXH5I>n+K5-p=g(vL$sfez%ho(cZ&+(=?h16=!gtsCAa{5mrW zL&y$46`O!F+xo9GF!V&CbWixNAK)_p#5r1@f!2k7Gukgnv1BG_OV2cU48NMMc6^Z7 zlRNs7&)uf~-y@`F9Xwp{UTP4ShKA@wdhH2hvgML@>e;(QNF1>~=is{~mis2ridan* z?CD8L0Vw(g`E^T{t5rWYb*q-4C1LlTS-onyBCDB={H z<}1xC`uanL@|?Z5<{GyiJXnfCA5~L)zmwvfp#A#++J6NFRJVq9(vMs*)Q=d?CJG^p zuj9*5mV% z**0u1ODu1E*@VAi1`teY3-v|G@nE4t!(e#N*ur<-D^Oj)`_Uo$cLwfFpFw<^&j844U^17#p%R~bQ?M_0yWRe( zA#NaO-pgeFo3a#dxPKDX_q;RG#_ljSt&Oo3XW<`0+7S@JNfo=OdkkFa-sPthZ~A^1fLh-b{&3v@>|WnnhLcjXFcHv>HwcMDPR~LxzI!nY zHg-9p`uIM+X|xWU!-`{D*f~+X*z%?wd=n@%80W#=|DX>f!dxQ@YWehQ|8zaZJJ6gi zb6eQp_-W79l%VUh;3_wP!TB+fGID?Giot7bI^qnC^cOnTi^z{Q#~CFdU(is9!#r)m zE|wxCwi3mP9i8TXSCmHIXtTADbxCYUYqWuqQ!W1oLGNIE@abpr`DN6}xaL`e@l))N zA$;jPT8B5gb8MCK<=fp4HKTKlo*vvrtH1)@*04if=eF3VxhM@|;VSL7+7RnIZm4@;SV_-TK$1Nd(1)hrc}Yr47o+uN+@og^rytVdqY zV0H?CsJ0~oJNu-uEQGI9mLS&hA&MaiE-TBAz4KA~;Gq&2g%wTk)6IG@nh%VELd-0v zpr>GL@~xz0WW~nP>6lqaQ6I$oIOGCt9{K!*;87Ngv< za+-x{w3?jy)f{rjMPiRpuA`3Vl|96%$AW{M>bmP{ee#=>y~bE5T43UEL3 z-%(G<^|*;p>VD*$QjHIDkQBZT!5kK!IWoItV5ZRiNitwA`Ju#*p0u(w4d5A$>+|%$6SEw4xR$mygpf`xWdS@t(S+6}zC(&f!Gj@C)5~s}v zk_`JVPD4dF)$WN=t}2}vv^&1g3$y6Zxr(+lL#FOcGC_3m2sPOX?ho3pl!7? z+p%r!v~nn8?d;Q_jrqL6$aejz_g^&oMvG!5`}lSHS=%PVj1v2nv-TG1mWzZWyTJBw zwEWhGvyopNSTV;HkxwFs4xMZ3HV)mWpIbLOL4w-nW*A>bVZAtpVUGPk+)|DZJmeJn z0d8*=ry;sb5XumP3f^TTup(_|gndHs_gIUA?e74lUnEZBjKe2^zbD0WTdya{>PBf@ zrCqgerj@_iF^{WKFFK2TQwgV?)c!T*()iu=6mvSx3;m0PX=3=F<%zq}yFy)sKQ1fL z>8!46e{AIL*Ipq%c2?XA5AQd8KfYP4`gagMq}YYf@DBJ}%&=MVqhAoaAh#O{2x?q@{t3xiznVg^B&FULZLK z0yx!SkI8~{|C`#@PF&l6IcL2e-N=ndPt)bcdZ18p3wjjFS~~*Q*nnSPCe;4;6lpsVCRdn14=xRMaA> zNvJW+FQ|1RWU&rZNzTjSt>_lfNjgv6v6bTM(`k1~xK0b*DOiz)4u@uOX58Ea@VL0B z(D6(fT+n=oW`im_Rdbmgf{J*N2JYirCSDKj2izT`NMq z*ofshvfa4s9$QFIN(xLD#l?j%BJE3(U#kr`JBM=ko)vq9$bGz z3w)Il3$h*$R6Z(_B^6cseDf{*M7_{7nCbg_g)Szd1AGZ@&C#Lv5n6k8EtFZRP&_pS zQk{ZZW_5dHVzQi+BCT4kUv^xK*$Eb=izyVSVe0e`nJ;t?g398a7)=mq=LrjaXqJtn~zFH3j`XD;`~;UZyKk?B8Y8LvLA&BfZ69ZOm5q};GsgBocRLWzMTKM%2S2Bs=8ZLKj?olR ziNKM&-jT)aWT*BpNu98$`oRP=N2h|?z3@@nViXO*YrCj^1nJdp~6n?tL}2fk^ecIOP%wG|rjFZ0j>NEVlgWz%l+ zZ0k5ynnmqiMK$k?*&Cf_=-8yp#8aDv0cM=v=?46d`yF>nWFs)<<=-P(Il-FTKqTKE zWBc};Efa$US*0k>Hghg3V=Ec;6W(5oZXZ!*KmNG4Zc;oxPDss@{&1g38ba%y*t~XD zpWMpqa)FDLmUZ`~H}bX3%1c6M@OAtiS4X&B_ovGyQl?Hf_3h;+F=7XFa?bYTGFv5e z{iU0=1_hmn4RFuu8p-*?uy{K|T4K?j3mZRWyWM47L~6doKhXHF5sWMN^*Ys6jozUDXfd+PYlPlsw#56k^jF)l`gvYkA>=3$)$1q9gvqL2;-Vh@^+A`O zd&$(>q!sULk3QJbH|y^&xcdA++=%en8DEq@Eh!KE4-y43b=y@;#<@q=$&m75OrY zxjF93V0~=l<^=k@!T%fEQ{f|^p%|p}+9$vT_={@GNig_JJ;=&v%ZUwouv&^R4*I{| z`BM=vd^`9k;-=57?YHwn9ps`3Qw8(C7Kp`d@Xl5B`VHtQ~473~!w37|AZVa?% z2(kbM$(wt4?|W+}1w1(U1B(vuJRE{^1(%rGx&c&K=Xt~v=MfKO3{$LY+#*9K$#iDo6jL-F<2GhZw$VPcb#{P zY6uSp;6Cb;#3oj_mShftP7;A%<)CXnL1ORj8xxfR(8K`pE#k`3{#lk3 zG?)|y{lqfb7+MPsMVi75`2w1krA$lW>TSeYWe_`B41b5gyL{6MSSwCm=rw?xQLZ{Fd9k_P-jlNP>cJjDv{F zB(5-RieJ6VWMiIT>i1({#jNS-phOmJ8I0`wg%!3ePv6x^A32Sb^92`M;FUDmh=o@|9yRi_B z5Ah1-$3i;XOI4I_7P#J1!jx6=gp4#Lmd0u*--zYhRT$%)sAImn#fO+mE`fF;d8QyI}-tk(3v=22Hh zNdPAtVMkIZKYB$|RCF?tyK!?iY%J@~{YZlG6B~aQJC~yIot@#L{NWBMVG!Np`4guYb zF&u_aTo%zls&-w2lxh#2=J}avI02bAnbNL7oM!1!x#=rbECnb@MaikuoAD1E;js@1 zQ`D)sl^J1CDImsA<&FsmqH*w@7zdP0Z7AwbLjhkca)uLfSQ>HwS}u!no=72C??f}{ zFD~)#ts%uVI}}ACj-poUgZ>&$9ei4kz3=OP(XFKd-^1AcQ_bPX~j{S}xUBuI zg>AYuT5O1JqplY65wzo*S@81DyN7h)zW@#Zc7Xsk2L&MB0dT1UwBC%}iHVB6>?{L; zV{R)!RztBQ!=I{t1NdNJZAK+g_0Z$*ku3M&AsQ7ox&?N)6^PcQ6g3s{3l&4CJUm1{ z>cF9FmFNdX_!|xYzscN^3v>W9$}jq3FhvziC(qxyNZ&Drfw9aKnwrYU-c6+0i;Ah~ zY9mtjl`fT&A+@3%w-P(L08yga-LDkhh8guJU2Ca4a_ZB~ptOvY3RS8>1d-PGlb)={ z5BU*~+PjLdsptaN8otz@k}E&!2eXn131ugbYOuBJjP39(99BN}~ z!yBj9RN&Nrv_eyXRKINA2#46+$&g`(AWw`9-bhm8yFA-oXhr9+>axuj9)-Ks(da}O4ds#kKXJLrTTz0@7!1mU6v+-_cXIo$%I zZC9!9VhwFacf{8pMDs)ts#uZjMqXjz3@!FNQ0FhQ5MezL`5fkZgWpW;(QVAtW^r7? z+|sj!*NKVHa%A4ka3bnO5WpKo=?sNWvO+`8G)c;EqTaxFpE?vz#R#ou-|Bgi)z z&i^Ib^Uy?@&)oJtVrM4a8u45&7UnFG)(6e)wZtAk#qKD->RO)by8kl(AnEQ;rBr0; z`0&y{%Q;~BZBT=HP)M=UdZj0#A}Qi&bpcm+PL?VMz6G^M}DdFN0(xqa9HL zOq#9vFF6UVBYv@JVb;u8fH4zW$~oeZ9=$O)n>GzlZ~bb(s-k*h>*%l5k%5(=Z;BNe zp4EG!19>kl>z*zho~1C7aY>=^Am0A+RgnkJ{_588c)HGb&*5t!J~yf4hHn$t2ovAF zkM_0>WQUD4%ji<+w_l`A64Ot>r`O=O8WXJ<hJ`w*1+9a{NF}Df!tGV) zJRz^`kkjyb)|KSy(r)as&r3xWV<9+7tFF+MHsOidv5K>x74@Xm55p_N9(IflM0(L) z)UjjV)mmm|l-6RpS7ou+Rgc%CuUAox7jR8nybm#P$X;*QPi%485H8V+M6p zQ3C-oiC`+10lVZHxnS`OgJ z|BZjobll&7^tYz>?r`J0ajVuN;hy=ruT#0NiTB=MDU`N-PlPL!nPiBKXUlwZ_Xz)> z7iNK#ICcUWviGnkLmOvNo@LA6VsGF7LAHrmW<+T~>E|tvk`esx&FOXY2jl+1vBxOC z-LSR6p;%2oR{P%o47@ahuu3*ikg}s)Nk7K^L8I~r2>wu$`tXay;d_ZUjTfKWr0XGP z&RW&+QL4q!fUY_kMfCN(n-$TiJa6i5#{me!|6=jad5;;XvT5#`ssSa z9%vA@L>99`cAAuebrpftRkm&1c=p6}lpJyR_I8e5aOiJ;`ZwtG!X9?Fe(-$O+>#a9 zJ%)xF$uy&N2KWBUf9ycI?woz)Y#IB6QREC5>43dqeW(oDf+oYNPl{_U_dj^ITUq&ScxwjBWqr+giWudUkjYG zd?qhdA;#s6hWp2i%Qs0}_yv0oF6T!F`6YjJ2JH5esI>523qD8*s_9R zOf0YHCnpOL5KzgkkVmCx6ys^sh%6#A`ZV^~LV6NEFSeVJ-m1X{f`1O`lS?EECEMmn z+f&Q63&lJBFL_%~0{!3e7RNx!hFJ&o_usOfjv!QgI-wi6AS&6QG%`$D4t1>%mC)f4 zI=tM^`OQgqf}!)8!!foiZ1GKujLQvXo71%NU;YQ)c3p!0Z+IIj%rwTP&^kq+NmV3= zjZusM&?W`rs%cB0>YKXpv0_n3I$}9Xs(RX2ZW%kR74ISn-mR+E*-6aD3ErJ`a=bR1 zE44mMZuZAA8;-xX^4s)Jkz{)QVJ-Dl$z`TwF=gih1GPS)Z+SC(u$VSb5ky{YXuLUW z)##5`XghnUzdhxeA$@1R$~rBhhk=h?WMDo(x+0kGB)bCu;eNGn_E(i>jwmf3=t7hh zCI(reB;<@*Gsh>|nguL*ndw@A+D(m&&jTbkDLMXdQv-PRep5$aEA#0YckSMqgi7Mc zZAVJeL3d&l*^;R_1q-e^0ifPVFeM>HZzlpDg}#%^Ul#(*Yhgn^j?VAn?O3J|# z-JGYN)tHv-)AmDd*;*^tL#Oke3i3Q;w;a7$|+S^ihkR2~x*4Jy=Rt`l~UtJ8Nl8;`Eu)h!3PW2Tj>MOL$O_%O;OKuhy z4vC! zSKH>YZ^s`QZCL)e$MRcvZ1dTCys3$a7`Z;nM)=LKukdGTV?#VKTXSQ6r2TVr%|`k2 zS%kaD@ZrDfH}y?XB=jbR{tA)WBPs-R_k&<1d;62_5k=42MT@T#)FnT-MwnnhpWhy~ zkX1~nJNijaM0Og^-&pKQUf*7JL(7KlZ}v04WIjqB|9XFFN)3Lhy{|pSx&Q-(a9#qt zIzI?p_{{1>o$H!*ii=(N5|9KA;U#rpSWh~Vfi?#v54r$vlJ1`Wp_EUWUAWCV7Xc{4 zA5gd~#gUYv#1Ifc$QHmZ^s7q0(q~HaM9N+=d4&Mf5>in9d^fGYW$+iZURaiz50ma? zi1z3fE|)4Cn>mGG+w*8n3gi=O%4L`?A~pTwpfq=}+9$iPpLiQU8UE#}&?#nunJA!? zNpw{t7(^pzKgci6#1Q1mvOvDzGbl-96CDwjM{`Wsk3hs3=$W>Yt?HJhXk9HA^^(UR zLl~l9Fdm0&O~W{FFjU3h7{7x`%VNPOudz3lfDl8=n%X0mec+Th@<_`bv?Qayj;>=* zw9OpcRIJ%A5kJZj!=5WQI)+G8lagLa`Q_uR!jfH$+RRGZ#Y<%DhVm6Q0(f%r{VxDEasCJ<;D;XNQG?7$K?xjJ4UAF1m-1sG~`5b~{GEsc=% z)ElfP3Q+p2RPLTc+c^LHt8!;SAZ1<>V!9(GptTj`ELj@I23Bz*u%^SFcnmZs0EJ4d z)>E`79-2c1$fmp0DzsVjoxf`y-VfVwi&*>?!=vSW4{txv?LlA5+ARZ zciW)0LRIcz%$D>#sgz6R{^t=gt)_Rx3b(1g)l22v#;$4VFCl0mc~SLk;l`g`^>Suo zQIOmAcIln2M>ck0S#XJev(Y|))?Sx)+=-!|5viuGjtgM=P|6s+UBvON!&| zvN+uMW5R|!;WPSmh^FD@dP6Tho|d!v^{C0UIbZ9OK06Bu%eK}DSIc0DBQ2ReH^KzM z0MnggC{oQ9CGjIYGNen>;9;x6?&0%AIEIHySoHn3CUYvyw5Uv-op7PY!JmhmNdrQ$ zbb_O!0L2Jjuf}Pw@+Q5PB~Br*rx8aiq4l5knv889^T1o3Aj5(S_hVv{d_1CYZ}H2U z9Yu`Jv%e&1UbT5smlpotpISKVuj7B#S*fo!&djcHM@Hire^SwQMmx!&PujrBjUJGo7G@H^0WRXorN2+0|j4~R}O`GcYs8cC_ULUje4*omBVJX0xZWvz9FE<3Cl=nZA*Nit1pJt1?_C7`(R_JFCM@S z)DPQ}lvmkk!B;vnu+G`1mQUZ8gaxeiptP zQ|;BKt}I87-%4lWh+kSAmEGKk+pn4k1dic4AN3Rwuj-B5*)lTFX4KcPn>yXubpM4; zX^{hrJNg!XXFup=wOux7w=Cp-{;NModU2QYSB~A_-ej0rS7?cI8eA z!u}tu1R_CzKnN{Lkuf>HT=u-$y{7s2S-D8xW%_y;g%BN!#ghl;?L zjKGoe_YbYr*x#e2c<@6`fLk#gZl=`=$%Rgj#&x^Z%fCoQbUf0t*S1tm>z(9M* z9n-~tnvmN^aiG7=l)oxp057uNPNhc>*n_RmL%_|Wjou-KDqss7a1HY-G!N8B3bc6e zZ=Vft68|Ud=J?muuciLqS-rnCgx(}6$kQ!Q6KKT=aph)!=p9HHLHIm?7^Vc~!34IJ zi$PJArZ_U1zl-p1ErK13Twi>G@Ma-Z?7 zZHOEnb3a3PL4R1sbPCCG}S-2M{OS6elqhCu1d*ofW6VA1^}~r%H&b?=0gK7jfAT z@pmys$eb`?4>5riFFGkcCPq5Jq$y!;Cl*kYQ0x$67@T0MlX!@jh;N=SPM6?jmFV3h zZ@=`<%L?eD7-=(C-b`|J3#|;tWwI8h09P3vW8^y{#sN$rZhqYi>#cA593R9Qkw|V6r59s zcv6wz(|$^)bvLEC)uhrcqfjfc|L&2%z8 zOJt-3r2jKZj~UF^d(8N|okm5UArzc(@i+6xB-1ZHgQF;Y=rH4fFiR0JBO91?FOj8Q zlnLLQg+!O|vSj%o#RPRY3#VAIoiGFIDG*;b>rE${teNvuisnaD2&ua%FPeL~@0J{DnrtIe(anLQqLsDBuJgq6~4N z20TzzDX6XsR3{#46aY;vhQ{bZvxcFWOvMRD(EMd+kw9?)QE`cUacN3%VR3P#F7%cU z4%?{2^9g;n5<%qR``6q5FL*0Z)`nUXno=g~E+%jr^MAtIug)gKa&UA1C%hG?m|ZT4 z%`ba6;y$!4e_5`0d#V5sS8g5^K2cVF@Tdf)R-!&vAf{GyVLao1y+EL6uOg(V!u6>7 z3jSa4wxx<}rHbOYii)_JMzESrubRQ5nkluKrKOr}rJDV@nv=MOTd?K}S`EEw6+dy+ zR~WuduWmDLZ4qrPu4t{aUadrEtyD>^3>rZ0xK;tJPU*Q;gShUyV4Y@aopwu|;!54O z<2o(0dR^vv{pUJ!;(80gdMmwp%awZT<9gfYdbO5XO`Y109u01(4el)so+}OB&kbPW zMqj~3f4#;)k48vpV@OM5*h*srMG=F*ks^5^DC;+AT`mRh}*dXJXI)RyL!me!S)_UD#P;?{0D!PZ{A z)_#xH!PM5_me$dg*74`oN#eF?!M0hwwt0{L3~k$TOB>bdO56H#8;rPpOR#-MuYJ#> z{UEjdsHOd6rTy%={ermTN<^^ZMz7<}qvIj9>^0(>UFPp$6D>etFJ&K?xqszrqS;XI_hFrRehfH zWm)ZJd+BB;>ERUW;nwf@;@QKO)+5l`BedEh^3o$p(#!c0#i-xQ#)BrF)+^uItGL>$ z{L-sR(kDlvFRkCl$piY9)~DOrr@z`~@Y2Vd1}|ybXBwtjg9(67>SyNju^#QWS?zat z>32HmcP1Hd5gPcDHsIDe;JP~Celp#n1 z5pAPb0S(i!&akgd%X-bq+s!Jb&nl6QBRYlD-&rIv>%th^{ngcr)SqV=0#l%^Q6fp>ZK?ZRu~@lD_0pg79(z*X(Eg z`N*}!sJ7+ewPon(B8%;!<;z@A-j5XzNjNspaz@)qYul8q=VCVLbiMHExbd>K-6G_5 zxwLGxylr)SZB-S0;bOJ2_s1HiH9V{BMDyBzVbR(IXt|wsbzXQqH-5^Sb?M-Bb+m2$ zXl;FRY@I@BZAW;+WlO;Q_idXmYnxI)m*vxAmfzO@a39ZzSticH5%m5mMoE>;^lm_B)@wcVML}@MJs65geLh+uuZX zRYAi_Z#&xMIfQEq_dkXu%6H_*_GCr&STk0@>ARL^GHTv1p!b}{+nxhij~V)KZTj9H zHV#_!;Rml}Bjr7l^?mRjOs>=33gErs4Q*aq`C z+fT_jOe0%bTJv>bJIwLUx;^Q>D_si9I8rZP;dTHI-IEl^yG zySr1|o#5{765QS0-QC??Ki>EI&zd=ygE`50)=G9(o_+tWy@-6_$=ma7{PZYeKbp2@ zi}(CBVqq#`ILqVwf%lT4?!2x1Y>4jSqW$tpYyVbyxI_999pyX7Y8bYn|1IN+yG(tl zeGQH8ir7o9*Xr1A{1Rc~lCokwt8J~q;)29_r^o7QGGhIk_6_0K^@=70{l_(V`&g#u z6}`ImiHz^4@LFo8ngQ!8HY1kx2`DnZq}39^dtWC556Df zav4h&GB;M4k5jD|RRoWIACJ36)0P=$Pw5ZtUQc?FPg7<6rL8wne5+6GBO!dxvte-N z)^`yfA5SHb2m56ged`aE6;IU{V~PyV80W{}i?e9=k-UwUq|xgi9}AcnFTGx2g_$dL znNPz+FC!fzSssh^p0^p9W9QF&DY9>5fp2MP7p-2W?H_L%?u)TRD|zmBeHREw3T5+N zZarD=WC`zguHAPnV*`?`tdy%yrZIZ1#}^8KuJZ7gNuicaJW<)VydPt z!O6u4x|WZh5%5OL;Hb>qR>WUGBY0pi0zU0$eQV(!mR6c( zF^y0Bn1;qfr()?5o?g~Aws!UgB&6q+*0S=9$S9AHvP}O=qRG@Escr2QnOV^>v2kI1 zAqMIW!K8-!Mhm*k67Vl0g~3l^-0lvh{wClm%9Pt1Qu^-l3dfu`82uxhGx$xMMzbRN!M#d7(~5Eh{cq<>fhI>go5LRi2b z^JN+>T4pSarHfSt1LNSn&%;M`G~?&uBdk^`4O%D5ZJ}%x^8(}v;~KE+m787uhE#^g?l@~PRVt?#&fRpqH%c+_Z7jhm_h52P+*qiMbLnKP#q9(69Yb*?cqElC zlDEyZm$?kaMuVmG_38FxrJ+}i>u3NQ{%B*I$oTz@0?X-v;IqChu2S+bU$4c+vqpbs z)wRLUGSLmhv;vV>SDDsKD?m{Hx4w18R)oe{3=<3}WmgWxZiCZTXZFhZPk!rbKW(w( z-l!ENfGlPZeORjubm#hSe*5M7ve&>ErEIvo-52aIW!@pfL@i&t7{3i65HR_hBWOR_ z%&yTW#iBjg2qu&#={JnJXRvXGD{HY{xOq>dy{ca$oUYq%{;)=ATLH* z+*HrtT%ph>)N1&{v~Uo&c{@GKTRf{MV7h1zluXe?AXV`t97~>UN0}GQuUZ&0scnnc z8o*U?ToXw4levsI;Ys>RHzZ|HP@|vmE%vMZj3{4Sc!7X+_+18z@&MDgujAOiaXk4n zhb1FU%OE#{qHIFA;iL_X!9LAS` zL#vcPhuyz2gbsT#umOy=&)o+1^Bx~deBIn9JklhV)BKsKuJ#I>`SIL+}?l*hX?+fg;>e5ep^U9GtcgvS6)?8^7 z=kAZ&*7>a(dEfuG8x;k;-97Jg$eflp1z1w`w(SA-nD4l@qP#a#T%qzIY z@^J!G8;`G9Q0g!|b@AWb7G|_y7({=wzL}VP{tOdiJa(dnRN@?kM!ZZe)P$yAqLRIi zB8b2VlDE+dq3>#UG!}N?A4jRP?&^GbZixBlvjGVy{6UT8FbQ7hS-M1m~1<0nF!Zp$pve;k>GCNWDxd6d^@Yu?r;PR za0?FyT0+UixWf+ce|-6?H7Q04g*YGtO)nP}Rw~oZ4F{&vvyF)1j{A^ell>r@;8b3mIV`HTWWsYqY!oR}HDmR*7eXJJ#esP)fLWmQ8lBDD( zXqI&ozf~E;SYk~YYBLe|mpgH4Ql`;Vok}T>>YY|3HZdO>w@->rtym19o-%EhVPEhR zwWH@b*W+|-&syjge1OvTdAxhtzUXHUg`_ppWG-oZws<^@T;_AMi10+AFoRaBurT`Bg93Y((=g1?7lss87?E+UPMh-xL_Dl3^AHFBuY2pD6=#4FbjcC5{# zD>F@&gL4_DO9-^2lX)>wj8YP$zCiOgkqKQ{J$7tJ@h0THJ)W8Cj;TpGE42+B8C^tR zb-h5PvCc$XI~2HgsGU=`jZr6GFmq}-_FuEe{w;G9;?%mkXAbR>sJUlX*?3A;?#2}& zh~I3--eG6uauK@m%KXo)lc2=`Y_ANAU>pR%L1kJ^;kw&RhmtVgJ|23Bh4T|q#c|Qc60f-1DQR%MTC>b zZmf1<hsKr9jAs2 zUaF7^^gn+_;YLK2I1TU^)uhtF6=w>O4SqrfyeT1wTqn5Vu_iP=crW*mz4P=f~R1(aJ`NZtEk@%b`bY z#vJzrB}?aQ!z!uD);n&S@q>}^gHkp;d};(|N)0_r;Up*9&=uFxkDX0R32y3X!=~R6 zgx`SI4Zm?JAU7%+hgQhWWl3Il<;$zs(Gjnb*fjR*eK=Zo)NB*H@OGx>o1zAGuDSx< z57()y^tRRPGObrPcM^)H4PU9VAf5LWyqxE7p_Ws@EC(_VEE6p2cc?^542%@bOH?o% zloc#Tx&v%8is!fWUhIpjlvi^XR(GuoPTfCMX%_aREp3GHPQ8s8y;AM&doz*L(E`)e zmb4zz3tJUvBHjH8MNKBA-EW)^T15#i{*-0jT_k+)-7f5CJ-^;lo>VHeo=^qk<>q2u z7e74j<|{lcgWGZGU6{9dP+rF?TBIE^8w=noUW7Awp9BfHOcGz84|6HjmPKjKlC94L zw+ZhS6Rn_42c5P~G+&uSsw5*O-p*+>E;Ejrp4<4g_l!KXfjFH2jZPr;Vh6-R<>!L? z;k9`Q35g2zfepnz2{E0B^UMZ#RYpcS^xn0RpYsRS9C%(7+B7uKHwF1j?%AQ|Vaz_z z+&OtA-Jv5<0)9bw)0lXXiu%&B17NcNU2L9`Re<|MUs9Z}Z0rCI6Fk>GE;SKX0(Sio z5r0BQgYPU}JC{h1I$n>kq^OHtOoxz)CS(E1hd`uFFBD>cA`S$*H{jkFLbk};k0b!4 z(JOiv+YH-q(8>RUuF&VA+s8UdU5>)b^qS^5lZ?nD^ldK|_E#Te>0Pn1in*~78 z@j}W1aEkh+E(DPuVn-i>{!(BFiUtuI`6m}?gW1?5RFI`rd{=A&=lDUNQW+XQ=x35< z$fBPiB@P6@XOfMtrYKM>0l3f`+;s04m+0AoqlyUY`CG&h_p=Sa0eTGbdEtQ2Dd2S3 z0f8#a1f5|_zrEkFp{Q`a`r$x9WC4JTASiDDBr1>r5O~i9iKasivlzaC6S1Qbalj5Z zD*A;m2?+y=V84Xq>hg9wgk%DQqv(JjP$7smz1TZLjIJSpSpoMCfMd|t&$<^l1c1x} zM4KpnY%E4WT}DOuM$d|b{JobuCO{1YPzna20m2s;L19Y(Ex2$z;s`XI+~JP9I~18`ypL>dfFIf5+5fl_+}aLNXhgP;t$y^&BMKEe+n z*(*bcs#ETRpp;a+CJ+5^`D0s>nFjS zMaYQQkavo#kUb|FR?M&eF!@Fp%I|FOrx7bSDb)IV3gjBwsmjbW0{ z{=`WTUpN5lj|(t}!#557t8s9W5)0(^BZ)2uU@PoPARY^|DTu=ePz#9`GqYd8RdW=}`#>Vk zt0c~d9mVI;Qq#PF7zZffe&7`NcE%W%q%8q_CQx&yl5}An zA@~eK$@fL9#RMAyivK84(4b*L*pwfoKpr;#hfi>rCf*rVM-d(W&Mx-c<9$ zucWM`xtiY%(94)(hvlT+8j;gw$W?nbjC%`%IF;k|S#UmCi#M+N`!~~woI#AzSaca3S3Ng*wrPM$LA78>Z5U()& z*eU@|*yg0U;)QqbsmO~JIRjkbk5>}rR^g)8>ggU9$6;NZ> zsvQcB+^zvkMm@sjPT_sEbgAN^p&2cy)pg7ME6Qbv*SY~{|0+Z{bcw291=n&i(q68g zw^3rWgp2Xi%;Sm(ct8FF_!7>q%B(vGnA0hM-+?@bk0Zsaj8eyMHt$3Y?JR(6>PcmE zm#FuK|M3P#MLd#UwDkKHFAKf5!aao%Eh_fQa&-e*OzX?<#N~FH(eC{i|F^o_+QX*x z)Rw+yMNJHVF;xP^eF_;8U~4g%?~iZ!QXAWkIQf@4iIG-Tq_!c6Mmj3EU%_C9I0%NO zI#~S{KRo&^(62BKsT>@ip`}{3U&*O0URmr7VhQcktKA=#r~RdRL5GH__5RKSgA7D^ z9q4JA*MIT@l7K%b6RN6mh!Hw<+50oq^B1!kpL1Fwz=J^A0&hxS98TZlW9zJV7lFDj zzd*+XWyE4-4&>$MH>#C~wt0FpgwqDNMqNUC45AoqVSeeaksQ&4w8s-xG|*nSD(m3pUA0IGLPkz{$nW3vy~YRoneQ>hIG8^OxMZ(TTS(3R_x>W5pza zsUeTJfgv}`4x}>KZ%~gK1*8bxRn46UG(*uRBOfF!l`-%`Wg1fen*_jOt=eLEUeb_} zMx#?nTLpYLaT$O|C;_a}9Y5;J2mixJqRlNCIrT^J`{)c$*>`XQ#ufAkE@Ka1ek8GV zU`4(Fx^Wt|GL4Ov2JI3sIS7Dr9vDy`!(OX^dlOVy`rEr1B#P{5-S-7{}i>Kx3YNy4?!%^3X(@^W6w#h|wzx6d_Ie;dfY zsB>#5Ue|GnOG(+5%&xnFlG|Fe_q3ONRgN$lS&u{#+FIb#If96TV{i)d&a~q_mLpO+ z_9_{VLl^5TS2od@4_kCk0P2EcOW1!})cg$dQ`;Ci=y(#>8f3#)t)3P+UhJY0R@dSg zo3cYcSCR!yQhVdlirSWCs#R`{jrusg1YcFnbkH$#I}0>UzLPyy7O&=Kr9CO9SXvopH@4Jgxf&<3yM;_$T~2{yPe^iu;ZCs1e5JN0~urfsXx8i+IIgT357Ro3Mg|$rpO->k8rGu z#UminobToP5a`nHhw$$kj_-k|jPIURH_SFm`>b<3mPnqKY}dsd2={>W9SZB3N*Txx zYpXctia_3$CptXhjQwT%{kD57W8{5<2&{;V!+G@0Q=Ub}Qe8^!jYQ;w4!4b}IaUPE zEu`_)@v_Yih11^+61z^cdzV}mcG4{|9vQ%QpgQl7&8;bF_Q{v76PT%!f7$#^?Hz!3e>*`nxb zH_z?4hP=3*Siklg=%*x7yb9`H?bI6wuDAN>t+$^+zP0>rKSs{j+KmWFDLXhGKRY7a z5Q!9nP3+#M@7ZAexDdN<=YHj-{u#-w=0(jY!>z@~uCs9mZG9JW<&#+6|0OW&cg6i5 z>)jgUI}i0OvEiGl3^6Dv)qQw#7^{Os-mBiR<~}-P3a@*n-hbN$;@;JdcLkh5M7Rv2 z2ZNeAP7IIsVduWq2V3uFQWsBIgtw`ls}<+>p68q2wKuZ(?qcS}d6_PvP(<2P_>q4n z2ITQqXYz0Qy>vmpIO_&&Pv>frVVXK^qHNq8%B~AuT&~hRywx1J@$L?6oRy8|UWVVE zm%sAH-QX+Tw4waLd64L_e>qLOYnFMp(|ug`LP`nghS0>1M^TfEcpIHqawc3~E#FFy zeDkut6$8TIe&7Op0|Fre{epx2!+gU-qa$Ntb3FNfxVKf)a96gk(^qAvTkxu;4c-DR_Tww zdcv{4w5vgJxAS?q@y=@rKRBk(HN~j)#dqE8(4>+Bmpn&9r($$`sN{-|_v`+bdi#XU zRia2W^Ni%0s%SkU9wZEuuB$6b|5R_KTJEc?F;R;ZL6^4MNMBHPEHPXvH$Zl8F&M$G zld6684=I{rt+b+t9w*XH>n{d^Tf@rF((GF37AuD{VC|2I_P4p>@`)DT3N=rf=lruB zSj&yJj}O2X6%;lD2N5xcFOUs*j+=i5h;_H#<`QkU<=__d0{$ovN_GH+MeWdz;}_MJ z_Y`&XVd1gl#f#I(H$xfwDR;t-qOmeVbMyC?&DlxwcOous7bhaQ?nQQ@Mahqdtn}s3 zqio?k8#KtGi+|+%p|i3NN#J9OAp9BojiY571xGHfNwyvjpFEu(yy%a`jG;#&Ol^9Q z?z~@INdF=?r=2PD)`So>cgRMW?QmZF8Q-o$C-}%*?~%FK3Or3E6Y8d{*}$eOlBAKD zl^i+#OIR)lu&T$T4vNCs|B*)*GJjZFJ8wqg$~yl+u^Ux}*#uS)TBoF`@c+VbS~f@y zK9!yM1uvUVr3OnKKTl$!oRn$iQtTJ4RI-;cp_z!CVkZ4Vw~;czZMD!Q#iKDN7vQOrYxd`sKtu2okrV4H`R)f=yq_ET$We_e)0oS+k3-7`odsN$}$D#MQvn z>XWUegnkz?oh`?qmEfvxrPsr>4X@(mPj6h>mCb~R78Tvq`xRsK52rU?lH$ErB5P(3 z(;nqTSz&8mHx5S^_}~Yp4u)1xutJunR>#VB`VV_C~c-w^V;^u zk^e?B+k2|!rq$$BqTG_-_wf1L_ zDAm)=UkRI%JpFdnXVT2+%XhpyXx#UHAku@DYK<%2?#DOP6*_EF+n?9+-2^>PRtqaU zx#74rxHkL9+PQlF22q(gqaezZH{W8bM89u%tYevn=Ywf0@@sm8Ii1xCn?<^D6~ZtoNP?a za9iKd?7q!7R(=`}i?A6keGM#;MkrwfQ(v9WX-#xsQt{FayBBxe9*dV(pen-KLT6Ig zxj!ka^!rs0mhQcB>jByQ~V zX4#Ae1?cctBR!%$`KBhOlKv`{%HzHt-)eaZg+$OhOxYCYEq|3^9Q>v*tGr*bi}C)c z_{&9Jl9(FZ_LNi~3p=)MfFsC4ZS^ld820aw?z9~xq@k1250Q}N+X6Ts1z}&Il)|+_ z;mT|l;YiR-4k;oV;LpLE??ZdwWu^b|YtGd=#)JqxQZ9`{9>UqsxUkl>o!AN~+0hT7 z+?2agH8~gldstkfH)6 zlWU)i>8~vA(F7GuMSi#DOAX~_i?sfm&+FM)ynd+5@u~%nU%JA``R;e!ZMq*D*&YGt;tiSN%|5?#AFgCZU}rrk@g?Z{U7kC8oaV9nac>LQ+>B^vB;6oGstDyUI6URYfm#k@Te{w*Q)y~u zqCGM|x&$>us(FdD;pzEf@|+2oNggy(<;}J6oG1Yl}O@>p(XF^JvCh!0BX^%|A&zVC|PQ5i}hkLM5d*(W4h?3MG5RRy#E%2ggaR%%Jf39ZvB3<&O_j+{>? zRWve9P2Jr?zO2C?OJay`5+fJ1250BbRsc2UmAPJ)CL`+6QVrKc@IcAGXhU&BdH9Ry zXL3i|e$G}tH~hW{l-e)9(}l?MvY@FTXDLp*!NUrcRbVuQ7>T zQdQ?elyIb{%J|^nrnj@Kf%0)`qmaG|QFm98WEQQP3O;AiazR&MWjA6Fwy+Ut2oQG6 zuydfxe)U9>CnI1p;BV8|5PjWdpJ*C)I?uT0A>a6-wJX~dYHY@e@H_+TU%iNMI@Q+a zIE-CmNeo@F7d778=bmVuZy0wdLtU7O*Q+t0GFAh>FGSNpTo(QINV3qjUnwLfY_ePn z9$l`X9k_2Gsa*fKYkzsW@JU73%@z%vr<=Dt81 zN{onmTdE@M&KilIL7qTkj!p%S2w0`)I6KEf12std)p-M5WtS%?U)x2H)>dFWi*y8g z@M#S~++T^IM5YM^M;#m!RZ=eRiZ!#V}Y(M+I1dR7zArI9&_b2El1VlB(Lx| z4^X4Mq>`ZJE_ax71hflEtte%ON)&<%-Yhv>9~^#*lhy1Yy-j$n_TFE~92E;skR45< zPBj~7deQ4E2eSv|*JxvNv&JauA(YtP5o7f+j8vEjhf&((bcgo=_J?$YN9HfEN-o%L zyeg(=ur{v+3L%4HJkFL*MR9#eF-%W!HU-hbMEHI4T5)22&`=YiIl zS@68PwH$nmSN_WSoytywGIC6X$5bvt@q;6AC`$H#IMKc-=BNnkeZW{SnOM+I2ND)k zX>1QJa(jr4ba1qZZQ#>Nij%7QRT~KBoSc}zfq_;hpV*WPo{2(0_hM_b8xwR~5jOKp zc8u-Qeo^`NRW))NEk&z{9VSrb0!J!-D7g_An2f8?l$`jWBq+xmXO;NZP08zO!BrK` z=Ii3$+(BDwv5+%3ug}$8pDI9$6{}$>b$JP;u~XW*~FpfeXGIZl!V z*#nF6ww)1GwDFM6sz}m3WoF)8_(oHv`*iwYRQl@@=mE>D)+ml>$PGvSeLXh?Y^*l!=270m%E~w{AI_h-vH&ZX; zhv=)XNyBI5TL=g5=cq|iV*M1Cl$$0U<{75C^^snF+e1jV-!2CWI3e} zM>Z@mnj9}d!WpLpALG;LSo17RO`Q(Jsc>d9Q%SKmX^uu2AsKZnhcl1w z`_zZ=s`uD(ngla)I1F&iGDn+|yjf))f=OO_tk@8=s8`%cR0A27gIyo3AkC{k_Ejna z&}#+oTtS>}-N|6l5U}1;b&l{q+0gEEukyu1!Q-p77)6LZ>7?K=_QtYN>{>iaSwv|R)!3HM{UMEh&CAIV9^5n;U zldfgv3*fB1+JPd)A<=;Ey5Uc?xs3*4pM5SVXSJYObsD1UzodHV)T$*GM=maw71%14 z#Tu3t(_!ZZ4yx4Ac$f4f(fO*TZ1cPft^(bb7R|!Sz>#VwtQ6hZ0tkvEZg9h3ef>vB zLVal63$UA)~p4H)q0`(T$@iA1Q z?4{V{`Ue8ty(CwpDNrSDT$T9NnS?_;CUGF_2sgyn#KO0z-m<06_e;e&4IElp*B)$e z9Zd7DR)>I<4t)?hK(g=)Yt0XfteD~E=@uJuq!59jgj^le#pDhSN$*3$o>t7hZMJ6p z#*%^)xTEQSkDD|E<8H8eo9>ay8+8ktOCe{mp@f<OeO| zD|c%yc%-}598CkOKWK$TV!nQ*$oN*QSXH96kfp~Je<;R6VIx&~q`8kazFkelA=woq zg55oRSDKYjVOG+4)7b54itx-aRJuB>-`_zHR)F`tf|^n2gr4}Rc} zZ(GQ7;g8%L7cV2_!ZcXxhdDNv(oOpf#s-Zs{A#EuG2Y}Je{P+a_wPfb9adv+3i#)P zGEka<)XzIuWSBP2p_<^vsYB+x;`%x^`)MPL4O^ckQ@b8JJo{2z>@RsYYT4j<1%ZbC*mWq!{IbiWsMK@jL0iJ!L4PC zuNpkz(tW23A90#XIECeyIB_FXlMy1>*h`TdV>V<|*@!Ep0#Z$(?!j#AO9^Ys<(3ps z-`WpW@^6y)mIwQPXfoZ4)>$RzOT2W;zNEzkNMm*`S9q)rzL9PXl}36veN}5;g`41C zjh$~X55#UM3@;m#=%jOP=ucnWBv?OY%KFPQ8V8vw*#Lexw9c>`X$o2%51a2f%Y(BZ z(WP5|Oy8(kq=nC|xEO0}B}gBBNwS1r`}rx%j74n5tg&g%9>QL61tP>$Nc=kQt z4@P9&Q}VX-4kn4^qXj(UF4~+i@Rlx42Pwwg#1ryp+TnDN50gIcwN88+5Z(ybW#7v;5)ZrN6M9&et$E6X z2#!=FCVvok?1Km0GiK}GPbUcltOHRgi?B)4tHa91pB*qYe0pDnFaDD6vP?JdPwx`lAdo41{R4eKkP zOX_KJ(#dAR@J00U%KW_~9pN>5$NBef$K3SxPmAOF4~pjm`9&u=dbKA6AG4IyYw<1Y zJG?vjj3*p?cO){zKPr5`tV`UKPktLScW#@c4*?gX-Mzin$EO50Z_)mhxwELa8SD09 zl3Am|Y@kQHXU# zXg$8iXx+!0kB*l*!hq0;QMmL|pR#YqVeiZK zCQ10%tqwCdlx5mN8QtHu2vb8e&L#&9Sq!1ySe*Z9j3o*bDnZH!V4k~2E?jnJ8WWdTsWy?%;dUDzYck1~402IDpF_kU)XcIzW`8VSmi+y++BDb7uiI*c_wOCZ z{n2>2)fi<>qmqHb(3BM3#%AJ)Yr!%!vvKtrnxAzN=u}%@>)arQOsYO3{$nnoedN8 z>rD%iWphk-qin6Hj8dK7sdm%N_Dw0W1J9|9lH}jnjT8U0H|6FC_8%1)M%rtcnHE?H z2j~@Mcp{h;51WC@+Y!u5tN+X1u7{YHH=hriSG2z)fc5kL)7|zH zhFVk)`&J`Wj`Jc~ezLa`mUZ*OG^f?eRwHzD^hk{SP-_tqR!x{ZcM3nE-GgL)9;xM? zxA^v|LjVl>@Pb>O&tKx&4)?(fosb`?R$WhoVf0-vDt+h4Zvn0 z=7;7b3b&sYXBe}e`Hx$ik)Tes85$+_St3M5BOb|kaYJ}p>@+SWamDTxa8oGHSr`+;*Bc|db19p88 z>F>$;YHhqXVTE)C$vL5s>7jra3b`mubU2}JnrVTBR%f>8`GNrn@0T4dr(=P!Lytolyj(L zXQHQFYp71%=WT>Y&m%q)U%@vhob+JGyPtqsR2@>QlN)t%{?AEOv)huOa&-?sF&J40 z#iNT$2lDEwkni2p+7JIU3zqM&t4-6_BR$jyk#o}*)*BHvS#$${lLZAs*$LBkLfA;X z*!>W14N!vW+eQU+AyrlcB{F)?3Wx~WL<7F^q2eQRdUwB4G5mw>g??eSZSTsax2=Ez zOe(<&QdUPod^(bsoSnSTQ78b_A8BNbMS=h!VvnK|{wo`}Zti2-NcU6ynMUJ3UB55U z`$o0^$f%qW2K+9;2HH1Ya6|wKXfB)I^2ioNcpFS*qru-vUz-PoM}JSpe*d+JJeLTM ziA)@Ag$F=<`cf{kZw*L*_FloD9O^O^h!0jE!DXrv)qJ*1>ZL7U`iOj@LeoOmNYyZU zaRhk_Y3W7O((VB+8+%6gV&YN(YIJM$LK3?di7L?8Pgo-3{Qb5u;YcP9K+#mJ@Ui>Q2g+F7YFAL2sRd3;giEe*s{p@Qsl{8xGt{wpQ8 zC0T$}@$WyyR1o~Ka^9GZoN=WTb`l|5OVT!%&}Fufp{u+5S5-xI=0Z*?vbmhGWvzy0 z_I%j@2ukayN|*q0v=-?ywOB9%MxTBU=b9|^FU*OSbh{12jmKdps~q6b`RT}t@A`B$ z_SlvKwzI%Nb4=9UT_Irp+k}n7&#KfY+T9RDJ_sbc28p&$UU{ks3N)7nOsFK4X;cLi zp?BZVRRzm@gl*d~Jv1dXo*=`dXi2zZ=XqdWkQ*g00JXn@1vT-%d|U0iJHM`Pkpe?b zsIN(E)?~e;QBdlZ==A!WK&IeWhE#;@(VF8b+|iYgIl~S?G%U7>XBc38Pu7){{hBIF ze}l+2B0<0wn_`jjkt=m0gscbTBQhfQuD02$eXHJXa4W09?e1t z)F7taX#6iBM=)p?fkt&+D=k0JMM^>&MJ^qsNast@leVaY+wyEVt8%lXCZqA!n&dk0 zHGPn1vTXme-Ll11Lm!(wtk!vMV$$z7HJ;<=qt1qG_l92(1Y+>*FWJUeM5I5vENvK; zm!TONyHMKE#Q0|W#H6z~Q|MZ^Tseo(6!s*1aLxL+yn6{cyP!$hgGYp|w?(n|uW%}D zvjbP_d{f!ON9ZiVf+qH8yy(WhA6~7j;cszYG);IB#!xr~bn`PbPen%D%D|`7bV)bQ zWFm{+6;T);sgpF%Rz%#DniJ3}EQrrEAm3MpdYs}Pxhj=P-`9dY&MYPZma4Ar>PI}z z?JHW=E+QUUPdqN%FIqM}kRLk{Jum&~J{feB$j4q9&#N%&)*T|0r$GtN>$r;6J%-4q zQFG6m&+|+Nd??S8p`N#S476=WGLg^!3%%7wd08CsysxcjJF|{_Sv~Q5XuW8=@Irar zMD%*>WiV{Nii~{SrSWQu+}jK@>vza1{9|F~I{5op>7kepC9BlKA++q6vch#gfoO*q!-{{AJ(* z`aAu7*#hKD0+bsO72%wvK>?bF0X4Gz3WHSI;-4vL5rW#I)2nWv`CwqKccAA(pe1FH z9ZryYVcWpZpc~CW75_%*)%Tl|Axn#vJWiwuKCi}R4Mq<8ASTQ( z>KIz)77PFL;E*}4o7oZO{2+YZ47d7Ij1YVxpptXIF-osE?g#v@GFBqFp#;ArH>)K~ zt>pMva!nY81cRk`A=NnP{=_`!#4UYJdPbXW-6a2^s$q0XiN!APxLHv5&We3k z$@NERX$xtpTdb~2@KC|ZVWugOrfKC(iKVJ(*|L$)y@vR;A5Up}O|DbPnCp-5ee!Co z1$G*$N>NShu3~9%M~MWo8vaK3ho%{CxGt|Qn5w;iWHA9jczsgxm@W=p$Uq4lm$dY_ zbXOyNtMB}nME~3-5A>evrAcx(LXP)9(?$r1{ zF^_nFlx#LvXI3*z+)Fi-N+tTs5d1&-zoGJgjG3eyQHvO|BR8fsSUjiX3ZF?aKlLBf zV`&8dQh#LM=Iv7Zau0Kb%isa#!!pmg z!UH4)I>UwT4TVZqpbbYr^$|R|GAEUYbz(lZP@MYR_r{_}O1OLH&ktW=h}s{7PCLb& z2+dE0`!d+XvMsZ6HMa6EiRA~w^V8EZRWv)(@h=^5)ca>2Im1yQws;pJpoPpRQRr3i|5lU4AG}YI0 zm7vC|@1fOU_f_Fo)qGSZ)me4b)Na)b5;f#lHR|)#EZ`a|#p+4U8ZN}zL5CW(gqq*x zwFoS=I4?CKBek~iwJJ)rQZ#kFvo%f+wF&Z}byEN8sIcl(PwI$j>nO46wI%BJnd`~t z>kPs5^YV51#1vry-?>N|Y$Y1(Airl`)_dqRxPcoyxEe6L8}uL>++P|z)f=?_)tiGG zhtSh!d>f{(8>hBeOj{b`x!A%Hn=&>VQ&O8WVjB|>n{$1eFtQ-?&6^8Dn~Ou6*$12Q zz|9qiE!3UO={D3=<}G0L3N^u|`RtJv!)*gx0KJ(Vk=a!|RwhoN6jb9otx(HFY29W?(l@89mJ=<02*BsNJN*ZUEh$p@R0zd?j1i~yFWu*=FwkY z!@4=R>8T|lK1l`d6L?12j@wXpfv|R=*I(p?r9Bd*$@F0e$giEBvn&E_->=dx*4AI_ zqrK|xwAi!=!rbl5Nd4Hry%*Ecz8ryY*MwR7!uTjrP0Y zQ(K+(h|&)5ybipp_x^kB@=xnE8660fEDbrW59IEq=k5!_9{?~8MQaSPjkbr44()CC zW$O&=st-ir_vF$JvYZZLr46T#4(D$V6ln})bGI3~4=cA0^X3jDmz9n@UJchH^`;Y# z9EA@0y!LyS)-|+_47rm+BaK9>kH$z23x$86>N@=N{kSA zOuR}C#nMcXoK4^OjK8^0Q4Q5xq_Gto3uZ!6VX*0i|Y;&PMUZroW%f zNLa=~dLWSTbbMPIr3#x*h z3|tDOTY6xdiR4*I1TOV+PrAJ=r1PNXzR(I0EJt%KYkr^42w#RBTOtXg`>DCyxw@LZx&Fm% zhhRJNJEHIH;oTVl6)?z({}y;TZWp6E7# zH`Lm~d!G;ow#H6E(V%XBr{5-%-lnkH-oV>p<=VnO-^MxJ$|as%#ormq-(q;*KK7XT zL)XDMu6h@~voE<|udy@eu=Bfp%gAz8{mZVX)tET%EbjV(wDexZ=8pN;&T8v6+5E1Y zV)=-y^v++cT`lST2kw27vpKc-J;U)K9l~*AWHpiT{WFiffV91>*1em${eR;FRv8mE zFQ|dBZTtXhBNwq-yP@deYoL<4>xeJo@c)6w-kU!z9XG+UbD9&Gq?|* z*F26ATvyr%k1X>JpBzmbczBXp?W0B>=TRShZ(ksy!A0_}9C)$`? z7C3*UTKV>+l2pC*d1>mIM&P+(WW_ip>_Hj!(fZ;;Xb`>5)y9=8{VMxJ`NvN3}7H z;&;$8n+?sC`{5C((i2wTHJjvhIPPj(t!PJIwNBhL`Cnld_^q2 zuIql)wX`RTI~9C%Gf2%ImT_aQbks9n7WuiU(Q~(%=B9GFNtt`%vG!;OmcIGOdAAj@ zS#CW)sDG6C^|oiU3+wT=iz28g@?Nd9-M#R>GxVmE;NiDw8M_72iq-s9#+`8K!$|AY zLB?2h`ctm7WbWChqd_rPUGAJ^b zIVB_W-}i5ez{Ms1i@=ReD@hCNAMgoDkLw#8_ZyBMo0#|=J2gEx+qGQujClmMvH9Q5 z-=1HxX@p@RWO<%<%>n7P*wQ;3&R<&ADf)M z0+te{C@zMe*@w3uOFu6^Tceq;L#bc5oUt?`UrQ&L ze-bqFbkqO&A?=ELXTI*|-qLVm?U7@?uwFp@%!b;KW~G>lRTJ27Gw=OFx=M4y#okQi z9v(yUnSOCXioeDJ(CzM}hcl*O0`8r2Kf=2*n{rQynQ8U|!qEWT0-k_nv zT|V5nt?ZR~QtcIU$#DlxN|L$G7BTjIBjHcQl9vf@!^+$Jcjk|KX#+WJdnt{x`}&`i zWkwE!l%<{y1eba7rSrEg@C^xXaq=gHKL4Wig~Niar=H!S!&(5( zyxZl$4hrPCa+}r)CX)D*N+G1aq!=}qa*UrXAtC0N;66?DYI_!u$u6FLD zyR)&*S+{eKZUuvT!~xu+o%1>F-*%UX1GthQ3z7JMlg$FeVY3sq(H7Q?R_@2U^T_9| zSKHIj$A>#?2hXd=y$k@loRk|%nd>!2CK5@XRAk)|5m68f#fkP6mduzBQAs8mH*5Pm zRhhR5Elw83j}@UN_X7t^$}B96K53h{F+aV^MRXmxPVzY(+U5DKewVgR>YL+0sk@c_ zcP7mghWa^o&KZ$Tkr6Jew@0PeZiTqT>(D?hc>}2@Us34`>WjID2 zXN6^W2;u;4GS}ShS<<`W)>}zZz&F%I70q03T+Za=P|ueA&kY z5;k?7_u%lj@}WFnEQw6e_(X7xQVNHo{LoiotT1Sp6a-T$5Zmw_QoHz(rga=bxKJR44yOH45tQnomj8Q9ek?%Dj^|)uN}NVvJcPI^136~%#dEXq2$ypy(&lG0 zlbT(Tp!3{j-oj`C5IE{ICWqB6yh!O6;d}G;bJHhjMJkS9(!`jZM1#~#diQ}}+2n#B z-!wClTJKh+QePCkLk(8Q=s=?RcmY-)W>>ETf1>d+DO3svYxasolZ335q)8QP{~DNr z!iS3t!;5uycoj*Flmp(R6@R|G!Hq-|YgjN!4A2#4I*5tOu##ByB}oP~U6ta!a4bs< zSm7~<$SaIH#?rDu^ z^>C@mk9qOgU3hKQr&543_QI-mTwVQH>5s%g&2_7|`Y}DoPkHjGZSBkIX5AG#=%Ajp z6-g%3fhfdD_;ztN-mz(Wa>H4tVBx%BI=ACs*=1Q#UuDG6{-6}KyBjr zA#oW%L7INDvFs&+KLq0fULm|~nP2I(L7B=`2TSR(PG{%}-kESC`|q+qSsdx7FJn3x zRd{{G8zqT~uez~}%5Uh$aMs9?`Py%PmWMkwwZwiu?v?7vj1nd_q{mP8l>^$wM3$^G z^I!MNM4ZOt?HjVcb{^1#*d{&KQ7^oz%$>Vs2?|gs7J&ttO4;^X{+sbgD?0-wDpR=++7l6o zd&k1zJMuc9Ul~(&`LAA>y+^!`jR#&7y>2%9RNyiRWfm;{W@)D3s6ADaAXuuEX8v)x zVXCaf-oVpl`zyumOrJ_k1r1_6!RQ}|TUDTZV5W0BH#1UGQ#*WM<$600!=tIK>s3Cq zO}SmzR;g{69zJwvyj?v0D{+rp=5Ln}OtrSmV69 z4xkXfTOo+7>-uYP?^a1i>w1JoP9j|I)@bo5>-%I?PGeH;)>&2S2Q)@b6B_R}cp~eE zj8)D+LvNS^DkWmK1)&SfI*zHpTt69Jy-4MB>PKzWDvlyU2#$1u7R{o%>zQ zKN^>81ZM- zeu(dglLpWDq)-_7k4{dlp~w#P0Bj(Ssi~=~tZZ;_FaraFhK5FGXXpO@enCM&Qc_ZN zb+wC&%hRV%1q1{Z7Z=CI#{T@a|K$rc4#pH=Xqs*UhFI*i*l>>8SuQ(45HF!0+cItB zILRaXa&7x9#4^ST?G#G2p);I(nFlVdurenlV9MxrmCLYI(N&Yj=+W(? zm-pYi7x={?DD=POJ(rN=)U@>f$a`rih5se*xq@rz8ycIMTUy)NJG@ouI(vHi`UeJw zhDRXXqZ7Y=|E*ljy&RukT86ExuB~sxEpG1Y?(H8O9vw?;pZp1YSJyXOmXBZa{FUhKY+-JD)I&sdA<$AwcGw3 zJlIXu{P|DbW19Y1IG!UH%RVgr$*$;E!Mhv<``>op$v>qlt`GLb(@;ddw|3p3-?xJtYBh8J{2TPryH2U}w*JsGo2Z#tu0*F z@Uy+;msl_Vw5|{}Wltnx6PULLCpDu^Z^{KknPn_j2eu!Fo9i`Z{Z@Fg-hXL@NEW1i3is8s3BDO^=P>VBf}Ua zp=*x07!?_|? zV=_iGubh9Sxxjaf72F=!wlkfdB4VE<>Yv#I!d#P_D47uU(VQ4{GdA(K&)}W>y&s6& z$MYR)w=kEP!K|n-e>VwSQZqMJ1nFTfW@9ZIG21Im{T;FgX~LV?gSG|am?L%%TtmzI zm?<(4xet4&Y($2`l7wwc#r&Z9{m0OQng!}COGNOS@2p)-Fh9guPcJ@fV9OmmY&>ef zK5E)z7*ITFUYrX(YMI&{JZc@j!9H&5!BRSIZzT^q?xjs>&A)^7Mvtfg|u(J`3-J!Ell^dM% zFKsLH6PEUqCV3i{DHN$`_N=+Tj;dBxZ>4hv$L|pFw8`QHIX~K;iJY z>!M-!{Bb?*<-$(C^5x>{Z20BU{Lb*@@&xTQ?iFkhQ{`&Km>%LZ5GtHruo58)q*#y9 zXd&N7Fjiq8-vl`>k!__1dXQ~rC*qOqi+YPYj73YrVI^WTG??k`NgL7T(p9aWdlmjWXRuT~Rj z2(LE_M(Lj3?AGc(y*=zbBDg!9&LFslZ*dbm++5-(Jpw+e0nn4^C>{_C2uaeJNIKuz zP{rP9;)l`3Gs?A(emeF^GNnZm?qkHkwLSLjJ*LHwz2YPU9ow@g(doY_>>%#(@dt*{ zaTu_6k|!|v1PBe$vDy@NQXBaMO5&`F`pLa129yWM>2^Jh?Rz85SswiEj-it%NA9ii z$CD7TWP0M-D>N?KlTdu59MVs#-Mk>)r^ZSQa`Xd*Jwn>v;Z|YaYS!gsgWJj@oED_1 zFN31QPfsEj0J*gHj4Qpb$S2VJC35M%MfJTEi$M){%Vkt0=$F@4Ly2!%XXG#Hm!-0; zhz-4CWW{G2c#l^Rmyd$jlI!pNRKXVsar;h1^^HwYbfO}`s)-3WSvaVdUY^)O$-v`_ z+4F@kHk7hQf|)qBXo#J!GWoq5vp@>#uqksnXo{0wsFIDU%hL8&;L=?_UMHIhMNDNX zp9G7TUg3y6xGe2Zf{IRJO}@i6J|neqINMT&`a%v+MGEE7$SR+=cJ z&7UvM3(wsIz`*|g@qJbHO(Uz!0ad|x_y9DA+MY)4HMk`@=C>R3i=rzvu-Z|1Ro*<6 zE1Oa|%dcd0XueQ5waU-ArnFN5XUTih6Cd{9Fsi53h0WsZYW4-c^X4E$D)Cewb9Nhw z)IQm(H-mSmD!`w{e^eKz-?D28vP_k0LrRR|DYbv?)>T*7|Dd#T=U823*XS6phIln{ ze7??~ZeB$stUL+>WNfP1cESY9fES!|gvFW_yfx(puAGMT%(MMpODhucsK|}@v8#t| z6A&N9a85{bXlF;)R7r_*nK|drO|DXvR73K}EWTk^&Q2tJuW9BQur1b!IjO0oz2&kI zrD9%K`C3xflSgVBi(S6Km-uDG9oUi0q383ywgEuV2z0p4TRd1TZd}SEaq7f|p2Q^T zZN7juOcYZrxl~UV?Ympot#T~8P1Y8kH*?pV7B72TBo>h0-CI{7al*WwJLRKD+Efsh zzCzUq;L`^1z|>{kdjh7Wnvcx7`P?0=dM``rTfs3<+CvDXPBUwD9DcHMW| z#NNKxlyr2RC_QYbz75&;5OyYA*V#A%^zy z>L(%OxI<7P2*>m42R-nl>ytzv#j{os8$QU7Wt*n8hHH_&9}K6_zw+0aGjs=Z$b(Mf z`NuZcEJTOiGn^%-yF3hUD=PBmzsDNJT z{l*l#z>uaQLhAd|(?SG8_q|U?2E}Mn{|<(;8pjtkjcI!po{bBw5TE*as>3gdy)%=D zTnN=e^}v@{)}X!W$e=n7P<30n7y(y^(*hzRCDvKEuLnLz0=~1jt>Qwjhq0+o((&(B z9c?ZY4JFYFRA0?vx7|!NK3tS}y6?U(yP27PfTJcgW12Aj2Chc5Eke{?1XfQ?u{vj_~73SPsE$n#Dm_+n{>vBLKB%r(}!8a1IX$Pyz*g9 z@&Vg=agO;g75dV$`tsKK2rgkfB@059kn?`K?JXsOL{;Y{Lh8rx&I|F7pNZ}(j%aRO zdA%_4Cn@s-u==ZD_{YNhs|x&eg8XIV0(AR4B{lt7w*5{n@t+U`+KL3)X$Cs521eNh zy3_?a$$7b51v+TreeLr$F$s{H3GkEi(FzLcH4aje3$W5enn4MQ6A5(I432XOPPz*8 z!3YN0_(k>wg%<`HvxaEMg#?lYbMpkJU<83?LR_v8R~&=Zd4d`>f(l7PKGp?oU;5@3 zhA_8kl9xi}o{&A{Gs0 z!VD#q_rt=BY3_^pW)f+(9o&`_ao`kFY7(206eA3X;4O;gF^yQSi+({C!!8xO8x)%q z6c=9@E7TS%wG+Fx9e$nUfkEarAX5~6gb}Hb99+#BD_R$Ghfu-ELGR_`4(Ou+_3@CT zc){8Db4^c!>o6|Sgxagvo9*~$r}%}w1lj&X&H6AWQU9N8iC^WD46mb%$ihsD;yvmU z-^#~WiAGY1#`%(kb89687bP)>Mu%%fr#VMQCr8I#$K^ODscDL3PT~1lbU2&w3NzD_B!dL;bC{{k&Ka2Gcqn6;csH2@ z7Sxs!#sXo0b5PmRU%1%5G>gdy&V8kj@m7oi4J%i4AosOGy0X!3u53t*d`O;RNY1Al zMh%x(Ev$SUtUSG4a#e8d`+>X+Gn2*KvUX#OwZVlhU;qfrE&@ox<&w4zPBB}!!C0DeUY!fuF-w`-g>25nZSR)sf}#5^Dd1~J1OfDZz~{CBGP)Z* z4uYcHmR{V%oVkP?ke3BE#Hqb1o4zSq0hg~~l_eJyeZ_iCid{~r2)#3d+R2xDDikR4 zR#3QBFk)BM?UV=mVph1YSM=Oe@P?M5Dk9&SBwaOB_{>#^>Li|OSBz^^F(dw~Zp)t! z#=l>wkO{3SAT7rtGP-cYB0y|yk?qv_pkUG-pn(O%_~ zkjBKW;zrlL#;M{a|HjIa!77MrP40YAK6d>O;(>gxY1gqnY_Pe@scCntnBM z2k8{HAh%bscCjIKHJW$HDs|Zxb}8P~n(cPgpmiIvcYkne@iOZsf92D)SlnIN({;Gr zt-;w-zSe2j#8PV8W+_}`hU1Vh*W(n3WV2w;6vj?OnfClHhPtHJnX*dQsTTrjC+6(% zRqC)NmvHOab;V5TYp(4Jb8{uU>WkWMb2jb?g|*CW#bu|I#mzRj@AVe>^iLM_YbmDO zaFll{#wCPdWG{IBxbDB6uPSsKm=`aeb8Rf*Y^o|5tk50wOsyNe9T?l~@1z{uNNGe& zCzS3Z5Ng>v_F%=Wd2g1^K%V??YAUE}e<(8*zq~l%Rcihx&fxn!tid6#{DnV53md7P zag+|a60NYIf^_n37QT_vvXRm0p|sTDTZ^Hwup!N%idDCc3#G9n!_gA2Q98}x1lN&X z#Iyos%n@aKD0xLHs2erB1@px)?(kUq{@5eVIDYAXfkkkzS<^{VQG3u>apu@U!RXA; zDDS~&_g%xn(1dExFLcW>RPe8S*hIhV0G0c23g@r4e2#qZGY-^>!z;<&RfhJWc(VcT0y z&MEy$4xg+anzmTPkH4S(vN#!|hm{yUlT$MDdU&SzU`F_0Qj%)Un``vaVkS6z_P%7c zet5RAykzpD(!AoqZ(FW~SiQMio4GfOGfvI3E-xlp(tewY&nR#$j#JHcQBAhF&GawM z4K~jYTTYGM&rEa8SC&q+E6)^oEqqU#ULIa*+FwY&pWk#}`t@RATY0HEZLv0d4iYxc zk6Y`y80o&)Wi{Mxqqp4czL;-0Uv#jHfJ=|V^G_sWGY*F3ai$`4mb}uIwMv&tUQF%{ zFSKwjZKomLyeYSo-48Dh?ZXCeVaH%D*z?b@WjENQrgK=WW#5FE|zfqv&yVwGmMC`k5;#D zU$n#6qwlnNMGdkF4~Ma~V2Lz!@@j8!C}f}UZ1E^;QOWEOCS4Y)g{ zx0-|ZJa1_#%?g(Bc_+6NR(kjhiZ`c!xc$XrZ=`caMP;wrW4ZZoPd<8&c&WSGW397g ztq1a~&uWc?dOzJ`x9xskwR~X4ityW9%K_u~fj>|056<3Ih`ohGueI)Bhzh%-Qo9yp z<9m9LM$2&C#G#=|pHX_c3FOG^VX*b)r~%{!czGPpa@?kH97WO}L+NUEFoF}>3=?k= z<2d2NTROQPytr*HE^eSZT-0Fz4)-;lSPY&7gtk6ET3W$9C8IfY*=uCxUg1|+scbkE zSZ?BcSlyFkhc}*)FP}Wc8_QZaV^Td~z(1MCUBwnZn;blS{j$1!Z{^$rKhoow#OV3z z-Z}U361}O)ggNzwxA_H8!=@V)<&!eaEi#`m5@x#_}}{WKsU*4OrzOU;q58YjfozwAuX{s&{2hb8Em^ zQJ*>H8@MS4`Ss?7)Az)eZ(s7TQWA#tFz^-W@}-Qh2(YKA6|1Ymz>Buuy}3m=sbkw0 z`Nc{d^3XbkaDj1y+fuj1?sI2vuLKP;VYFYeR_MNFuhM_Z$x)liS=ZLe`#xis&pfHG zYCO25vOQojQutY)J!HVBc#p%PWS`Tr^x(jN@z82C&zj7d$C6Ej?IfJne9SZiT-jb; zb}nFFeIe*jBb0RrKj#+7zjQK~_XfB$+=w;0oVGbP-rc$0-8DDzg^9X7dcSP*M0z5_ zV0=v5;fF=S$6+Ctm(r=Z!Yor#my6pIMlNg>)_KAp7e@U$W?!>_Q6ZXGqrzh6lxZj_ zHlKFDLV#CyI8{*X)Q&&LcpzHp^#i4X8p}kkLL#5#ZcP4%SHkUY-;%dI@h;J>wLRDs z>%%+I^KI;CYz-9|G5G#m?aw2RAQzjxZ%wh>)(;Xl z=fZ8sB-}DD?ygR^eoYUzNIu-%T%H~2U$l5W!4jeMMkRN`_CjZ`qxHoVyUHYbgDEXS z=l@*WX~moHYaLx6mCM!2PjVj-`d}uI)9QEn>^k~Tu7<0xitHUC4B>)vPHUWet91;K zFAz7Lo{J%gGDg28cb3d0kz=oCj8*1i%Z+`nAj%YvC9TDjpk-G7-C4u-nknf^NN`Hx zH;^+k$W&52A7t5Z&73M$PR5dEKj*w9B=U2$o+ZQmram~`6HAOW%irN?VOB7E18dGJ z+Bw$TXa(|}oLFrawtOdb5L-c-3;AY#Mu-@Dk%B;gHaoZotgu#8+;GEQLRCY~QCdBR zI8a!-+rR-m*_h)fYsaEkgm#j}euH$cE-(2^FP zHM6Z5w*Uk!-3M-D!}bEwAzq#-`SlDHQ{#k!l9p2vaeUC!V2D0d=b z8gf>neCo8Zp8{}twJmTy^8ilTSt~iVi#?`ROdCfYAl2jWs{m06x?lj5_bi#UlbRo> zuTeETso-03^AU5y?Y0LUL(eIJP6r$kv?2fu)?uHcE|xM!>^UT;&=D_`;7L6R5w{KY zG4s1uYny6qcc{OSVUFh*q1;y+z(OSnKu9f*#0g>UonULfK3?S~FJE?AZgPP7AZin2? z&!{u;k`4%;_uRE@08bXVCVJ!tunK^kBkPUnHim@fMFCSE6DBp`c1h;f+!jbG$MjNl z!+EERIwT~?z})r~3zpSW!SbTT63>}M_Z3#!;r7rw1u)J2k+4vzrC6#Kx>5skJ@q4l8&Wr2@xk*K7lm?P#R2XnS89>br@HP zN4MID<)mVgUgr>#khMrWaJi4UzUmReluM|l{r2K#J4P}|9}@MUM`G&d3+=H!SsArB z64vQUZCXh&F>WVI9T=+a-$Q1}vT_f>M@w#v0WJ~Ffy(2E$OWp@D%Tat#?o0B1MOdlX%_AeW{(dNf zf`SnzB5_=;qAhbZ^$WB7jzHR1AevSV06QR9-oxPon*O&moT}%_o-s!SyqPAd`cbVY zJ{4|DRokfQa5eQQDMVMS;EEW~7!~uJm%&{@{zE1)uY_G=Rkc6KIxeWGHAi@=aD@y5PGEu2A+j@=NjV8dU7EFW7twd4=L4 zVx>^1wWh{z9ydyOwTsE$g@OjsOZ{d(0`9M>%OiW@Dx3s0q|3=Chc`cFaKcOtB3VCz zh^NX@?J33goOMD@&nw*g+du9#EO1E+)+WbW;6xNl6x5IF#xrm^U4PY12zp(r?}V#I zq5(mb1Nc2lF@MQzFnvXMX+6VrXz1RE1nY4$@?9)=`ZG#;1non8!#xL8TCwgv@Vr5! z5aLZ`;Y+{S_5|NXubz^830J3EXd|!7S?f$T!wjDIO6llCOl;+0!m+)>;?U#C;MbLo z8QUjI-u@H?3Qkzz!^5!gds z8^ms;*pGAx;Uo!XNW8|^V&fe9+Z zH4@LgP7b%%P7*up=W`-UZb8lst=Q>KV~A2}yM%F}?zJ#4?QqJ5MUrTzRCR+9Kjdfg zrogTCtX)leP@eXxtyM{CHZOPMWH2!Yx zdacX@RE3iTS@m*km!g-J->TkIMdPSRu|j?XEL)0=h3So{#Cltch9glW*|MM7Ep_ub75E%=&vbJfzG1N>qIs34NoK$32gKw-#b$8nG>L!I`f>E zWDdrwUbSXnQV!LdKPM<_^Kb95B56BnY@(uhmL}WVDGAg6vH+i+c*`3$sk;;y>xIqw z{P=^UDxY0;;PS+U5}bj>+HJSO>r5?8|0lV=|5>irLsz1EePt`%BLOrBW=i$N$?9N zjy`!1*(P?3i*kHp3I#sOpcPYJqBr|P=dsTVQw@fl`m~jlinLoDO??_80ejs<{qkgR z?iI07-cvM%VX1anlvC>4bAwx^6}vIL6|!_#@WK;v+BcbGOB*!M3-ZSoboYWQ(6?Dc z%(S-e=t#}_)=i%wQ%1~4kuRsxjk@Uvqe|@ph}4>C+a**c$EJ>&lRT11_kRpd{*DO{ z1q5six8~9c(Mg^4w;`8htz{0c+C@yI;#oxF?MZ2(yuqE*yrWI®{r_M6pA__12k zr;RGL%BmgC2EqHyM#o{+d0XAnYe`2Uy|RXy`S>bp@qH?EHiLMM{u71)U@n7mcdmlT zpKhsxr6!y2YcGdb*PYCZg$f59N@bZ9`%KCx^CV%X+9u> zn3ybogFdHp4ynt^bCne&hpYkboQf5CLHf^v1FJ$ZInUuTyaR7Ur#eLQb6yelit%;v zFXc$0Je^jhZ(Vw6XNS0uyTjTPfdkM@F zQuwAq3{%=P6gbQsBui0)=e8xUsbl47;C;8`To*v2R>-E571~gYVUkc^(FEs7K$t!< zf7gO%yrkdISo$uPoGTra_dcC+r8HSqZ2nVCs+Mg|3T_IJ@bi1IB`4iA*BPo_RAz)( zr{4K4eaxiykT;M;X<(Ia;Ib}awrO;eXEMUH?oDYdX4L(j*?c$pJ3We#WL9ck{-g7z zG0;E?M5ZpespP|~xMY+xn{QIX`~yC2&i7Wmr9+08<)_dV5#^Q|V*v+e!B6!9dzsZx zM>%7TcU$%j1&)URJ9d8uqyZ<70_Sv=#P2Ms?8eU71+Ff}PRwszN4nfv3fzT^-IfYG z4hy{2Sv((Dye0xY(OA8S`Arf@S$&vUqi_m+gbIC`dwpdJ{jxE9RSNw-Z<~H(4X|Pj z{2CD8z#0@_^3|;{D4{TTmkTvGy)Xn)m|9jC+QRCX%NjPq8vfZRe2F!pMmFM*HF6{% z49*&bR`dpgEt;6eG=!8bhIuE5oh?>q=kr}*tV~gy`ja@7qWI4{gdf=wtk@F86cRny zl3aWf0@#uhI876h*+BWGTwz6^8n%RrqLd!fxQ?RKDbtvVqO{GTsMVr$_)f%WQ3e`& zI6D$JlbAh}5S+zq7EBM$7Ge+N1Lw%>`bmRx)y;g>zk>E` zga+&LIU30a8)`V3;4A};Jsiz*15HyLEiD7hn;fm#11)flHjjZ;G|u*~18smpJst8rRwcdp^1U7*eA?rn+<qQy0~fi16H~OFO9h2f=Od7qXAcLOjf9^vT%T_Kr2&L;F=`k;SVXu z`c&|m3S=YX_l6O$&kVQ~A-~}O+?F)mj<8B_1MZ|VK-{t+yB+(xHNco8$X?IkLO<{T zj=(W6=9e zXyPvSqqNE6gS9+T86dVFK*D252ZKIg;UQ+_K{hr)epzNr%Y%YlhN61>s9uJ8#*Jod z^FU+8gFX#K_q4eSD8p#v#!M``N$0`JfMP*;u4~G$eQCLIdTlO8cyNuOxXWc1n`L;4 z-1wKrXAe9C+)x61-cuUl@~1@Hge*UffxJYw5TchShce~Q)_~77cn?1F68F+Vh^>F@ zIh2zW14#or>@aFM-rsC##3NG&B&J zN8S~*3R+}ZARUSAGBqFlRw+HV?V?Zx!yp%Qui&2N zxa!q64rvyFYM{Vt+2YqP?J8xe-mq}I)vzn4`79uVUMyp6hxmiaZn4V++LcBK$oGQf zGwh1Mh7b_G*a%Es&}TXuPe0;;lL)knJ=8no(nEbQ+|?6ar^ z)zHCe-1Zql)gQLlKFZpssR(}Rr7ilTZ=YgRtzN*U;aHvQA*dN!q!}rglu)hZ$fjLz zkw6I%)X^`}X}yT+sn#7e)f4B_TNaGntp2=gs!u}2r~haljaFlDWcuZd_X~|fBy-Kz zN7HYkyx(5J!zpBH4Drpr7xRAChld)~7}1y+2l5(w!h-{9Oend{OuzA(W(Wo4*OZuc%s+6LVPlwljBWO$H!df^JwPlc`R^)oN+Y*cDl|N3EkQV4TOgF z)(H%YgllYuQ%>7Nm_x<`5%YA{lU_(Ccdmug~%7qhoPmSSadi`J4&FMbChuanqr3 zs`a%W>*K#+Cm7e)+fvmNy4D|Ci6)7MCPj)?#?>b?V1o+k8!N7#H98+sLPS%MLsLgZ zp}*_X_OQ~o>YGnRNp9=+4n;FNDMB*w8;YMbWEEp&voy5wxsZvA?EuAc6NhpGaUi)G zE(N+``SVH`fPVp3k^cm){s@=)Ux6zC2^j%a5w!Im(TV`B2*rxPt_aThk8VYHR|INB z09S;1MMzf^dY+Y)mE7E10t#AOTwMHTBu^Q+(5TsP2%n=gCprG)Nt@Zc!udJ*@2wN>JEp#dt96Wqf z23|}Os{h5$0?=ZlQvW9?TjDVIKS5a)==}c;WrKo4{yUUS2BrL0C|mgdDMod5|A&(u z9UK3*lYKph*n0jql>L93Q3xn|dw2gYC|j1-_^|mFf}e1WX7iyYwI3sJ~LS@I&z*DZ4ja_IIA3=3u^# z?qB7oR`bWzC zX-Cl@sO#(!&$I5OZtgPu~vPpgs#U{w&pNbS4!pctlvr@2??gro4${_PX zW8cp7`==uXMqpW?-E9PxW#7pqi&Hn-$xAXqVA;PVsmuU&1eQ%O+btp~%m?oxwCwI~ zaaj-h-;&f`Y27B61fgZ&y9h0d#<7pkvgZ2;E&Hb>MJ%+pUpXejaZvSJ9ie4sjEWIj z)?tt2;Gdq00s1<1*1~vL|o@%Z%8c9$K_Z`&nJ|G4-joBB3Z7B-ybzBFD5mO zOE0E$9S<(1Um65*!DkGyiP{j81&F$oMS&iC{wD_Q<$~R`<>jJXXX)jV+a+JyvKKn= z3g&ZZ=|u5_G3shHM0n(CErP{{=xRL<(U;mtGKO4lraB&8Z)FA&0&li+6RmD`3jY?S zY7cMro|g3@HVY7qsoR6PEyyjRG4=fL_NWb=`|h~wPi3l)gPY{;R1GouuPQru_i1bYahngbhNy{c03oJ_jA}^V3H<$ zfTUwTR;4Vg_i`P?4IIb*JYiWlC$=?Y9ypdlsahy_HeQs7+LQ}c8-}0PYx)XHOe_BE zxgeWYw2P~Qe$qT&@P9;!pmG92nBM6UA`y6_0otCtacg^ym>MK}LkoEOP6|l@@B{#( z4OQuyenL!(LJg9_z^tQ3ahmBtdjr6f7N8?75_!T0Y6}4ze|yS2CPOU*Ku4GQMtS?( zix6gj6hq{LM93!3%-k00+vSb4B!_|yppAJ008suQcq07pBjzrK62K>lA{G!6+ntt>?@s;|^lt`lDMe=II?QyuI-i$;r8JAT96o$+V@%zbs z^;N0NCw1Ap{>K2Kd?NIkeM0~qz{?5mqduf$h2ja1ca>ljNRIUhfH?E5Tl$HGr-hdk zHwY?ycmk^TYeP}s>Cp(V%hi5nip-n|2-%IvD}T1b78nJ16Z0efwHN(UEO|9Lk+uTL z@Ga(UQNJRj3$W%eSe>DEU=reTQEY%)tl6vJZw0DOF}$B%HYuB`!SgD$w&&2D1yN2l z%sqcaKePus)P^2=A3;~vivQN8$l+x{1j;2}Nd{+n#Ng#o;fUIl;?vmyZFofzqBkWx zI6LA3ugoOCEj5ue_zDTaSpKNb$LD^OhA)xK<)8ODb=)DA8lx?DEGaG_4UicgnL zL6=S2%21c-!6kUZW%Ds!IMi)RartKMvgK+R>cM(BxQuY%i29Ta5;Xx#ZZQ)Lw@u2M z!X8F0Qimq2S(Z(J3g2Fe(xOs%oUl1Be3N_uKoS*NK?>3}eDQ*|PfT({K{BPO zKtWoxPcDxyL&5Dk5m`Yvau-w*FJgIAf0g`;&)Mc#6$&+ z5?cnf!K3m$mF1DmCTvqYF0)_B?P9**9tf~Cj8;Kvs_j8$Wai?^%`JRot)+Qi^c!Z@ zfw6_<;fHpSMOo0#g4K<^>t=$+&P)1)m4VH77H;2m=H5?Lx7Hi(d~|eLKu(OQ&ssS2 z`x!hcZ6(!7g?KyjO=pSs*>X%3w);_-zWnl&quSoATx)4p*L7_y$Bd^DwoJJ5Fe2;g zxCpnM=nCCUS)a@9xBL-5nNZtM(2NB}T_bX1Vx70&2otvXavtXj7r}nC)c^9Wvcd0l z+f~eV?m`^JV$_ltAmu(1ixr|aGz(3Wr!^Lv7(HajaZNr@JrbzN!@HzxrrfuK-iCLH zSVHz*$mwfb%mE)Ga$IEV*ZlBf<}1^#=(%2WhYJ(__D{X+>oi&6hBt25zd3VPY&uP` zep=Bk$$K@{X%dbmOt-3>kyBMrcV1#WzM4XPOHm}zmkGTOK7{Np?6tgc z<=yA$s*PSL$GKfkF@X;|N(NJV+V0GooYvJog!3$~teES+O;fGDY<(Gp3wpGAOV$i~ z>|Z8mZlJMI5NW;Js`0xXZGC{hY~_9T@?f%3pZ4g|LxI5SuoeZSPy@}$i{1|h((Cl7 zLEyEfK}ah4blXHE(+gX~i?r|sW}WwZEKcR67l8=iqv0#&Iz?g(5n!1Q=Zp_R$O1LJ z*{+0nG<}8YeE5P;*qwYOW_+F(;?$xm&|`Qpnjp?+y%D<-<}UM7n(2=&;THkC z)AU!b^HWRm6S48vCH0r9!)bDKYZ3O_JQij$@f-dYK>i)Y>zBXP6_S<{ise;+E@_~J zoDTxa>i6NaYdEaQ1{`Jv5M!YFB>8wJr1u()15;3e=eidZ!sQYT%}}gcfuy z6|9~a{FW!srY_iXO(m@_C_@u9QY0ig2sL#ZhZ%165Ey(=7V?2MR2mknVH5JkAT(11 zkew9jq#2U0>E~OBGlm{|CKdYLE3APutf?*_$0?{yBs6O~tR*R=mNjfr%Du`sY}+f$ zFDN_|BYb)$>_ea5+Dzzcg8)sZa3&|*85X|<(oh&i$N^&Jpe}SCwy5uvTqKS>?uxPfH_fOKn<$m2D7^l#YbS54ooJ$> zh;i2Ni#pt0N4N2g!055eXyvx(OV*f;x(L1^e*v;EmVVp?bq8o)%nU3>2@osH7QSv0 z#4C?lw;lO3IF=|F7uN)dV+ZlhBMvhuwq+(-PBfffE2h6NUZ_7_&oq9Kz(=hg_hXUw z*V!mp`8b;TFfFaP(7HIS(gzE>-}`#QW( z9uFvF`3^I6kT$hYE;XV)I!Y_0-!!#uHZ>A)3Cqj;iEP@3;N*hfv|Q6T*mcZ&aLB<< z`l=~jGw}}+)AVA4w8_GB{QmUWo%F$?xD&Ft?S8yXX2aHijCQTek^T%pvgAkk%qQB( z$lCFZMR=fV#Km*$o1!?p-9#F8kUnPASv}tFmc=8*|6%W~zS>;)w%b5yaY%xM06|KN zTLmpz1qu`j#kHk4MQXSecXxMpcX#*T?(WVGZP!|D*YkYed$Pyg2XBrt@&_at_ng=K zt#V)yM^T!GGo?jLK@pN{+&WN)r!|a@9s{Nj%CZ{Cdm4g=;a3WesYkemivEZ` zfj=!$a5d;QaWH{f5JWfOk+Aj{E#gNUa62d=M$O5D>c+_nj{7|!Yv{3B7_L%nAq*n8 z>|b=YO;MBskN`fSD%=>FBR*t8!FE(P9#Y(RznW~z4X`IpvC&QWCK5_^A;=G#j zZ8XI*EkzS8)x|tDC@l4*8_tc;nCAjW@4h6PirAZ%*;$4qyRX{W2&bi%*_Ftr5edcb z`@~{H09-+dak}U}PXHB&OW3FBHK*w@E8b+s2`LI0Tk@VcX|{O^xGhCtPt}l7umzIc zA_CH5Y@J4Q?Y!&Lk!mt0MF0~RSyN@4Rj0mV^jXWqS*zSxZcvtN`vL)H_-m(C&h1nia9?nh691tA^5v+Zqj&@FQC$Jmi9 za$h{nAueY}7mN~hzCn73I4hF7AsH=^^bxy8lv#X7q6 zS_07w%#Hwh{W@{Gy26v%PcLilS=J=F)=G!f-><-bm{EUQzm^XXWDu?6A14qLZFr*J z@NC?Nr6TxMM8osk2J!I*iSau5ih7~XqKyjbjfx(P?z)KhRRe&hQH2Lzji(7wz59<;+d5J?&$~&KeB^OpkKjv+CdE?q`eS1hEW^~*je-c(G7;MhQG5P##VnJ+jF|u%l zX>vSsa(-iSqjORob80VgybU%<=A_0}Q_>ZA(>uJlP`LD~Ug-^$VqsR@qRKJ!s%f;Y zX&HtQOq>}o^9;854DRGKf!7S=Lm{+lhQw>^aATVL3*gJ>l=Ak}xm6%CSt&zSF^V;c z$DZtc94=2ht_M*7cJVn^|2dwnxqDS}{F8HnWb=F~a{?dcAA8L|<(n5$nU86=K|2MY zsLiV8d`pwE_?B_zA=~vWxqLy+x=m(tfhA-COAyHlU*o-Wo^)3E@Js&@_+sPB3XNNGi?-XVtfqQJ6p%IL z@~vX1Bhw4a=7`^l%vj6!S}n0&t!Hc=Oh@ta&;65 zwqBa>on+{xI^Hm#sqZ`ZO zTlDH%E3R8Rlba2fn@e6Bmsy)FUaiR4t=pU1s4Sc4-C#g=IQFZyB9(1RrEO7+RlF-8 z!K(yZ@11(<_07p0a@DQ1$t{%ZjS2YcgGex{~k4H_ao2l zqV*mydzE)tsT?<<67pPDbzj4wXIdtNlAHOPf_Yd&~z?)!T9F`{T^}>g0P! z-3NDL_pw+G&$13#4fo#|9#YyJs&*efdUYhmzwyEW_lS)0sMaItBx~>Mt3$%P+tjQYdHKuI$B- z>!Z7{Cg$yx98x}is@+G*>o`z-4#_-+diIUJvTb9&xP_E02+Lko$bKM}^@R6wpEzs$ zG|Maga;3VT3tMQ5{A#}X(hKimnjEMaIVb=G;{sd@9yaRg;%w{Hiw5 z)l*mGY65mQ&&aAaK>!)r;{yGD&0}+4CnfKcy1;(E+UD6y39S@r2RZm%=uX+G5T6R zr^-N)70K5RCc$C)WdOMUZFc$SJo|-*)S*mQdm6GX~jj__NG%gcZ z$IY#Z!M>Ye=b?MnI@JPHlNOdIdn?0Rw`{O3@2s_Mf4wZ`x_2&hJ8N|E8>7G(zLHN- z?JZcvvuf7cq!Nz-?Ihhf=Zq0uya&LF?ONb7Qr`wJt}D<3ode78S#{sLoo9zw`#I=9 zu?kQ#>F06D8eH5AP$`y$2dExJN`EK!(E6cYGj*gTxb3Fo`QvA|*$Vsw?quVe@V97c zfLdr9N&P$STZQ+z)0gsq+nL8SK6WBMLh%=Vuu*~6&BOLMACqluU8G zD_#QSfZpd97$z;{=-wRz9qPaZYN|@s&&(6W|=rc=5Viat^upM-ZthGBL}`n{7-G zz4uYJhHBD2yYzmivU1lA1_j=KUb$MoPJI>*J@;#aI6*4ZT@U5zg zCZSEosX^|AvfQTfwQ2&@wul{PWRB74%xrd&9T}2p(!CBOb4!!5?jptA1hXL$gHADi z*k{98KSAmMjT93AlANJ(!tla^@Rw1&;)hvC8Zm-S!#22+U#A^8s!V5_#UoD^BVSN6 zb9Anvn+PUXPn@o;1zw(3MqiPgZ+p8&t*@nCv7fDAice*1v~#E$?JPt`Un;NP23(yV zOhsQ^EZz45xJz~bj*622I1;wzbTPhwZ`A-!pgoFd3<_09NZ%mtH|W4LP0W+9r=M8a z$x$dlwD%r^KL8yc(@Cq-ikdzK6|sVuB7(43N*12WAONN1m}*gFMvn=ZIS6o%3_e0^LT6}n6y#b34Q{~0eKD~n0I=#zmZz7zN> zn@hYHB!XoDbW#P0Z)Etf=x>%ELq{c?`S9aFfy+k?zfkh&=J=>)s83(Q@ttiNrLnaBb%?8(0PDW*`F9~w&2-elm73J#5hzN6C->kodj>$Fi z5$5>@u4~XK%D1Z!<`aevu4{6Q$#+{57Q_;4XbUSU^hXgE5~jY{(3Kxk7_K5L$_?Dm z*HKg)nWbxBxKOR#C=s;D$a1}VjFeY0s2I;OP52Px|f-29ZPsJx~^ z04awLZhk2pQ{J+MRLl@;nYJpb?BPd2Dp%fYnZw6aj;bJ4JAqr4D~kB4XOocXlff

%j%!2I z0HP$^aS>K}ix*AQNGG%7CO?k<7E(>rbT4SfLr3Wy$rMpD*U*j^81Xci{EDbWkZ{+> zRY`-29NLN}EVJtuI<7&(4{ehS+6_omdd~n>g|^EN?FN;Ozh|+5cBm8Xg#cTXG#^Am zJ9T9C!r&+0b6p320o4{|ee;N4Ot@uB`opoVXX{DzhI4 zozQ;4Pu%Alw4Xqy4AOb2O57hhw4cN^q4Uaycpw%?c#tBjtSc2wJeVqTkS0F?(v__y z9?A_m$k0*NQ=B3mF2x->$TFMIQ@tV{sUCX=E}iU2oBX6na5d@Tz2K7$Y#o)4;ZvkbT*D__W|P2= z(O0C)f{@c5R~3XS0$afrmObqYo&1!{4_lQC#yK5GRr#E*3R{yOJ{>B>nf#n>16x;z zoQ<@qe96a&hHdD`o{hpMzZ6%)HnEL^&&F3&zLrnHw#l=JNPp4|uu1dCt6*_#r$Tek#xZ$uLivzh>7FIRyk49AM zWG_~+sSA(z{(s+M!)*yJazvX{G| zQ3fX#9KJh?W;Zd-#vG1Hy3lRe1QrcDO;# zgmm+kj;`)YQPJw^Y5@TO3JQwo=;*DjEg%qRV`HPHrZzP-B_Sb^ot@3f%4%d}`CrN5d;Rr)C5Qj-BnPhVlEc%RYS$sy-z0}$ z9S35%*nb>{pNa$YPsibJ#ewK|$KfBvf$&$y;UC2T@|)xEkK#c1i{tQ%;z0Pf!Gmd|$HQp?(D%HiuK7g7jG?a5NY4Ki*}e6Tpealp+iTg}j8 zK7*uIPnxYi9{6xJU9whXFjIF)}`eYTGrqgrwLqQvmg`EIfy-`TmW`qkR8|J2?r zjlk)kp~dtAl`BG10pvIqTil0gpi9(8=;A=TaVd%$93hXeWN_{0Nz^U{LIf~PwjDA* zs-cN-qF@AxJKP+fnwOS)42|1%V%k^R7YY$1#CYZ0zBhxbGxCJw>+H4@vr;6`|Co#D zd5;U^&Mc~xf7Y?CQ6JvZ`yOJ?Ze{+!68;3%ct|`1oC`B?Dx?F zB2+0Kq`u=976*H5y}LsRlYjcstIwZgUW2gG<%PuLs{rq(cW=BwZ-4!^FA$KCNcf)X zrIL)1f0bn{t`52=yHlP|#Q^C&F=#uKY&WRM^6q{A*5dbmCc!{1!^+Kqh^wb`zG~Ax?zX$UC`aY6&Q#kK1dPb)m_8^A_v_cOci6R zkxE4TuTfx2GsR5nvrQX(@P*Bb6=Fpr*38|XSGS+R#%l|1ag!-Z8q9oYl&;iyIuzn` zkY(D5$icP^SDhYY+hS*^J-qJ0UW*Oiy#_hAB%Xy=?iYRM!`o z^x!~w&6!|$_bs+`)pyti9i&=ZjkVc@-Xs+I$`(yXxMfJ6DZDO7NsC*ck zmJn69cL8U(azykgOIw{@%g#`7#XN7erSS8q?L*m0&4@*7jziUHjJ@LiPL_|TJbB18 zBejR`S(_<5R58|M>&WxB?fuKuajL)CMU~Onu5rFP(><<#C&OdF_9CMhEF02jo1sl| zuS1oHrLPgfZs=Sd{%*drF!Pfm1HXZiFqFK%r4RpuvO1ruvo4`*u#Wv3$_eru7~m#4~L2tX+I#SRBBK(Y0=vg>IQp6qWDu?=o8#*U6;r4=P+I_OJTjP>Ij71nNS7|D*of7D>{|IEQ& z*I>1dI=A>Wz z9+lju(;6$TESmq|&(t;}GZy#)CnF5eI?=+h^(B#ulGe<3ep0#j<|V0W5@A8w=gE(8 z4p-wUg3TvAR2$urEm%#GUVrh0z2*iKqKDgLx3$`xexg}X`9il)TDby>S|Jyq|@8pERWH)qb-QEZFZg* z*{nb;pT4(w7i_Zz`Pva|xqW24`BnGf5E5$);DQDP8*l>{V{^`dG7dq)-nOGiv06pF z0TDvUkhMpb`;t^_yD|Lfv$N)qqx~}qyCxbG=NEQQZLKMrQSOJ>t1hZyjMxH0K6pcr zYF?n=OlzFgS#vbmi!D1mOL`ABMS+CaJvD*=W;tvfAQsPakQo5BCv^^z%l6@#U-<@J zC=t-21fl}ew7;?1+9GixMWR|B9XUOV0RUJX)rlN%X+K$we6SISJfMRkC z(9yOTI069V9A1o=>X$fJ3ORiZ)(&=Y)e=&2J~q#IVUwkW1T->z9)jYC2EfTNO(sGj z>;ae_JGslTyyblVc-h&m#QrSN%wEgJk;W_l;*L~ekxFC*wsj`eLSg`Tn8~_+D%Qzw z(K<_%J`9$ACSfZ}Yt`#4JsjdRviwat*}na&`wXpDr}I~VW)D4xO=pO!ruL^rB9`vu zr(bS;)7dIv-VxH9mh+l#u|1%5Uv7D)bmSe@@>hoUCXVt`8YDsf{ z*=G^iS3=0W5X}WG7zw$?mkjPpapHR$-H$z3uk872>J zC?jF0uYHL(Q%UGxUuevULzJs~99I}>s*g>x&pM}Xsw-oHj%&D!59VWNc-Q^#HM#JN zRId_rrm|2k9HDUZlkhfhL{@8fj!s0qeZ-B_i1gNo=GllMcv!1^M4t}h06e_DG<+m9 zvLP>Wj1N7k@pe>tX~aC7VKLPETrTWgNK}Mb6e?HLHk5JgV(8@@92eP9$kd6KEX0#oH2?p^4>A2^q7&TyP_fRDbz2V;SA}hpizWsgmA5N&2*x zSoZpZF2;wsLxT^xAM|zI=Gl`Cx!)LIB!3o39>RA@vK&peK25eIPO;}saZpHcGEV`! zgr&G4&d`mfc%7yos$YKGsdxbjsX^wcAz`UuWvLOPsZpn?F<|1fIPSCrg|sB|w3M*4 zw6e4es?oHp)3hAo^t?3uA3URZLATXLIu zMtfRv$7x0vab_zjnwjQmA>|h5+K>j@!JO@AF!LzGbBm14B9mokNRTnd zzR^D-yv;yze(FkBp34?a$gxH|&S3@!$zv68yBy6|-O0}oL?Qq1W@#sC-xQkvq&rPJIj3rLOy z(ZEp_tT>zW$hFiRQvQ+%)h8deln3g`saGuVD<>ksG$|OXu#~K@qckJu@E90_46Q-P z){Lnhl~h`&mwEJ-<1&b`&-w>+3u!)q4vYOq9^(aviI!L?u>CVZvZUXxm)b4Fr{irS9Z8nSUF z3evi3VbtnRny z`ysxkEt}))6NNHb9O_zL>bGblt|Cm`tI>VW?)+!Pg{agTS7-G z(|W5Ka$C}PTh=;q3RZpE_}z^2wj##1T%P*;jJt)F?TF&nbM}gM*}nFw^|l()jylo1 z4N4tJmK|IX9k25`IxO45pp8A}clvlbOO!gLEjva1I!CeE$1NKtBkoLRbat&bEa=}^ z;^|xw)bBK2r{3u3*uv`Gsi59d>NpthJc^(`S#LiR?YiJ;eS+7ON!p`O(Z#ddUF^}L zoYBKo*Hg{YWrx+ve!Dk&y+=q>zxRQCZ6=k@`l{qd^H{P%@E{8nP$~2KqI$k$1 zK4dkFOffF2JV84q);@v=kHIDvc_tq!PcB;B8K|5z?3>)2px%Z}9eYgP%ADM1y0ga= zIi;O9buQL3u`yAPJ$Z#aa0u%6B;IEoU|( zW{Ab_;9kz$>73bLpSewThthhsUu^bNf0lleitcju4S4R({SVW~IYXRLbDCl^Y?rr- zCgyrQh6O&{wa!d_Vm&VuHU9!IS|GAX?LINDIol**P0hbS>HTQIq-J2N){4MrK@JD? zO%!F+*n%n_szMdnIwn8^Cp~*&@x$e!5!sRn-_j?(QJ-34Pz(|%XDRXs0NMsv<1E{d zEjwH;;cNqNjFw%!mfa?ovzGuVkQE=k6+h-LgfW2NsFjeam9P&;amOn_z-k=XDiDA~ zpthP~y_y!a3hV(8^jKtXuIBKq-AcQO^0nylAb zqSo80)}KF{HWmEXL$=Y!w=tlyF=V|l61CBAiIR~S-*LGyO}07Pwa(l{2)Eu`j@n$U z+FZvmLFm_N-;(C|w)R!F4#^hHdN)t2w$3NF;8_6dm}*G}|HtMCG`o!!z3$uVD@uS& zzpzayfP@d+!DHFMx7h)2ITBazkY?||ws$DVcW+OXFu-{ERp~ z3zMVb-DG6x0I=G=guonu*A{V~sY3Qm6}G8qwQNet?M-QkLyW9GX{MRU&i#*_^MW`> zaOH#dQT z$f=%PCXHPoypQOM*q4%jM|H)#}Ujsmo0pvWd$}GLwtAU2n>L)hj~#eUa#W z6Ynbl03O0}m;;4MGlz;T;SmRlr@mSHZaFB{4ij8k4!g2FZ~nXG0K6w@3<7qpgYIK_ z`sVfR{>ySWXuIe4S^xf@mcv)$!omZB-z^6PVN=e$5C!93ECG^1ec3<$ZaLh1+tK2!Ari?Tqm- zL!9o4d}MZruk$)qEL!%7Dw{mcx=KOeV-csr2K@P$5G*6XNL0$>G5n#7K?bOum-DlF0aNqx%K8sVG?7SmRoA zNLnTGQ*ZySIh;^xe)bUzX15~#tot!e&rjxp_=`R~-!OM_DeTLKiE^_eW{oqf4^#E7 z5Dsh7DbcXneu}S`q~`lGH4t{SU~8DgmyN(w_udMarRmOG@t0xkNK5mBH8{z$J!I=^ z&!PKjlFY{X;tUqIb46YQxH0ctgN)4sq(CB2+U7@rSpw_O%J0?K0&!w`9%Hd&7}aBO z4M;pe|E@WRG4g;K@!u%^2lt#@bx zmHb-2iw?Aj5q@p-nHAu6#=>#GcIL`+a0lym(V>$Euk+#EK_&lA_6bY8F3yDr|1R$D zro+Lwe>dO7IbM&zO;|wBV{Be3{NAVERR@?AexK0Y$bdc(rb_&NQML($>cD$}KOpf8 zMj1HpT8x)qQ2LE>;Gm51chw;>a7e+Rl3-Zr^F-jV3c1wl zv?3hUPK>;E9blD&WBP>?L1P~(F9^qt8ezfXCf{X;LFM3yFQgMzkV(^p$lyuyjY`PW zwe2uveQ^Pqw!KLfLNaZS%||rjK%f$W&>gIaW?k+^h0MA!RT0g35VK8&%z5!%63zQO zBMY7P6XSy}1iVoRT?kT!T0<8@-baNlh8a{rmm)q-hAu@}UHcA>L}X#hah`m{D+z%r zVJk_I*4M&ARM=`7B(sWmEu(NUY%S~Bc*toa3t!LcT@;-#{o2B%37*QQ@0q z8&xD*6$g_9*UrNw$#%_6@`&v^Y<|+61_IRx$W9Z?hIF^(ZZzU>8Pm1%z%~`J*TqY4 zMY`Yfj68C`4fd;wHGpgkyH8ih_y5K{$s z?LOSX=Nb{jwFWv+;?JN7%RR=60y@%3&-6bRs-4g}~Q3F4;%f%}I90sPB?_)CHS{%Jwre%B!Wu^@ndY7n^i@xN&h zIKO)kxPN#MxW9-H*B%7!A0hU=_2&9Xky$I}!0KN$Fi(tM8;ETY%2-=I_ya*zU;KB&_ zi$K5#Sd0uLM8X6jgP};M7)a>Y$T)b&_%LKbiW^X>8zj^>V04HDJ>dV%*Db#qpZ{gY z|I3d5cOd`&|A9RC$3Q;V;PB@_zIX$CIA+w$^m8F^x^%se2VGwj{XUU*fP6TfK@eob zi6;+^&Eq*AM2hq0L>~0ZM*hb{9{lS@{>MZf_{&EA`$Qi2&qhA`e7A#R3xv2U`hw+R ze?Zpo`mAX7#UXI~>(<5fUD3;*BYE6EM)JVnKSuJP7{qDO-$wF4#7h2pBu^t$i=nV> zPt*a5L7bLCtmL=9UEdacAyfwjhW(hy%eEk9^4BYQP%M_#^>xv0C(e#ooUbkQWVqMo zMX}<5&O!~hug{C(ipGJ1TN9@?Vu#JIzV3!`N_o<<(=q)S50+_2kBN7E z5I06|LvUMLo^$W|`XG*s(g?STDB$h-g>)oPb0Dsan!lDtI@X9r7N8fA+|&V}0&VH` z9@@$jnFM|{Szu5Xrjal63o^+|WHd&nQJ(Y*HrQQYzGqiTt$yhjqJ4jn)mNTc(P=uU zU1KpM0=-jP1W&a7JdrIMepid{W6%@|4M*jt4ih)n$B2H!eW~tG9iOSCA~)|yJ(`1y zn)W>q-K$6r*`-N+r9YY@==@}-H?*!Wn|P}yMOTZ2yv_a_!L ziwlPbH^06qpD6isREq4lY5IypWjpAIX^D-=*kDcR5NZkfnAkctCe z4b`K6xvvqZPe_etR{Na!zN5G3`y4rVgK&~zh^ms}9JBjE``+C>kH+zL#M$r;MZx?4 z63m5L-J!Wgd=I>=Fc;bLL)$-)eT>Szy~Oy=tX<)GUNoxMGA~1Mx9#P}I4d5V=URz9 zE_?-XRGFHBiedC2lb;grX6lFpzd7xfMK4U{*VmTO8XZje@F^_O=)GJvcUxxD=ZyE6 zjtXlE1Ig5dnP0J2@87T(sk=PL+z?%tCb59Ssy=EYiop?wM-dO&3d}!t>gz#B#>T7* z^KMiUXd6!59d9Q4npbY{4_;;$`+jcpm>IjQwR43MAyiU!9Z4QmZfbuVo1WDyu5o<4 zEgt-64A=0u&g!E@S!MQnpSw7Fp zg*fCSKEYT*I>y@U+7Yh5=kus+5f#_8%O-Cxa*q^~<7mG}@AZDP#nFn;5m#5RxPMZ9 zMFS&;d80H!hfJP>tjd;fb*6gCXYDntzPmE-&c8Ite=xDOd$XLNq3V#EYJ7u3&|+di z#W8`Obc?0eVgPP^K>AX7#n#Yux}N!n+z;W^-4tlh)x2r$I0CRmjy|$ZO}SmjSQz7C%m&R1a4Pm5@93tJ)k!5%9lk znKIeJB|pC-dnGjD{W&K3YF8@gN@Us4wC)k#1V*w=?^laM8jC^5*PtWQ7#dsZTAM6c z+bL(O6Q*X{+Tkw@$p(a`cHJ-RNFZ8hga`4^A!nX6NgzJgYd- zw+ku2EsAfI#sEZc3OxqE-a_oC1pojY&}$HTP-30~Q7$3}HaUkBU#S-WhsQ#;qAt(} z2XE1V4uBj7sS;qqLnW1CM>SJbiIi`ztx6ovdK}e5oODZ^Tt7RC(6ZCJ*f9{w4eIK2x2yhmd*CNjP&>XcrAw&Dju(bPVA*@3bLAlKroG3T<=?7DNrG8EtpAi(JX=Q`BS5K~TQaRP6DKQ{ts{WdEnC~81p1M<*+c8uyX=W{!%OZMp#8)PNzMop%?Lisx<=%H z_uhy77|o2T*^@}i2L%tQg9r)v#4FO&!xxS93Jrh+_nFpo+X?YPiuaLg@gbzs&eQf$ zk*OwNa%%MZ!l0l$^~KHM81EftBwl>GPmB+ClHkSXE114C{mXootNJgAL4BbzB`AN`e&R{gs3Rk4yj^EiUgq zLqlGc`nl`)pY;U`9tYc+`AX3Rw#tP#Lpi?48<>^`o0|or+z16wg!-fVebWh5mS<^q zadbiF2$%m7i5|ww<$r}7iY5@2=<1hT8rZMRk`8s~S`Kpy^Ehj5+*n9bRiG{?82+R&~tTet+8KS_8V7LSS+|3HN01~9>T zE};T7@u)c#>K2h1;%X>?o>cB^E2(R%NT0|rlITX8Am$eTAUPfslDP2IG}Vkb zrhv*~_wG9tLpi@;q&gBOM~=qszD_l=e~=bFn#OCMlE%rh08NW?OQlXqwP;IKI*p^F zcl%bKR=S$@Vl_Sr$})%sC{Rd$%Wai-nr6tGzpH8{fawzrz~VCE(_%} zeTF`HW68#QG+C#OnNXK=!!7%yExUm}d_67-i6Q4`G-V4T7e&#g*W6*x%@IvE^Tt`m zX<6wUU>o!T2Z-7c-yw%of(JF=+2k z*`CFxn3Z&zPgt zA^#0YysTcK{93jmNs9Z8BA3+y_3-?6ip*?bMd_B-Nt9s|=p|9=AsNCYiIgQS zkBFT}9D>S9qGL;ZFo_3FOZ?lRMO-EH!4}W~;)w8aXzNL7Bq(x*t}IeA3Oxo9ScR^( zMWtIn7v!_E6``vuWw-NRLNDpciy5Fi<>i6KF;#j{OoiN9OlXZ_h5x;ZmPgQLn(_*t z*orPp=tGi9PvfH3FDt!blS~wMt#XS^S=S@7 zEvwqHAliIX{S8#Drc>=8S$&F0WJ5hxZI6_ehFN2}Q;l*C$>**?Rf2q6DaEXSG>2v2 zRzM;j)mqPFN)nDkGG9J=s`qf;jV-F|CfQoX;Q<=}v+n*QQwT{0BM&OflC@uvjV`@T zb-PYatsb2#nl?kc@6?*XBh$gQ{tkMc!Z%5(j%@z*EPILulGOUEM-iIXpNp4643dK0 zGnXgdB}jh&DSF-^`Rb)^D55dUqdIn{aZp-Om9Y_(i}s#A8|<#+*O##=Z6OUb%_O6Z zIE|Ti&Gi7N4G`K?jdFdd?V;@VQfI+8;HbtAgXFt=tV(GSEent>VK)}oC# zUt0~%v&|RTEZcx2O^sNjvTHjuFPz(D5`Q%g#EF zPFd2`5%_JCyBZ#_pXTRb*)ky|Udv_Au z9`g3i-Why1ks%5z2rr`X8)KtV>P}|t>u~GU#OT*G7?{iG;^giAGBF&=*p-kruop4x zGBN0ZLGAf6SXFA+V`I2#e8_NO=tgC~DR!?vHsWqn_E81g6*d}FITUsQkC}kmSv45k zAN~YWtk56L<7uzG+jXqe>;h{G_ZW72)9HcTWiQsAr#x1G-65jiIZZki-Z`qbI}%hr z$VoTUQ4kr+MUd{5k!b0e<|CRMao$>K)l7l?I^S|+mUm2^wB0(gM{#!4G8;A#RW|Y2 zqL3wWY&)`#hpl<^V)9hEKl!3)U_9|74^1GZvrcErmGWe|lQvXrXdm`k&MJ=iZ2FFJ zV-4ZVATK;Ua12P+)SN^U>dq@#fl67$K!7#t2}7gGVu;n7U7bLs^I}LS&w}sGeifK2 zvM<9}=V4bteZk|8&|)>5+kc@?BIeBNBWMN|dSc{cwUsq!_>!c!`x z1*y#jIcA7t)q;`Wf(j1AfOJtGREQMmBX-$qlK`D`BO?Vao(fNwX&Rkp<0lCtul5t&16~VizRq2^|@+=jCOtLjeq_qNOf_&>WarL#b zOZ>#~wQRnC8a{$`3Z8YT+@2;f0_w+(9c?yO-7SJAhw53DT9=+Wj4-2^Wa_>&> z^5bZ*pz1#5y%kC!8}N(jeKOw_aYJw<*#Wh2#<}>x9l=#Oa&U6h0cGqexe6{=nf#DU za$Ca&L~L`2uYRce3Zyf22&z3aA_rAd@*P2aJG4F=fh5<>y+JLLM})cSwuYcCzGM8} zeWz^TG)2@gX77dvImrF$7;Wa*Zwu(FdU8|nB%~ThUUhLo(i}>c+(=D$|-LGHCZBDF!kt?12O4YUdVvZ7afbT9){`}#;p>+co4 z6>t`4&->4cKJRQ>{ht+m>^;;i@IQQ6A6%=5;!g4xU(O~}2xUdA=zs9#KJ=z1#OjC@ z{dc}xABbY)r==$m`<*ZE`-(DL`LKWf!IzyWI8g$0Wiqe%a+>!yCoEA%+M5E0dOue5 zdhXu>QJ&z)74}UVs8m{9^W|m|shiH=mM*|P8!!kdBYR2qPK+R_hUVHf)&(G z@!_L(Z@h4DQ3>45p+yE9`sF)cUV2FSTB3$Ty9I>P@qI==ldG$E&j%BvF%AGPlOg!> zrVhSd*jK+NcO$k5l)eHyOu z@-uzp?OVu83g7Ya`ogot_44m{ITsHX&;ocy!StOkUrh`AkMJ_>{|GN5=<=Qa8(zMq z%T)hIy!?YMlmAz|{DUr&{};Uc7hQ(^C%pVOT_*iEUjBzJll+O7f2GU6X{5xHS{Ee4?r^^IC@$&C zy!pnad$i(w)EM2cAdAjNM8-1h?}BEaI`n^43em5N`iQlo^7 z>BZFC!mmgW#lfdlIK7f{t+V_}jn+#;CrUdlOMk7(h_U;V!IWv%mgUw}P0at>kWU~j zJtH$KJ0~|Uzo4+FxTLfUah|rSx~8_S{y%51|Eyp?TwPn=*xcIQ+1=YeI6V3vWw3Rb z$6Ipao4qOihZ$@-k%E-(6Vd-`2K%md|1TZv|CYg~{8tD2&kQ!@O8xIhW*D5 z_Foxn(*N4QMr5#I|ECW2GU@-RgG~~o{a-uSB;P~W|G9$={Sm_c4;^giUm@&&b+93( zi>W_D*uQnKA-{&Of7yr&BG!5<;)A0tuxKSS8pBT)i|&5}Pk zUlRk~|1^Y+_uESJ$4C^%0l_(;F4Zbd6{~p2yeJ^4E8o~zuEMfn` z`9R-G*u@+Fa6a(S#1GK-59b5^1bu&VKH%S=?7w1%0@GfIi%g-$5S;LHd3NeYn>>>|a11 z?ltNA1@z%wlfGX-AMQ2j`vvsj{vdsSgFc)er0*}#hx?QC{RR4Pev-bQpbzIS()R=O z;r>nfetw(a*(}^;f|!mj>#GBXTD!iK?L}_uGi!0{;dPsj_DAj|bn+g_cUyKjMFSd> z`QPF7*xyo%L1Cj60AY|gG6%$F8PoDpQxZ$MPWs1X<}N?GS4QkP>K_j~(B=_LBlh(Q zNJzOy$0Hf`IuPSdKw_ft%1d=_;?Oq#q=a0$N5)}Nkz@hMaXTxo>@bL9!E6to!+2M< z{L%)xSPN6b&FG~cPz|Q4B&Eq@V#(xlhi1Y|*zT)btiG$GA8PU`%y8#oP{4~HDkMwF z{J?~zgf;hr^?gTugg>{W&jyrkybP`d%P>3fGAe;#>2k%aMY;NBj2eTt;jM9rc|mNL znh#zEb(+X9M!_x^yx1`s$`NjNmoqX zE-Ds1VKNH5JwEgQu=iGBQSW`c_RtJSO1D9$Ae|O1V$qFsgTM^k9YYSo(A^+8G|~+U z0z)fG3kXWb{=3$CR`0!^eO>##-jjU>CmjE|fA{zP+{{Py3KKXpCa5HS5!B=u_2yI0 z;nvSoZ)lQg@7uw0;7M|ElBS>;7d>={?X7dNK7o1S zYxIyhpO8|wlzEas33S9lz9xBzhg_rqo%fnJJbjnvy;R_^)x3%|ZB%q(`H7WJ>3N@#!DgOT4j=A<6gO zD;iMJpiQrd&c&n4`t~IZN4+rt&*d!+)Z1^-dqKPO{h#e2;uLL;{wfTN6o#+6`I%y( zY|cM0P>1$?8Qz0Mo`1ZtF3>RW_>ObDuGg0(%f_JzA{T{RsW0@EC&I5}T->)>=tZgh z2Yna88AtQE5Bu)cISJhPo@ddyt}tCRQA75f?5_7_lAp&UZ-lGOi-k?`-6p}>HrMb7 zZ#qLf$^J2!!&*Y=ZC+dckh!I{Llk}BHuG5B!mf;a)0X#ccZcU93Gs2q(ZX&Je~SRN z>~TD|3{4cHuax9X{*%}0ZTmE`QmboDCnHbv4l+U&)*kro=QAcOPG!k@OUhSzWC2KF z(@1)|4%qm7zVN-CU-2uW%dYGE!DUb7)rD99|9Q;y_TxK8()%u6=U=$>XBbAOy$^DI zJ=dQz9}l+6v=mleSZ}wT^&iU|$KGJP{6;Txwi>H&GC=~}sac}k@@?n;#D8Tl-hSDH zVwSyh$-cUn%elJLBm~j;RSWoG(+~oU>|oJ3NDN$+*}|$Q42R25^e$uPo=?4n+Z4A3un%n z!O0Pe&{2;lNWm#eiO{HuDAU5JIF5Km8d*()Q_G?ispc4on!yf5MY?50wiRKw(nNWS zMRn8Q^+KW?KvB$tTG#_KQMMgX;~dyyV$tzh(NhrInUv@hRP;P5e39iiI)f(W6HUZt ztr#GQy$*@kM8%ZM#O%yO?9s$liN!L17mGZE#MY<8o~A^e&&0MI$AUPcFvVd?EYY|d znD{IR7;hEqS}`W^35-)6MlOy`rX9y77e@oerf!V07BysIq79;>jo%Lry{&}L2#r5C zkLRMr=S0VA<0i!I#WM(^A;RD=9%zDxT!IADgaw@t-jN_T0J@J(c*JCtcuPEy0FkJS zhN+?xNly~h#p9ldC(&vrX=r0A&_fe+2a>c9NnxTdjKKJY;>j%INy5OfX=A)NePi;w z8Q{`1%$gP+A`kcWOBNJ|yICc8HYPYKqT#Rn;JlnEfyMAHHc%KMC89AU3Y`*x_VI1> z#S%dL3~fUZDcT5_mG|8gM8*jsS;;@d$~I96Q~*saLZp@;QWH5-<1{!C3EH4)ab&GF zvR)fm4)w1DBe5KjZ70ZT1oBWcHIEkA0ZpsrOly`)QI$vZpQPbrr7@zAE#UNi&U6;V zw1L&MDR9~hlyDfG){9PWJ4weZN|&D#&-iSGT(t_A-0+=m%-9x3425TGpppCHnON$X zyWq@CMCMzi%yWy(BXAaqGYdU{P@>7irAx<`2)Lm2yJnsB^(25ymMi-ToJHB>hdpR} zy(F7tGecH8m^n3*evsg!wmB~QwDD;0cS}Lt#!_g2tR&;T+zWC@lyg} zS?k76yPNr95lg{3*`-T)F)F%6S;pc~nig1eW>Z z;rZEF`MQJouTH_{kGS%1#|!LF3moYR9a9Sg*YoCZQz4rO!;&oHltLf6R5|OyIX6%c z9Vo~H6rxiUnpzarR1~9AS(mK}l-fKho8~M(?JZTY=qv9FFaMlXK1Wclxmlj0 zP`=Dou~=NODu*z!s#vowry!{O!cuu%uJXrV{v$4u{nO{?5gI$E$j>E27%GXQbhK?8pk!)2bEHDsss;hXgrgo5VES>dUjG^pUxYbHvQe&uJxJ5_ux=kVNcr zPl6k&Pr+3QC+dFeHF!GJ+k4gA`4BN=8WlIORCATom_)TqNzIl(?XArk3f;1M^u)@| zP+3VNZzPcvy~p#p{JWAM&eAuRI(5@~Z%W{GS9B!cIgeL_^;21ud{yxJCxms43J7P3 zhSzfi){@2dK&3tr4f+T@6*Y&7IYYtd8{>gsL?Tiz z-=i^Cr}4>h1Ei!;t)nq(qw%qNDN3iQ5ZaKlg=rsI%UmppT{_pKyw_9}iCrz(EU(^N zuY!$2Hs7ymZXv{OV>xSINylvab8U*Q?d z5cVE&_a3P9?%DKyjqKel?WJ5B>ixXcyF}PG%iTAv(l=?-_b#$eZnU&-XsEA$tFMQ< zfglTnE6^XT+8>hEAJ)kI# z#e=VO!`0{~sZoN@ zdZVjeqw8s-n}jW++w-G4=c9WJV+8wBW8d}04!y>X3Dd?-TgJ}k$1cyu2tbVEn9}3e z`s27@?{WO}ae~(IYj_Le#QI@m|FPJa|3FSb08aX~fc9tJ1eoi9!v3qa4(RNF*A9s6 zfZz@&?10w}`0Iew4w&wM%MKXrfT|Ao>wvZn7T#EC**=PHOOR@ep`)nyU{YrSS@wb7tf8S?g`hB47U-#J<{upTcCzqoC^FZ6*xD@?g2HO70rD*>$ z(Dr98MfyW0C;4Ng?H4RX`sbClpZ+=NUsu{lqsaawp#O@cNd6?C z|HM)xz)sum0y==D$o?vz16YdW&jLDtrAYoPpaWQnQuF7+QQMf@w5`VW>O{*_Dp2TKwE!li!4Qba$w)bCh| z=oc>aA1p=ulS`=SR04{ZqaXd^^D04J0=5u`7n<9BpiKkJY*mAJ&ixm}=-^a(_&0%VN{1Dvn zKDp_8zJNEU@^c+zQt$TyG1d-MD+zk!>i0t41#!+NrYq00E?*Qq@P7O>I^uaATb;0q znYC*A>EwrpFG~c2EngbDyez|gS*q*$ShKa{c~wY~u;~x3mv_djQ)^9J%WYYsb@K_P zQ4!`9s^z8n_X0I@H(e_;Hn`QV2((7=4~xus%k*w2Pa(;LG~OV@ZPl0w(L;*$#o^^V zhCsJB<(USs&{}3HL^wMd=2nu&`_)`+3q4sPQl0j-OhxwI*=%2xNGWEKoslWGZt{DL zhN#A`w)%wgN&S!qD<3F2i?$_#6{@83|{cW^u zf%E~e*%s8Rd1q$C!G1TA=E4Ntei2-xgr~X;li{T9culsWc@$~X$uZN=Z6{#Jq^E(a zA|^l~vlqM+{}JDThM%#1FZ7I@} z$KY~bvy*4F0@gUJnc-|&dOK>h_+f3MV~mh9dVaO&_^8QMZ797EX;B+?1PJH$Ki%{D z*XU2j!_;bXs7Kc2C1p~5$TfhkOU2=Wuz^S%1o+k!DzuV)^VL#Jb9aEo@!HvPue%j6I5(}G)^*Yje zn7F`4&UHV(ig!c~@E|DW4As46ZAO2euFu1_iu@|N?Dmpbys%5D+`1U~b{pR7;t-4r zQq-1HvU;_)9l#dvY=yPeIbgWwqe>gjbFjGZ?6OOMXiD8+)7 zd@@kiRup_R$6f2(BLE^r9gJhS9x({=sOPZM>Iki+A$R9QMY7!SD@%51n;{ES^y&&D zd#32vw@#*F88u`{#&8@(Z|on(5lw$5l6@+LtPCGi4v(gO7rAIjCUNZfQH)GZD~9rE z(7F~GOw4Obj0|2Bb#s@4YJ8%IXM)))kdw_K1HUxIf zDTV++x;X;_zX^)Rj|0U;mCMKBzm28QCe=s9VLN$a-QI}9_!fJUmedg+Nj^>DnPdEC#l#DJ^ zCTWr*f*33g-|m39A&5y3@Qv}9!+!X3M}ol(aLQ-(*r0G?Gib_+UvdN|u?;$9nFR(@ zBE}U*d>r>wxMhWy_k*XRi9FGW1(vw14I-*=uiQqWC}`>oETsfWGz~}GC|4pbM)b=n7(RMu{BB1UN1IBBwJQQ9ahpaA$kt}42UGrgbG zUBxPW$Ot*Rajl6BojyG7sR7RD1BDYRW(=^TtQTJkY0T&sNZf&5n@dUfh9-OkPiOu? zdo3Q3*{zN^FV3Vp%IxYuV9pX|YG<`B!|_iEH*Yj%wT)*G4-%HqW;ei6DXa-O6~m|_ z2wfYqQOlW3I)o`Kv^glbtlM;ipP)JQMrmAJg!$;4T2P_@9pRpMZk1g6O%Xc68bog8 zI8v&F!1*M%NG@rGGp`VoCNGgU2G1)cP4>Mpke4NptU?UqDAnp~2+WkF^~j`b{ zFUax0&%Uiw62;;b8&Tp{Ra91jZ>*GANr10lky>MokMt-F@+)pS#q+&2TN+Z8+o6mv zBT?qdQq)H?LD^t6q3~ird&!X;)Rqe948^PFLliUV*KP z+v8DX=~qP1j9W2VWxZTP+Kd~2L*k94dNGv^ZXx`Q6)2t|5_h*Sn^_XKP@>wjs)DT< zhxG=s+9<1pha30%X|*9~sUSVBicXDTRrE-BjSj3_N)iX_)>(}XOWA!S4p_JLxlxtk z7WPn6t=4i`>1M4mOSvL@ou*MavxY8?46;tiu~M%Td+D@Jg{8ve44Xx#Ue&L{vKgC$ zrMdp$a)o^)_BXl)1=5<=k=O>F4RTqP9xB)uvkmu2>u>m-VFgQ|9%a>rAh8sXD5;LB zC_-$KGgMrILp(jUlvOE5(h$~kAqIjFTfrS8{S1qJpfP(6Yq!ude+a8pp{ckO%VMrc zST0GAf}mMw*@bR7tX2|>3f9~**W7m2Owd8!(ka;j5UE~5&z62<3xK4C=UNCx&sqQ` z^#T1{#&uPdY^a?796$r@#!IH>Emn>vggbdUbN7b#k?I0u?j@zVl82hAu&=uDf8p zE)lOTv9vA;!j`Vz2o*y&!F{RjhkD)eUfqh|v~D1TraIrPMtI(>&d~Ers^^7X4;VI1gjz;0fB zfNu`e&Vfo`Ul4e{*~hZ_7I>B0qh2lh5!%)AQ%AC0PO*Q z+TfB>U|?Ybv3Z~_@GGeS@&f>=afd*8dHIr(;lG9o0MG+q9RTzIlm~!2ztSH7{QyV^ zNEG~x+yktKk@L?M(*Kp8ff_;l|JS0@-+9pg4hKE>yQuVc9`s+}Aklx}LH`X7;`@UK z{Wmy>^MB?k{Xg1wjA>5!FucsMra=dr_&{ zF{Sql33C5>Fr%thUi(g)ok)20MQQ8Z>)NGh4<~IUi8_Aslhz&|E-7O3f-rF?1Rx(n z>D5r5!l-XR=Ij^#)FjE)LE;qlx?}y8s}wq#;Gwd|s1y5Y`Pc9Yz&sGD^W7DhjuLoFNDMH)dGA+wcD&Lk$!a=06|ZL_rd zO04xZom14b%P0DG$9II;{G)vsR+zrgC=4&e!RUa|raFk6&?YRFh43>Oi1QJcI6nL( z_owySs7H^i?ZXth7HMxZp!<|F_MACCGZTJHQGB8nlMqaCiw$kruMl5t8h;m)wcK*x zrC@bZo!)1D{q;f3_iw^uygu=B|HFUVGP(d(ukL=JaX&`@02`WvdV75R(sf?wUwn zY)}T4jOSzSr3sj=D+)EJ`$dp~p0rxtyMm0%nS3kI)&yxeXIOr(St|kPaBmqL(xga=J{L-lIRy-8DEw)VQtWeU1vbOP=fICt8k)E4i=>vj}L#;Z3b;Qxk%`L;TpDmRQhSWejSvR z~vFC9YK1~{kfYgLYH>P;V6M65R*h|VvR22E^ zJLFup#(gFzQ}+6DK596M>EO68*)$F29fe{B#)Iy$?Zcv5KHr=`^+K9?-Xv5=ym-@%bKDLshS=J8Kh~0ncmjGW#~5Rl^hG z6gWUGk93eG zyn~Ex*K15`p0aSig;!=Tnr`~zGL>8RwD7leHocE{4&Mgu+1bb1P=7w4I_aNy>Fl{g zvcw=N-nX%^?>${V8E1AhoYD2tWps0{aM?h-hpf#dpTD(-Xl$j)N%FfBW%}}C&8AV| zwoe&CdPNt94pyc16{WGrEq&Rh!I~Jix}*4|x%AWa0b;%y3x*B)vBnX@Fqalj#!aKx z=Jp;%_nh72HG4zI(nZE`?ooR3yXdBv;JKbo-=Q5-aQe0bj@{!0YW|m%zDzyYp=axjDv&uB!XIt!t(y&_ z$0JwLhwI6w%iIoLb~g;a4LY>#2vvH&kZrHYfy;d5)VvsCHrVoXm-zHNhwA zx+T)kPYG{f^y-a@s_}MlG5Upfw;(_{J+B8$}}Gl^#`no>M0ckT}OFs z^ttJ1`OD7uW`zba()cMtg5OsMc})jB77uT!q3E^&WPF#mrcACmAMNKAH}(3*rcf?X~h| zHy)J_<`l6;uK8V)k0_ECp25{XLPFBR%$&p`iWD6RAQ6PC;T3oh`7@F|QxWe4#qUqs3zMfLUznu}AnC5H0M5ixsZ6 z)iH`2Nfn)m61i6F1sF!q!pMo<$dY7O(|+jTZC5H0+s&ysI!$@OqgT}7aj2ZQCR0PI zg7~RJgZD|&h?z(>C^rXK>+nN>XF^z&nSQlv0*R%m@lnE+a6;9y1XDacMPVNaT3vkHz91>X+SfFwl}B^EsLniNecf0p$5NDD|b$Zf=1TiIHk zB;DUgo}_d(Q%tTfN!H&;a6;I*G$v@H6KosdWyi^PG~sRoToCclpn=54#c;jilt?8V z4KzGYG%OUs6^=-N;U~wPz#pUGXjbhmcK054L`git2J8|gj)*}&@&JeS9jEYFrWOXI z=6p%=K_@$-Q?>F_Udv4};St!V$eEPEo<>qs8eeCrBtK#JMQ#BWDIbG|fIKtynwbJ10D=I0L4H zFcC-Y(sF$mfVU`R=x(G=qti*)ldhDwFlN(EHW1O!^rHbd$spn_I+GNYmHR$x1DsvP znT@-N?8#4h8vcwfL^*v$DW^X?%P~9yFl2*^v)Bo^7*f-(b7is2BENuh-)&?HHE{}i zWQcBNVR`7^w9Yy+W0fh;l&Y_ty$a1%-po~v;Jg=seAbkqpzJxfp&=EK z&61R#UY-9!C*R5=*O*Sv)FYqPy5KQgCN_cUgW2q1bb%7I&=_23##NY;8c^`)G!LgF z_s#9hH)46u35xtT^F$E2u1z^19*@K5kdpmH(aMEf5}8j26EaUylQx0zYtxie&H#d9 zB%PjZM4>fZVnkDBeo6kZaOvsG(zP$~aLdx#&G@F-(i7I4+SJl-JsAOwMQ?OC2{(+e z!SW47WkmIN&n+W+v8=8*$|p+h-UgfXNfd?`MA^Go*K$M@h(vx&ExImW!PF>NiYk{& zeP+|JTA?&$vpZ|^ZPR97$L5Db<)L*YTE}#IlS@0=iZN2daNE9!<))dy77#Bn9@ zSrsY$8$c(b(0xNC=J|#O`G&6f4a3|Urn5IJx=e>OLO;Z+#csv2c~&dIs*h-@-Nj)E ztM>(L426bDg>1y}5H%^Q)fEBpWzCiqYZ6A}C*|v@6^%4;vMRQz{WbpNcUk34ZsFFk ziPSwKtdoyaeuAufT3Yvl-b7OpTTi#nxU|ktw=4o($udyu8*j^lg~7*Ju|Fkf&z<9l ztg_pR=LNfLmDJlHgYi+p9AQ2N{XuppKc>6E;U2|aNOsF2p8(H*2ge?dXo^3Tpdvll zqlbchB~eydC^Z^ihtu?4k4D${Mt>E5*c^~e@e|j|DSFftkiXS%2KU7iYsSS7F7-r} z+e9Q2HoKTK#mxm5mwH|HHHO&~hio)EW45p!H`Ycrd*XX1B2fdSEsmWn)+x=m6k!|u zuF{^d6&iDzoy*x&MbK$`vy-l@)fT@MlP7pdrIl>B)8bR7d)V6s`i=x-3F&;BnABUh zb#J!$#!Hn>!Sk-5&8GXJ-DWvm^4GcyQQZ!xw%5=ub$=AoJ?z&qDrTE? zRoV2Ti5m5=*rH_hAZg?2ew&(ZHFt=y)4C1OlUcueEEj$C|USu z399xZ&rsBAd4`+F1JAMh$T42r3MZSmS&1=Y!6hVVE2wLlI z^^TKE*HJ`|hg=t>X?;g*XJTp7i#fuES87Hs{@%{=J$|G+-cS{fUGxp|i}wPI6N1tc zcl9SkyeGucCnU&QC!`i8WG*J;7$@&bPd?P2lqdI|R7{^#Zk<$Jm{hx%RA>C~j9mJ| z3;hoo-XFBmKj^f6&|COmz;^M$h;hn9dg_%lD_9-FjFFEjH+C5JofaJUSkuyH)gU^t{kO>%52Q{OHBJ*TnpU9WNzl0eWj; zI(lI?T9j&hVR|29$y@X~%i`4i#h&QJ4>^mTE{j6j(!4v23*UH`ydTW|xZq*#5XF|8 zIr++S)+)+qguW_Y^8C7l)5e3*hQ`iV9(q3$5WRed*oTMI0DZlU2f_G}x`Kz+{v&-1 zPkAev{PJTb`6p->`nC@bTZ||j48&>A!|fwlD|GRxi+lwnuyU7)MLB%g9+ka^d_%dAnMOO{9_4P}0 z=a=uit4)GSb9Sp128$rjlC=*vp~31M2YYU(7|{#qHOJAV?#ruyjY8w?MD+w_hr8yfRqZ`F0E4f6QWj8lp zs%_>~uN1~?=HzS&l`(C``K**)Zf1U(lULg+=)|xcT|lXEJ#Sm6S=>VIZ!I%!cV=*Q z+b;tBdied-j@E77VFOOf7*TtsubF~hSM9!L2Yj`w_?p%E^~26rB<9Yi9nL^v&@F?V zWWjapy$#+?AK>@(-3rccJ39wzoZ%TeN5q`bKA#SOf<4m~l6WuPWaCn8FUe#NBDBYj zf5nlpxR;Q#;Ys{0`sOBs+_$(-n>4$7Z$?+>D>>+{zD4b?3=@6jyLrXVkcH9UvmX(^ zZD+sFGse#6uyDthgKun0*nwlzU|T$vgP>|lVClfU^Xtdkyx&9a@5mc+C}w^Stp2Xb z%)ty=-M#q!*mvm#KgV{(u4XLzkB4nP{K$88c7M3WuV3*Uy1w5fWInv_hhgn-XdG}D zZ*a(}juF3eXqU4WpRt2t+;NL#BQeBw+WF)?#+JIXgvs&)BnOfR+-1vAn+3x}?b?rx zZ*B%$&E4Za(a+h$2Q5(hoj6plILUm=$Yc|#5UmzDwKv&{&&4PrVQ=CSr&e4#wSI3# z0|POTZn?VvuN7>R>YsCKFq#e7FgsQgT&CZWu;Ke1Yk8b&H@#l>-Bf38GhF54`|ah{@3+-pe|tWkG#P3yDwfR{v$6ibx0+u=??!(UO*d)Xhi=x zFYu>&{rowDw7z$`+#mV+OD0)E-^}(u^7Rl>Utk0icoQI`X-Nk9 z^}H{zjl0woJ68bxKRa6Gd+^?vpRzNTaa>VxR|PlbVg{U?7{oaUG3G|sOEw%G95 zVD8T5Bgl{Bd*`&Tf5{8{lCS?zs@)(@)KOyuq9$;uhj`PwKSu-o`d|A3Kl}A zLkQ7b=r4VNVoi3L4yfDLIbEHfeSzb>lt#FyZ`&W2`J0_KyPK_Iy#BvrO&Gx>t=3tU=rW^ znXhNQ?#GM$@@Kw2_JjmUg5ROi%iucylp4e9^Ur)e&tM;O=+At;1y?mtnfD@Bj!Ur> z9OM!Q^7T2Ki0)tV^$J5kEh&+~@DcsY3nb-CJO|?S zFBPBLQSwY_346Z;>h&*QPwKiZd>GS%UQlX@coQYhm^3p?&Adu_89!x~?(L@VDm$Gg zdG5+plAj7(EI*@buELXa*PF%-ayNr-XUy_XDPz!!0=u+&H zjFSks1CvKcXl#ZDG_&#+)0t1#GprCpgETNI z)d*L-sDRLoG^q8|2ww?ONR~_*Y&UB}z^2__cDIf+Bt^&gniWm?^(oTOC#lB7(6AiF zL((wDCS%g@qFiPwcCv6Voe4R@vYb_fEP~3zgrc}V>y`#tqZt8e_}$z zHV~F_Cz&kTTilcm%>n1ECyPnYHf7jYPZFIXiw!}TGM(TjNglG3!M>lGvVg6UrK!l{ zu;^diq~%O{AVMB5q5bMMCn8ZngFN95;uRa4I5<(ojXaS9{fa{gn(!o)sRFf)y}0Q&AS%20nj=n2pR7p)4^K0~SBHCW|#FOC2EA8jU)^*Vz zsO{@vbnr=PHoUw)-G0Su>5aoM0-~nL!j=~hB57In;$ob6+hgk(v;j6Iz;fxNmUjpd zZZRNs7i(k2brR!clrqtYnCYJ9af;9lH=)vW>EV~pk2JeBLNhIvc$cx)NzkFiRPAu4 zFA3Wj#wq^l=7VQ_3Oo66NooDJ8AWN~sv&|mGNry~DWeCy%bb(7!_D|s-VGVXn5HoH zTvHH)BcgTXg*e7qRs~DYBM4iUw72*c64dBX26uVqj33*tenyg9aPLxHNv+Ef%#J_i zd7X0-#%L*j%jsQgMQ$!DW4EH_>@d9db*gadhN%1O#H&2U*F`q?*3aCXCJRHdi(f|f zyl9XJE(>v`FqF0-A7q@JQY~|>kVM%Sn2t}kR~S_VGv3pqrVs0u7r}~UY~5zynnUB* zy470Y+gW2d&Q9BB*VU!V+Pfotmbh`<&kKUmo#2T|t1|A26XiRu&C?24jz`Ub^bXJo z`nfr85#Vzo72E4(vV?=cKWzi ze38s4)|cV)4FIz%88|0oGJN4A)jS#0UUY`H1KY*$lVK|ZG8g0$!tzNvBVz8W&wFeHpinsn2Ex z3`pH-G8w;GX`aoZ7u_1#8FxsW;Ait22JX#EjJwe3vqjRAMfdhA#yx-P^JTChq>Gt^ z>07ww`6t>XNS`dzejJ?S{4=K^bjXnD05N^OD!xPt9rI=So=ttRu4L#jk;(Lfs95u2 zQ+vr{x}E9p4V>g++se>$Zi(rrar$Bhx3+48l?e;lV(JB;z0wQjF{B7}PKf`f&_q5>m`t>X_^> zOzv?^eo`zU5_K$57?$KXmMkgu19fZ#Vi>l{IQA1#oagE|nlKy^-Eka4Qe0DYTniYk z%{VTJ11X-1I-WZW&vP8lmlU5UP#r%Mh95bOA4>{O0IP%HFtF)uAZh#SM8xfXNEv{U zG3!1Cs2MZ%!@F}SSL>PQ%Q+NxBxc&G0P_N{ zHUNDCusHzv0WbBVf^in;Ks}SKPe;d z|4l^vzln&yr(^yPqzucyN*QaE9epgnlw(5YzYSo4e@Yn*+xtU2+)ww1??$#f91$Dy zav7CEmUaK;WKi?o9aMadaqwQXLA3*LG8Vp1s_@_Y@qt%2`p1+C#qw|0IYw9Pb|Mtq z^jW`L=eW&1c*rY?p=vv_KM%}53NJp}dMv!8p1^y&Y@BHaS>R&M`sFqUvf_j#(><4{ zS>g6Mp3wMY)sYb6cp|}tPh!=}H@ah^sIcO6vzU4B$H%H$Z}wLJALDp1bIks1Cn_Q4 zh_nUI*KgMai~oG5d+6fFw`O1^>fpipsMqeh(zND5dDV8YooT`>>96xxd>2Pip@vP< zpCxyszK_LLN|UXQwK^XTYw~xjVGkQVyf|E`Z2#65nE6AdLFPtYrQfz>+F1d6$5}Xu z=h+G~_!GMkP|SIHYeuhp~PJTUm#sDIGPuq=6LM^mvbGC|gMEVLqKUSyw!k zaun8kUN}ZK4ism!I`h)U1Me%|n7qX6puznX0=Z>E6b*TZ*R~fR^7do)9m~h`#JLN4eWLyv*3CCX7H9qX^UW)W0w(^xXAnW8wDPF@Pn4dyT}KSmu1SwI_8QkarElMT_31s{m`6YY;zM4(U?EzC&&LJP_jzqz zOT1MT_qEB+YrT61#bfTucQXiC*J%4JM}yQRGXsZLAMqwX?f)p`yN`WKTHIsYJC!1j z@x!_bZ}>CSlGphZHk?0XlqX)_e35n|pY55;=Cd5!io$Yi4t0LE$sRk3*gIx<51a;H zl!5UX0e*E$;+4XFO5@QAnZS6$=7W>v*DUX$VBJNFl=exkx<_+7lt2r z**v|m`LxjFC4#m5w!w!+&62|(p1va`3~S1o>2%5E`ad?!sJW#gMcj<77#CGAwlrh{ z$g@R*xwW4cYZWEGtjzS*vb;5?QJyMNlhewhm1M1%5%VMS4bi6QlBmu=NPM*uudOv6 za%Rz41d>6t^VVEQW@X>L`iNGIUPwtqD~_<0C#}-sbtP9QX3J zNBk=}+y(*VJQ&9#l|Dy9OBaiGnx}<#q1ZzIpwmU!HFtGr7fY>_&yP1eHLd)9@z_74cq;7(@~ldW8@Ewa>hZ zc^T&QiVcYvke4}#|5mt)lh-(6&m@?J_F5I~LXTQd31QL5`AHu^$NY$Ev+zbNWCF#9 zP1Ih0OUzoAe-QmH#{N6`P2L+D4|Gon5;gAT2}QRnxin3{iR+3K9&bKFo=z62)s<+D zWNvCSpMI#4xj$yk2tvF|wUBHU9(69aN#!-}?!a{``xgAEfHQkWb%L(?OHYUA^zI9H?f%a`K-*~PuWU1W>#7wfUm5`)0$ zy3O5%4HxKCab9ZyLUd{E938Ox4-guv9QR8Lc-S-gimMqE9ROHG&*b?45njG znE2Wy50+k0xAcalx^(T%H&%>ruP_rIcF8VopYcd8zC`}WcUW+ZT=E*1^gU@GlHPmG z>}Ag#JM&KV zj651@hd%UQ@>n~Qd7HxT^P~I;-!}J)vuwE3XE3ujdMsmY%vWy?HbBCAAp80vXsvC( zX4mJKZPD#O!FSZe*M!>lq|j$c0jg!;>Hf$&J;?*pk{7#32TTKfW8tT00X<{aD5Uly zL-AY})1oZWKJC-l6!lsU@{L~cXHMa{v2O5QGhlqu2lv?L~e6 zHE)i7|D8TZwZgy;oq>K0-uV4~(qcjEtUe@=phQJ6UgcvCRfzvF+`o?6t2;ChXu-U5 z4wg&t(Gd#}DGF9kaW~=!nZF;he&qRhJ@{NR@N&&>JSot=pU06V&}GKEb1mp<%>_8M z8GIjlp$Rd=_p#;(gY>%vTZZ*J8!{&wKhq3P#y3r|G^VK!zYxZ}L1mlM9}s`c-NR;D z(P!IBW#*M9`ylbp7|hJ8P@`scbMppp2Ls1XO3aNYVM(7>0^%JD-YUC ziP@To37)ceT@$srZ-o(xeW>NA0ER1S#=Qg}+3YA1;_ z3NRr1Vvhwl21(l5NlzgTMresVRAR1)vyOPIWtch!+}R>L*>)od3U@YDO3pTMe%9yA zz2?MG2>-|qk5hyf2ZXs5g*j;%5h{j0)C!cxPieZDBBPk%iiZdk_kNZVYV6aWQuZmu z>Lf6@*pGcZED0PA4@Z>bAXaY&Y91lNXoKw(Q|nDq{l9oy_6NVWFbY(ZblFHv#YE=+ zfU2duW($MF7C~~0kp-Q|`Y*^*Z9inOTNl{R*DB2w;q|sStrhLrUhMzD%HvVL(L2uc z%z!iuHY6Gy>~|czFyPt)0cYgg%y^-YF{I>?g-+c*;o1>*--Bkj9;YNKW&S{O9ihFZ z8vTzNLY}%H_m9%A8j+LYfzPH9R4JL*gK4-Po|lbT`RJ@-L1>0!DT2AqVtF?S0u!u|V@hF4C3(r6dC9@4V(uwogQ@%_SsGlNTC*>7=z{!@^S=Ga ze@R!6AXq>qk;T)LIfBl#9xUi9&eo-Kk4(vDrYl^7hrKRIQM4}1|5U)VnWwLu&Bs-= z*`!pIXHs-uCzqZpDAFS@db8+jc%eC0p^Hwzf@>~|L{5}WPD*OAjD(|5Xz*QvLXp`p z%ZRK^tb+U!jznPa#5EiUzNDs<(rdW3%$717#dqN5)^Kq&DMsF>jq24h?beDK6NzeB zFB|iSQh_HbYnNp>myORpY85k^oGl|;EnlK5|HxIL_Pl&bM{H)6bL=<@Vs4HuGT+s) za4s-^Enj(mHR@n9+Ur*34-bnguIS7a3v3ljTuIA)epE$GC+ul>?6t_`v*Rkl0T_WN z4+-HLvY{#*1nm0I8%n^*ps&6uS$$i#8W1ubaUiR?nyY!{s`<{U%c|p9O>G2igzoCr zu;j)IpV@St#tR@h?jaI{pb7Kxwdg6~4UckJ8;*Olk&iZN6~J|)A&E-drI9LUY^%LV zHZIB8xpiTW$si(yZo1tNJCF+Bj&*)OS5vV}S;*Af5<;umC68Wfs zhU7@Llp#c%QeeDIo{MBrHX%nYH-bj7>1jk$hizSx@B5~L=HM#NCQ+@z+Aa3_t>UQu z#&D0U3}jZ=Tyst{>TP8Dn}L!9kAg1ll4RwQ{vr1E(#Em5oOG4OacIj2UAE~nl#fnU zQfd?Yv=vI;7N1&JLDx#u-})k=%_FC6bf~S9u+d7Q#LkMlUHg9f@K)Ost+tw>cDGOM zY07zi5{=Z#9nK~lzFeNQh>mJPgKK&ngElE1ba@EwHY8!|F<~h;eiLOGE3ubv4Wikl zq@zW$lgR6B1{&o@&_Xlc#;n)!X0w&2g_U(UjeWa`Q|j$wCwLdHR|y{cf3WvfQEe~k z-fjYI2@u?ZYjG%2+-WIR3KS`>Z7FWWEx3DecXxMpcehg9T?=12JJ(!u?zP7_=X_^h z?#tZdA{iqYnT}scjeCr{;U>4kBu3C}^R$?^ML>qTueTXGlE0 ztK$ypQcmv*4{Q+BNIPn1)k|!A{Z`{mVH@pYM@?Y&4PVn;L-$+JZUNdJ4Ydw!lScdf z8fLHy+HUP$L!OmX4_RYIx_csZ)olS;c1~ELb`p1A6u2*KwnEIhCpd^|p0{{0zi=}^ zaVb9_m8gG%w`fbsS!tjz$%HBe=6X@;&mJuN5#L`4)=w|&uVz#(a_%Z=q>_B|xkF>{ z<0Vgb;$WxqU~l7K|NWrz^x$aWptX}}b8Vkp+}-@Gn=YF#j4UYEbLE-K8- zT$eFO(ilP07!l^u7|G8uGU9OxfpIF$acY-wTFRtx`lfNlrEz9rLaV#+v-Su!P1E@| z#3*tTA6P$)kplo>6H2_AIKzy@=l}o~`lPaD=m0qoFgi&Fos=vJ?PCSXpaAdyQyXsL4-jl0Bn}fav3;covAMZ3rcHq zxoJR>uo=9qxHV4Y02Kiz^*}@<(?l3HcLA*3xG@mEGH0`nkxFMKtT3R%XDm+U!Mbv;xw6DWaGkVr%e0&_w({_EWh)84 z+huV_29AJb)d*uH7-Kdkcoo3%gwz7)IkftPb;;jy0dPE7EU<=THUlVK)q*6ip@u-* zVV@xLf-B0*^W{|P>@uE+*)&$`52bfMSeNmLmZwPgmp7~|HyqODlw2l1YT(ts0vDzo?9j5JX-VYex9}b}% zB{}e6JD7Y&JMDVl5pgh&O1nt5eBgSzyZW$yDR5Yzb+~!ByPbSk*?hR~x_9_+SWj|9 zdb+%Ku65MvdUP$gf7^U?!g%;VatI(jc0N6Lf=2U<+U?jZ;uvX!8oY9Bbb5>)N{uOW z^7;Jcc)6s4z(E2qqlr_8dET_LCPc$`b-oeNExoojqQ7eKt2VZQiT zz{Z7Ub`)}q-KCt~WwqbsPBBb1aQS`nav$|disEYg&6T9) z$ho%1)mZkG{>YUA+U2Vkm*_%EteU4@Dd#>yC&|RsD&McPo?Sl(&R9iWz?EL`#hr+y zPCZRNjU&BG7`aI*2~E+y417EfmXN)*+lvYyz4aElEi`|UYI|#(a;wyGTOs=-{PLo7 z)j7lQI^L&B}4ApckT z;S%qe<7Bq52%`S)h>U*%8R``iXaA!_#@~PpoBt#-(lP;m5*Zy{SaiyN5E=4t+_vY} z7#}4<34Re7zkm$A*VSKdEI<7QGV)3?KuLTsAY-m#UFLW1;iYO1+#f*3q@i4Y0mVSD2npaPy&zLWn$myr|LMNm_L9F-ZU|{=MqH(wkC5GhC;eb z^OmMdFd&24`OeC0XgXzMIEuI^R zi{nLW?~035(;q~JZq=1hva9RQyBn_Z7uQV>e*_=yu0t8_!7w7@8Sx}sTQ{6DjL2YW zl=BCryhraq6}Ug|^v%MP>HEk-47(L_^f1E!v;MfQ*4CL5tHKRPNSvi5aWGh(WfYM$i+a z+9#==J31g2ve=+-Q;vEdgn3NME{smND<5*=63E&rx0#@eOWr);wnHKX-A z_wcLsV1sk&Px(zuZ|-w=E{>Y0c~_OVe$nsRgYqnA#!uWumh3ELakC*9j`4NE<>jL7 z(vsfdE4xbQf75!H)kcAAjLRc0YsTs_e3z-z)f3*ESKw;ZrX} z`e>8YAJL)o_?xkxlH)exNsGrfROo8~YiQWR_}d94f0Q1I@wZY{4afVC!;}bi<3f_- zVK7ECmys&S!ph@A#RG3vK^IAaNqAfIh9@$w&eeW(mdorx?M5;oM(|EEjKw%BRX5k8 zxFv)=l)NQLj+E{7nV)oMqY{ttu)0U*cdCQd~I})2CSmZ(aUi12Uuv^&mdX#o^%a1rW zlG}^K(B#|QiUh{x+x;pb*h10t*j#i!Sy*^A0M0txYH9rm#D_n9)~6;~1gF zmFv={v*Is}imlxj{lU~fRU~NAN5lNV6cjjfy+mx|;o*wpRL@oW(BF?o#7&Y>bIkRD ztZ2p~(_WL&imCPkBF4Yvb&}F6&GlcEjYpNeCT0AjI&d^L9$i;S!c1*5H?Vaw9@F-c zgvDERa2aDFwy%44e36ma0s@Jk3P>92F5~)tpSDTpT; zMB)yZfOl$cII?Ua>Et7XpJ_vN#Nq(a=jI%X{9tax^3d}A12*k8fMC>}q|)+@A}vOg zx16sab~;AX_G>xYT9IQjBXyNY>g zN(EXt1*%iK%6N72N+qs1g}OL!E2bCI<&1@&^&U*Yh1%QvhP#Z zI>SoFYuZsda-3Vxx@yL2K8ZSV-oaGYzFWp?xtTh0y;;zHe86i3V&NV`5jAw+Ng-{B z%Eun~un|}Y6Vi?mee6Z2p^K~y>A;^p_Tj-?)J1cHbduwq_(^K$VWmL2=rEN}0#q0E z@LC|lcc zh5_FS{)of$S@O-IfzTuVs5=_&c`Bmj7g#iU%wPFD9Uo)qi?|TMczE=ACY`3Cv^K#6 zdffDRHqVlwtQ)~(8tz4|B$}p?LJGlDp7KS$>XOli7BqtCvgnIKLrr7#6@r<%X^e|v zhb3d}M}pZl+{;pbO%r`IQo^}D<;(K8B@+W!>vs$-`m!=l)6`g-aA9uxvbqj!$<*A9 zaB&Uysc7d1Utk96-^q-xIjH|6 zGr+2UBQsF{Kr{X#Gf@6SGk%j9D1V_DzsL;KzoQu>C*Z%M8Q_?|p&6*Z*o=Rp8Q|Y+ z#&0wO6I{!5M$D87P0k8Nb;K@W0@UUu*{W z-*Cpw3IYty(5~|M7a&3YPe1}jN5Ciu7##uAkiw7#7}x+~9$@eT410j_4>0lphCTdN zQ^05k7#9JfBVY&w41|DD6L9FbFo^{mCO*t#0rOJ81QjsvD9k|$laazK7cfo5ZwKjr zp&I@NU&5>vVgGBb_+LQce*uaAAAkhTKLH85{gYpS1lE7Fby#|^|7%<4S6T<_pKYB# z(mGiG#bNqOT!Hxyhv^@X1m-^+rhh0*G5+o_{X=1j`45Nb9|}{9KifM0PU~Ra{w5{< zmexU^F}pjc{aacG{l)!p=ik#h7{A&&|48eg|LHLOOJR!k+hO{b!W4bx4~OZ$6{Z-e zfL{*NzbQ=7|4UkD5aXX|9avch{hw(amLQCOq;-CkbI%q?! zkaS3;z_&j4)_$D*?O9_^>?uv2tjT;u%F;cJa82ikyH#H5v!Zc;-{`P6wPM;TOjAKh z@Z8+UI^j9?r}vx&{$oMR{_!Q}+2U!H(z+pvQvI*DB_qF$N4=AbV;0MkB-E!@2p|8J zS>L!L&6lqboL`y4h(jwE^%|N+fO?TDA(Q_;rhhV<_aaXYH8@`}M`|ix?IK@gC|?<+ zrhlr~@uENnr9h=e3U#_X{-W@UOo7@=|8#Xd%0-byaDm3w@2(EYA7Pz8yE-VSg*ph* zv+V?zrG7GndU$jUz2A7JV}oo)^l6MIdb54YW0P_}y-1krGqV*)y4W}1A)Jp6!4=4| zq1iW-Pf&|3p5`yIIWX?SQtus(uCBV+Fttor=s%q1YZ}@y@4(Vnp+L3c>)P0{EQwr9 zVUFP)!r8UX!O|Rh$Hh0*#B5hypm|_HT+=Xn!)#wfxqKqZRloU(*>OTW>9ib=_t2%p z`a3a$_U33!bAR0n*YJ#$p9DEA56@Ad@QjH-EC}~ARKZdYyusCHeAjL0uuup#&D%93 z_1Z6p&Bw-Ug=^CmwH*aF#4LUal;1H@U=mYi6qv*`9;br)C9~^^fga=f@F;O@cS0R& z2z$mlv37QkU@!wr7v4LR>edxXf|1EJv?ETa147h+8nQNhp|ODh6=sK- z-q0^Xf%p`8AytLa#sI)9FRSd>tudc~nv&`)Rsm~@aQL8y^KV;cP`T~k*q0cd zT@wKImC2eMF3i+lLEZ1Br{E8nfY6v*4bCs#hoFFKpB+nhp#fqnp;5ugcU54nehi%Yn$nO(le$)-Q0% zfMI|(4m819$D^tu9LZBnpew!gPQEaw95x0KQrTuK<$N^&yFAm?L6RO<2~g^p)B0c7kk$$%_@C9 zEkBn~nQqmuJWuiHZt!fVu%Th!-JXAhhNRn2XF`L^?cHafS9pwIPFnZ-9>xqjcidRJ zAvl|xYImRUSJ5fe^+Tb)33(aCytAFmzv8 zzj*EKEb0mvbQ&x zDS@-AwzYMl^(+4F3x)WZ)iQX206zA9!MO|p60e+9Xy_pRTOSCZHGT{U+xT?X?kOq+ z=Ue_6wfJB%hnMyNS*(u1Jpq}lo&^YjX{?Ss5`oF#jun)Cfyv~~HNAley56m`_yb%) zDJRZ?UxMP~gZg{%duM|Z+k$9DWAR~0owziw=`Z+3AfLG}_z0~oOZNDyF)nKm{Erad z4G8|ZfiKlg?MF4gkb`Xq2of5`3O$wx4P*_y-u{?xAA08pxxEezNb{wF%Y&eQ33z&g z_arXNH{2c67v{X?jvfc0RPn%)grwPr`>uJA5*R|tXT!bSJ;@{?0^AXv?!nYMcNn+@ zZ#HFN@9^=WAW@Qrwu+K?3Jy^gCw}p{Q6}M0R9bVmagb<>G5?P{xcCmyrt$$seYkns z#tpO425SLmbTJwy;iR^4xYe;ShG~J${j?G2}x#SmqA) zV|~~}s!7uBF|%>l6m?0G@)67a*m3yD!XmNcKW?zr9g+oflXv^DcIJ|K+mestuy}b= zMAssd3{!Z>zmeY}V!xS75ls6An8#Y^4S{#W>UKzd35-TE!qS*aM;!Gv$5K;Ednsid2`UiQ*DtjI^uG0 z-E*Ah(fcTGa;RCBQ`)6@NO(O5lR_UB@xC57#xFY)A|0+UIj=h36$^N`0f z5(&`X-{t|<^3r$F($w;Q3g_q6qg{~IuA^YN$Uc z2}-|PpinxMPqIKY^U7ySbEudbQ0fRO=EK|{#Ve+c3%K{dI1Uwj3@A&t6|<|kqJb#v zwZ-IOged3nmE$a)NCTCFrFnA5C?~vC9Xh4Td|;)7s#a*#OATaj-c_ZxI=G0py0aBX z5C)vXtNyGGCbBSo#Zq9(2fljtwVJo4U9N_lrNSN=j1RSHRG6=6Kdy<)sqj!oCRjB_ zTJ_zv17HA>d>4@X)m`&x%@B@l*6`};S?Z{2Z86ZT!gA{poXo`x0QtT4dJOf&$JLnw z_L&I)rh@vi*7_nRb96u*Z+Re6rFsJ*Aft{CxxvXCaTHK{*HEwCfF4%geTUSB+|bp~ zn2QdmHclNiMjjuqFH>)-Tm=HcfFq~{P5G_B)qe?0!7zjAUjkF4WsFopR`k^dzyh51Xte+o?h@RuU}5}5w>mm>cmF#YW>MfxQ${p~MBh6zmn zN z;ZavtH$FZd78dsE)vK>xzjk$Xv9hwRuC6L5C?qE*fBg87oSeL{un>z?$h-s|cFSX1 z2S8?KfnD^#s?P@5&tTzaOirf%Y(M`Gu2}!GZ-tNq1$cORdHeYK`3D3B1&4%&g-1kw zgBfOH5KC7V z&Q<)s>PBNRFl!CgmaX@MQA#Jjtt;Oc_#glbZM5!JzQ7HH&rc*QFuTuAzh%zcp)QR&+|NIy z)e*gM=9mu-9az<#_cJz5|F(-y{6&`N@n9)$Ld$Y+OKXUxiXd(}wvh+td!Y zZ7Rhs`!sB?ez!)};U4!T7fny+r@h5X^M4e6=8qQQ=L}iqN2`z3XO^tr21p9LE7#|L zjbykHb^lrwUQIV-? zps0gGqnhOWp!h{F`p091e3A|A;H4+AJ3L0rX1D@aPI;Op& zFC-rekviY~(4d?K5;_B3Fo>H)_bVl`4WSMuL7R z_n?Z+sorXKGHq8vGFOrGzZ3eU;IX{g|P*@WpVVUuV{G;&P?O$Kzvb1e4yBSUAcE z;5aW@JK%{7GMsxL@9`eP$jDSbycY-N`-;(irGAbg$4$PM)zf)R`vac{BY+({RdgVD zoN{i7X9hwDBIG9|<28*I_dmiV={=&prb~P2yawwiFXw=s$?#>;55Ot-YBm=LS zf|QL+GIB(~umNBuengS+_;TxEcgqVTv>pMSK)qahR&ByR@1;@X;x;#OeMRPebeKNR zWZa_VF-5BFc6enFpG>x_@c!=eOM(7_1LwXxrHxL6GP zdlg#ob!GM93q~8F-?WMWM$kPyiOIK>$)A2EbK90*9ZB}277GirBwi$QoW>F*rlwJ) zdrd&;tU|sIB=d76d=2o8c?D`>a7A7uJG82S(i+M70Jh143qff~hDQMS!Z09`1c-`$ z%cWzU4EH&#x?@D)P`5Xp#0COf+#m`?FOc6SzoP%V78KMUf$GD1%=S{y>JQYJs^f3B!wCl8Jgd1NV^S*Uk2!k+>|znc z1O&W=9Wg*to*ajG8ajsnp>>ihAtrG!>FPi`BB?a*sa&ITNbBoMIw_Jf{!NWf*}ZhH zxgKGqoo}_eMAEixJGvz?UacEEV_beRuA>HMuQIyptHp(I^!$^@lZP>F$4GY$%Tj~UrknLS;`NeF>+Z>J}=uS*Rrhy32mM_|yX35unP>xk|}EDO%l68KnV zB2DE0sF0v1it<@c#X4~tVg}Puk)n})i~z5@IAREJvQ(!sifjp5;nRlm=U?yfBHJgS znN|gzKFQVCst0ioON11#Za-J5qu*RXZ?uWSF}3~2J#pWKlLvfwMgn_WcpDL-rHI-h$l%xCU-bA2 zTV^hG(h&5Zs0fJPNY^uXmY2lI;mPydi=sq9gOoTZ32xhXK!ri|{B2+|dxa_aXGf6; zx~qb(C?)n!9cPBC^K1Zqcy7nm?VI^kl|B|nj!6Tmk@a7KAQRu8E;-FzJpE9_Ydt19 z)1$Q_m%yO5Cw0MtN30qGJfpvZPx3L(>w8xp={dwx#!GZ?3M)zdxjp?a zLi`%OPOm3Vwi6%X+2eo(Pv&ETE32ZOQWnlY2;T}E1-~E`0WAR8=zB0fo45r2(56JQ z^dzCCclgxU=7$fd@OCkazPMPgx`os)dLBC<$LK?H>}+;Xicnt+HsT{V4AedHorY>$ zO<@V)_2o@y#gyKCr?|6kc6fFgjBp?;Oh{3u9l{iYC9}@eF*C*zeUu2y^TYCa5zPw% zfMX}f*|pV10_3O^A`2C%>+SmtLUyBNV3IGlTM3DQ^K6A7{X9MVFHlZ)voi1!oOJk@ zdi;69LE|>phNGt+4yPv?nejXQ>Mrt2_M#5YEi|m#d2U;Fyp#Oh0NoebjGqdoDrH!t zO4NEsHb!#s1|-@^zQ4LZJuFTxG?I8ST7J%IWCpZN^hel<=0N`4s3x6xw zGJNr%n6N?rh{hj%S+vTGc;FJ*u}sZASMvEm8Q|hPs9Y^HB>B)+{ud8qg z6gHvBg7&$;sr_J#8{=O5S<4XbvpDmd0GL; ze9JODAod`EUR{8f_cI+z2!fVGET22b*QeD70uoSuqoU1CsnTgB#_jd(3(+-#W@@C+ za)m;?xWMZRzwI@^z1YWw9Cz_2{b@1PmGjP)-; zwmy8HY#fyQc1~#2wtO}}n>uxfy2A@FyyEYNd_{S05kUPee^$UMmR;kC*J|h6nMlTK znv|iZr|NSpSV%wTa*7QW|GOpFEi{qC4N*Yn`S*s)hzHEO~ zZ)fRAm(HCuOUj`+&3)njoxwLc>=YRS-YUS&-BZ0!qq84iQ(y3+9N~|7%)BrH9ZgOJ(bbT)j+ji&#NW(1X+2Y z(+eZg?)ZV8*S!EsGGWqvp?MvNrOFYJYkFWG#TgA$Qm}+%&`7=%v9e=d%O)g7k7n?W z`VtrY*h+l(Lk5M$32qf|mBC5ikEo61N~xgGWXWOg&*GRDZ41SZC299Tu?iu6<>q1f zo5WH4q;a04N!6rjhosrKr1`p}#kr*Ao1|6zX(YF4WCZCHyy;YG>C}$twDIZm_34cB>CCt3ECd;BycsVk z)iO98Gq~e3cfAKB1@-j>urr_L z+k#et!uH*Q+Wx%uI=GJb!v6Zg!TG}B+rqMZa4%BPq*~FmW6^AUQHD`r`%WH1D;!oU z+$wMJnp*J>$KuWS;_dq4-TC7E+u}oll4IVIQ?(MsbH|d)_>$}TlH2)```eNSf>J<3 z@p62@3PCYELn%^1DY&5&b)gjft`w8744bbEm#_HgKv5}A5qv~hZ)+Lc_c8*)_+GU_ zAO(;?r@Tk5oETQop}dPHBP3%?0J9blv)`4c(v{D-m(v@Ac>{_011qTZ^5Cs17^ExV z^eROeVx^EPnWZae7s>@1DxsVp8q!1}$W;o)M7j%A?s-++Csl3G z>aMiv$It5iW`snaoGSIziLCFcwE`8CH!mk-qA((gRPfXZ za3_o?BoJ(;4)%3IA+-pDDPq`C1qn+*NMS@oqaY+QI7$F1kuZQM3~34jJyN}PpSHH4 zRTB;%U8; zZ1zyt7lE)mFrmiF^FUBKCVT)Ch!YE9ek<^p&~#ne@kmc(|I``T9{vscQ(m0!wQz4h zOr^lAJNh+3!3hj_iquXRlg=g#nnfm`R5oWKgYRbmkdnd6E0Ap6>yYf$3W+|0>LJ921L?ji8k)ej((}YxP)$I9#lsDF6)bdqR~c*txlFJ zaZN@LD|+Y72{2Ec?L{m6;5|`Nf%1rZLxI;&ls><(3cNrr@E{tf`<4tYk`J~HW|vMofy5rMMZGp4mh;Q(ekMF-ptz?Uq(%u*rK9EWpu z?mZs^)<}$_yeB&fm%LRchK!mWrNJjWgI`I97a0WHB2T{10QC<5Bv{yV`TGFSe$-J( z-ThvJ&##B|ntBTQy5+j&j+^5GrvS&}zED7)e9Yk`Xr&!6vIvyX1QCuhvxETC$TCOb}<6q!)3G{($c&gMbsZl1u}($NPQc03+*0~2z6B2#f}>iNn6 zi(}6VG(a*$b@0`18Y|+~S(#A%;2zV!00}_-Zp|BZoO%nouT?kdAy)YcCk9P9_Ac1fk<((&Y#a<#1V?;o`bzW)L{K{B{poegU89HI z>yfU7_OOML)~Q(?+Q5mg@|ADcEGDB*=_A9&L_fbCL#7N|@10r;?_=#7H122BKz-u6 zl#TU0Gga{H8xdyGr;s9bwDXm3&LBi0!ET%t#m`NPX+tlYh2*nWX!ymJGFs%>I~uX3 zN}_0{MB>ixD{&S(mUL*-(^tJcR~he5nJOFPNTNYIaPT5p+AiOhsGm|kg(Lgn0^%qQ z{;sn`@$SPjW@MM~NFGSc_Vif23i%7cHMV9;{^@S_CUM3t*mo7d{7f&2t3w?Q{%(Qs z;(0?4oM4;i;Qqjm*3!-C(nZ52&@Lrod&jav_D-1GBoMkS5Ytf|N%=J0mm7jg^_aX# z8o8S`+9_Dn!cl6NKK|0nOo_ZbHZ8fIS!;{bjta%pcs6qVdBr;NXC+^`JqBxN<|0?V zB3rRuKX2iWuM2(iKcH1i{Kdh}DSgC9`=CB8C3>3;SnpzC_TvxeWGcb<|3d zjSfPWo!M67uuF4tliKDMm*l?3Z~#01(&U86o1uU_-xvL5Bj5NLsx=MA z&ijvkb5E6sgHX4WAGVwzQcY~z7aR;3xMx`Uw0uqC4yQ~7GsxS?$SB%i$YUe7WF#n1 zfv#%_1Z)BYW(>V1-I6j7yd5G2eBVAeK4(~YgEoRhx_I7bJ+HDeJkzv^IK8DCgGa{au2&__thQ9*JUsVKh zr`(#e*_pr$O|}MEOVJs%5QvMIKc#tV1^vSC1y%6~VjwBMjplWR>rLK038lr=T;vsd z@$%=A(LM5g$`{unM@JV*=77*ef@bBB4QRVD@)tMcvKL@4L5AoSQK7FVA*+lsEmHnP z_rnzZJn4e8D*FXb$PP_eT3j~AA8tLXFY`;61BA)_hxqJFHp|59h#!eAp#U;&Zb}w* z+Am}}P+AZr89FoFog2O zT9TDQ{fy|fcxKx#J9RX2_|RoV@YfG1{3l+gILo${jnL;8vR%^x`T=M90%}Ir%#6wi8Yi< zlZEOHt|)xfkNb0#u>Z|qzM67_a=s3n?hznGk1;S)(KpxA+{PF2it!z&{{3beH~CUk z;tg2$UVx#2qP7=zYufmHC z2=~5z4)a-D$(9F77x;!D+n&9aTO!dZq8Fcfb$v-(L3(}^ArgB2)Mndi;4lzZkom*M z)K@a5QB{i!=$(44ACp`0mVyACt5ICyCpHge%jN%pBchC{p_+p7sciD>F9 zbwv_Rc#pbkzW``iQ8ak~HGM@2&8ByDrsPCxs;53Ock8PNQ6-1x^yyzqp9}DE`#w)- zl*rw%@-|4KLy7Za?*EbzB8X&|8S#!D!}Xh10G-(9&CiC}$*%cpBB>!#MtRxE0Y>?8 zX8A@1S@cGv1;xX=bcL160mjAkhyBJSNta0bapm)u#--f>-1}wy%mXGB!-55QxufsY z6N(b0q)n?!RSQgOmR*s}YIATI%<6ETOB2>@787FB?>8g+m^YpdJB2r#FE@Z|YUTpX zTiUnx%v<5m-m)~@;}RXWmE;Aov?FnWEjuw~`FA_8TaYbGakU%Gx{2JtR=pI~y~n-B zW{Cv-)W!v;{md<3>p{dSM)nGl318D8;u5)-VZlf6*U^qq#+N0p-4w0HMoEgkPP`K$ zela0e<5*$xUN*^SO2yoM$Yy4V`k-ozM|;6~wqMKCcHXEYm~P%gA)mvRrGhVEJs);bxj&!QPANT~UGQzxm@3D2y}18kEASWwW1j9qh9UTz^5km9YmYxzMMbOP@N>) zi6T%uUNfe*(bC4i)608HzgH&31FuTS9PhGl-A;diZz8o`neY%0Pl2K@xYluUO7NPsyj;kvoE( zVnTt!j8F6t&p>am_A$Uup2%Z~B8f5u-b+(F%L_?p6k}mTiqLt0my(!-mreqASv~PY zwg5e23sv-`V+-pS1AXT{1588m$i;cOSz`@|s96WT4C;TjfP|6x0$`ZoXQ{KeiJ;HX z?|nMwKOJ)5@{1r!t;$%qvM2r^%Hiok?0HE)5!9!#^>krp*z}ox=y@Rxi7%wr`OU6K z5cX?ZbVPs!OU^g6`VCKILvkApOn`^(J7gLlh~5-kefP29us$8W-b1RyQ{|G#sb;ozGljo z4aK8LOzCr~(5?``T;YX$d)e$3Q7{iux zYhs*XwcC&I{8XgUS&oWhLxry_pIuQ+~_+%VRGr}JJ|x%anAP2T!A7ZjFX`0!}F@Tup$%v zvjLAEwqHa9^3C0h;pX91@mrag-!mAaBSP%z+Qf=!{00`kBjEvra`nFk7EBKpO%^RS zp;%b}m+{s^haRWptMtd3CJ!BR=XK+i+x`ztr(!*>>ut*y z(=*LyO(oDXKCR?Kqna@ZBT9Y-(C8aTdqEV5vFU&;LJM{W&r~Td&pjC$JbY&^F_!vS zhmMUFUM+sbYJeMqo_y_fmw2$Hp_BCbI=v0L?O>Vw-mLiIeObfn;XscMwjbL~KLiaY zb$`{D{(i>KV<8`H6{#z|LH|mxUTv7pEXE`Tf&LI8I95s9<|k$a$8s#9uk|KB(dDfP9Q&Vf6w4&7zjNXp9^kaKVURQ6o54sae>|)2 zt(tz=&#-5V!_#w2ky^5G67w3No^5TCoAv17 zux#zC9RN*j*In~zOO3{_?SJ~{;G)y*<36drbAfGhPzAi^csB`e_}JsJtv=z;ZCgp+ zH)JDwZD$R;PBHa~GTNrRdYVHt!nI*_#n-b6ohNu1F8RXRLFMr92Y7Z|o9vIqtH}Ci zGeN0KRdu!%k%INkct1;$+%`AZxi)S1MW$Lc)qwt zQ%H_W;3~Tg77}R{cR4?;sb9?rxpgTZS(qsiZm)YpgJJ<)G1#G5=}zkg#*GH9WZjlR zG0rvBP{^ev*#|ozF{{xUr|}w%4PSSB6}Kz{(*iq-Szm)AJSdF{_~|SQy$Y*67e>Sf zhbBBbR){<1wfPR0-hqK3wTdyn9}j7-^I*1ncbEIMom-B930sVQ(FdT2En%{+2Mw%7 zTcKrZ#ZW5_Mo=+osc{|q;90HY%I&9stzt5A1_gn z_mHRcPONcJ=~Wl>GvlK5X}q+43GoXSc0M$4x9jnFUF$;}s}TZmw4?<~#p?Uc@Rs`e z#z9b^w!{n_@F^0i2D!Y5Aiw5KZ&W{5BcC4EuOECqS^2WFdyP-|ML@ugcwXYMjvQ1T z>>qqNa{OCzITfzyoFOi*_QqZ+ettR*V3mO8nt+cg=Kg-pVSb+iX#;C>-Lxcv{e=Br zbbGz)HHrlVt#7|<-=^zS3ADm<7}^fXi48LGwT#d4^!glhgX^|%ZClo>XSQu;;pb&F z>$dnN9vg5n`{sxnxeyXutP)z25z4|D_?R19x$S;U z3tGC?i}v$-;vdnL8V0P0h)MS*0(xTXn4y8}EX5<)dV|pH{fhkbzy9@P2Y3{gt(nfA^Icw+1=;v|o#+p+Y z1{(E7|GW+tM+EC-eS0nu6XhRsWfd2m9`j5$NLL~p9TZDA?X6`PuEP`KXXh@a8cc!c zZc|tnYZ~iiuIk@ZYvZFDha3}k_8q|6N})#?K+zN9F8P(~#=Cq9nk^AOYHwG4$aP;L&Dp*gvF{jbx2}5PrM&)yx~n8pF_AnUt(!)G!b1= zeqH=2k7=P~#LPwT<4n>c-A$4iPIB;ZgvxXb`cBA3oW_o1Y{qQz#$^H_FlpvG_6vT* zd0glb{ZG3BxijIs>_vW*Eu;dv9=<037LJbB|%Tm8dN3>3_+ zz;G`!MOHIhZgYfhbGZ90o#}IZ`ZL_)bDgDft!i?(9dpC#b0c`ZY|;o|%A z5_a>Fkn)r3^WI5?KFafku|V??!*g&eQv*-4T-Ou9W%;U-c>MYWH30>6js=zZ1x>pJ zUoOpC=nET=GWzKF~z0?t)K+9-A-iDm(+9h=c8kV0U1lnm|NjUkjqIK%E_h66U9ne z-jp^*mjmkaPazqFtQlkD`88uD7zri!0Ts7u6$E@>4yO{{gbL1v3T1APAaW&FLkZtN zMLSt}>PF=|S>&qZ4LHDR*==5F{A`h*Di}L~9ygFycNc>Z07&4d{)nt=U|g+N06xvH zG~KJFVE~)lRa>i9bdc424XoMcH9720F2$;HBdoG>tQuD#EZ&=qn|GwwVqk;}a42op zh9cL66V}np)wrrxLkVk%Fe^L>B~+11x5nWqw0l?0S8fT|+*yNHSkCK^n+w9Mr}@Q0By{a~GIB5SL%j*oe$Wsg^#u8dgUL z_9koelLNJ+VlS{_l1ZsfCfGMDgk?Hef8T4=@@>@r*3_ZXgd7e4S~Lx&*4D!}Z%a2j zT#D3h!D5z>EN*l2O80;F8aS5#Cbt7WgCX8(yf_ zU){C-h^$J=6YvHAQMup?$l83?0B)RUdf#!+ECAxpfM>#u4S=rqZ@XmKkl*SdBeTG} zvSPWMKy+sCpYo${@>B3K3L(FZ<|jhA$9ke-l%>JjV-uA1?j@dOW0q}WkJWw;8ZFTA zZIA0>j|;fhW3b0lqnG5hxQHT9LB1L;J7Gi!h3SOZtM$IF1U6`+(eq#!rrx5snQM!ODA!i%o%pz9j)C3;|%YoCXfS zBF2`9*_ImFheW*xj3Y;kBhbWbl^fw#f*0D zzXLbDn+zD73g{ac9NbCtzFF|?Bx`*mg4WhbvwRE^q-Z5?!eKuddjEc4Onx{|1LaE) zBP8kd$nI$W%ZVPviMCGH>CB0#&WTB?dI8$t3|in?r;^xDnN=!{r5{SmUzLUk@x4^6%8O|!Kh#SRu|9(s$u$2pAIj_B74k9sVdCX7T;G?WFL zVX+*4#{;T{V$Z#Wk+*zN89+$Sax5j4taCylkSurz#IogQT!ZaJQQe$Zc%^nciT>!-E?e&0HH zaA_${p)D>2DmbN-;#O#JE2X$Yakm7L;10pvT>}IU?(Xj1;_~vGGw(g;&fNcC|Frka zXRq&C?f-IsLgL-%D@@0kh!=qkO@~^_URuiYS{Yc{NjYULgoLE~T5X5ryK7iC6nGf0arrVuq4AfWp|P)xrG{ z23<6KqIldo*t{4giCyP>0>><`CWVslPh*b;iy|mg{o`}ypIF|%P^^(7i>1o!^`SEk zYp#ntALA+LW2nStEkDL9Q0rw|Bqeh#^`D1aH5c>1!!Fc;b=LTQm>2=1c+*g#Y{X!# z$ecZ_y_|5~s(E@6VCGkA$R76TV0>9w*t#_sXk@uEfV3Y{NU@gA3nBD*rY~%roBfY! zmf9u`M&){I>be3iextwkFxyDIh!1SXZxzR-M1M`93ba}3^C#IZCK&o-4X!L7#akU? zjG8wkYWzL83ff=s77QkBNv+6SZh9PDwRMi0O^6#e&(I^y@xkTTxPqZyXXDZD$`8*V_K5kW*S+mO+hvRfv#bW1qLvqF(L3VeeAMv3E1 zkD}?1w(W{jZ>}Hz%LL59e$~<&+N0yobt9fn+HfCb}KZl-ZjVvLsV=pFhsF>7QZ=MZ3$J zdlu~&rv?ou9ZFN5^DW0kN)Me(@+TV7ovzg$P8u$8TGMc4oB{wOX`aXQ%l_PN33;h@ zHt((Ol~U2twlUHqu?!L>U?(LFo3JHaeKDU{B~ZK=4W#+}v1NT0eU*FLn#L&QwP6nc z&ifMG?|7nWXN==)Z8xYr?2_qi-dmx~QR_TW!I6LH=O5Ct2agW$)Ba6tE^;uhX&78o zPMsq$=Pn}A&`xCZR7o6{8C!{VIF1JHN-u}P(!YNqGbV<}eL4(QJ`XcLhYW9y&?n#J znJ}GP@mT(gTyB!_Y4jZCFf{$!h<-N8ch!*^mt_D$5r?Y|zsU{%R2Y8o=kaX#Vdfd$ zt*$K`JPiBu9bNDk(%5PmeQZ1y{vnn4U;5*5X4=EEhDW&8^LnmJsCj;nCVI?BMwC#%X8C zzZ>RD$R~30i%ASkzaV~vIU;slRCz`Rh~%|giVW{}3Hf_=_2+itzTM#fZtv%^Z0ka} zdpwDfk`Ai<5hcaFfx(rd1Ifc;nF_dPZj=3A@Y^_skvD)|*;%uR1+keE4i5*wQ(}Z%pzb~k!k>EqJ zFs;Pw#*D1>(0^z-5OjOXQM5(GEIwsgH`c z9|Go6Gk2E;oA_-vO%SrZ8 z=hCk+-zfM|h20-pe|ved-k;WLF2KKhb&j8V+$;8|H2+~Eo_7RhGQ;#1i;~izS94n7 z{YOF8aTzS<4=a)@2|da|i(h}Fy1k6jU<5pVxAS+~_sUuE;jG2$oR&@sb(eC@nZg81X+IDRg~Y1H{s7{?Z=;)BXW4xWDFE+C7(q(As_ z(rLZChx_*6mz%!nMpttxcs~9_sbTA6YVmsPIb+yK&WJmTb5fFpIJeq|qaFP-RuN9s zfF70d?Ztx2GA=ug>+;Kt)zGlJ5x)W*+@2kl+J+OjsXOP6zWxg@qk+7%PY;(YnIBAa zzdWzK@03QOj}19vf5$}M4iwg^CbXdnFl6>+Dqv?KGXdHJ-8d6q9a*`##-?W4MWMQ9qtVl=o-8Ajf5S8--437SgZ7%+@Q{ z_*En>2ExoH7;na~%_$ftuPw#ary54aMf)s}U8A>8z4FZeyf0mo_L@ShJJO&Qu}SZ^ zo2-1s_X)|i#sHF4C^obWot4Pui@*LLr9GPXmG9-acSxF^sw4jtf;n&}Us01x-uv~q z9v2_a8^&ETbO7a;;JCRGW2~bL)RjeORd(2IE*gFnWBU59*0ASNRoaa)rr>ge(x0iS zbO|FC5vs^M+D())2`pcdMM&PB?>zG<3ike4J_h!Vxb=&FQv5#Aiq>M5*=wKK?hcsy zB}`&o%_Qe5nD*s$FrDWJ4DJ?vEF7N?a-u20xF-z@-c!k>{vgdSXMs&kI`LBQgN(}W z&#DxSU_ZH<3{AR1^_P8_CpAMvE? zWu1U5&6U?RpB(PG*6}Z2Blug&PmgjmAqJGYb*+G48!uD-m8%7Qr1PKG{)0`6Si%Bj zSYdjB%Z@8JO}n||n`J?p=gUuq{mN%Hx;Rn`v&7LwZCeudtlAarjnxTVC3Y2vYphD9MY8C=m3Q)SWLnxi249w!$W>hFhO{?O8Ni z1o^c2OsCVSBkaN|8}b50Mt&Y1_mHL7s<5TRGJ+;#2f(oeZ;{6} zlIPjHD4MpPUp8Xa7n^}E?jJN&)G3l*tStnz*R{hMh7_C@Jp;pcjK<1t&w+7es% zx(xrT@%S+3Ld3zY#cP47Rcc4O0TJBa{2o)wrGM%ZMPaKoF>~Z0Qn*Sg{(tI{7mvdD z5!?4i*6b3Cf8#EjE=!L8%-=4R?6|bx3>`na^Ke0-vTL*!>56vC%b_`%`=KyXenv z?FOT7hw)9DCyH1vW}dY9>+bt5&os6LHoLfLE`en@9?p3l#IDXJmmbg+?{-r!6%5}Z z2A>OQpPM?*`vAv()BJ8dK2v>p!qwX|kjIHS4)B>jUbi4e83Y zV#NbWYkZs019KqGmd@?~DB0mXl?iZ4FSM2bWM7#f_c@^O=M2;T8v5Qqb*3<2on0G6l4 zHCIah8AcxxBWMM{)eAsFMxm2L14sZ*yzI7X6rFuQE=fcgZv?O*0GvGRjz0*@)e3trt5H>THgrAs}tH zP{Y~4gzG>iO|(fz{|D0`Q>`ST=r~vdE>uQ@DxJVi7RZSku7e5SV0&V2f!kL}0Av&| zZcL^L@Err7QVTD57W(e-UC88$$LkQgP`6-9R|hpJ&)q3cTHMHFbB4kls#&&3#nZ&C zZ|HNHvAKe&^m!23v6L9sz{r)b&Z>AX^QUKp09KN=U6cQ~0nj4SZrs8?ti)ft0{eetdOHjDnlQaPC`;r zdD2O5JQxW+E+F8%P7N$inoxn*v!%1hWp$dT7H7tU?fT*iL7@#PuVga8Jc(Uj;PvF8 zHGiRSD%kU0H$5#Fm_Z?5BaTr!X9<)Q;*vhm7RR!e`8G6Dr#c=-JmJVJCKM3!aWBnW z2oIAzqckf=$Ll6bX;)05FVSv~;Lz0OYl7hvZjN&!TwW+LX*UmK=Z=k`J@+S8y)oa| zFG)NUmsuOg9tu2Viw2a(i)sN&uEDQs!_3@)X%=`SNDPt_3mg9wdkg3vJZG2Cyw0oq zgPD-8h4}`OMWFm_^1kcOm4Z`6gX7zvV8 zOe+mUlEOz_czTt#1kUydC>$Q~T1KS{3qd~%-NSs->3U>{t7i4Qn(1*8_1%r@p? zXJ;ul6vBEcvVpNQmZii2rBxOwU7#Z0q#}Pg7=k)3Z!NT=u#As`ceOA$%{Q7mF{-*Z z3Ljv>DVXb|RgP{5*f%T3SplGtq|~}qq+3)3h2;4ACuFUcl(OfI%w~O@PhTq(f79>m zDMw0^M5TjUcFtZl7ZR$vmys=4W!2Af{eh-DB@PW6KexIey00AYPoNUyTSZ({EzOZ- zF;`7VgJ-R6ReMwR{g$B&uP6vmZdX+Jil&HAs}c*Z_OLJGc&^f?Fe9F$w)mH*9y8cl zxbe)MMyQ_P50|PEwbxidlMQY9TDe!VXOVbDUAr@q{ViW;ydG$dU-eS1 zv3sBXwRV$6V`J4=V|5Y2_O@QTKWjX>X4AQeY}F;owSIP_p*guhX|5;<-`_;1>;{lZ z6QZ!-fjAHjP`L$;_IoXaB97%*RFmI^$UjrxPhYoaE-y+aW{1Z$wd{Xr1!Yto;UWI? zRIYOP7oyOEK^x+S@l)|V`SP&ef(7lfczj!>pMKq(CJf(xgz70^^n0ur7TH)EIP?K{~$o*|wA0@~msmV*I<+pt8 z(=5&{LNp=K;R_@I-*^ily0jw^k@y@qBcz|8te>rzvd}9lY1$tx+jSw$q<$T0@<^@F z$`lAP0l%Y-#!H>^xu3v0_JJ;6L*Ij%iza9t zHN>Rf`20Gi_j7qKTXk>3ovUkLZyKjW%0X|ERbPf~NUm;Qc1mB_LT{nMuVUK%&rbc~ zlKr-rZJ4W)lPv^akq43=Q}X?E2N>Eqte`CNMxuLrOq@l}*$Y5Du@Py~!D1rHm z$e5NTvu4j#j9`QR{TKyhZ;FVE~SXI1te zsyhe6tlhV(eWRUL|yO$E*q0-sL*jp}THqmhoQT313K%<=TgFqGF z4<%7OAmq=4RBYFp>}ZPuq4$f!4gs?f!qM->KPG=+sjI^$je;h%7bm~x0KXqj{(yHb zAv&pQ?S1?woPLifE>6`UNh$SkjTR@rHme2Dvh)|LS=CH?!qt~Ovtx_3%}b`_4tx9# zr;c>%&D^G8*`v?k(}MvsHSDu7gb~ntg?Ku}u)`rcXe+3s1FkoldFap)JIm{U%DN{^ zwk9pI9*BU?brsH#y`Co{oPYMWN84(i(r>%+DAk8E(!WZC;37Tv!CpO?iR(VBGV~Vx9MkYjnc*h2;cOi_dizWxz{^MN8}V zOB#WTD+J4hgG(!1lWU@je+eOG&xT1%=6Euff9x%>nJgQc8e2>KdTa6vPsvPteMPWy zMT~h>A(sQFv?>|ADiZ8KxU|Z&vC25HLUW|0Q~O&h_&2NfZ;cnfGc0Fr@0ZS+d3Yb* zT{V;Bf0Xl#Ter$x7bxWrepnYE+VHPk7o*>h65EjWUiVku@K-ncda?0Ai57im(PG+u z0N3UfbQ7>`6&PTS`O`cRd&{8IyykulCAv-lALh5()Zn(!I@;19+74l`a$=;naz3(j zAzH~5UpcT?f%R_<5bXF3ZQ6#Z5)+P&b`P06wk((tP3k>gYJR;_T)1g%v$H;>XspN7 zK-l?%8x1qh4)4-+^kg@?#cq**k3eJZ#(58y!nTyiwrt3jX~c`k$iCKlpVXxc1F^Hb zxU;ynkGbelB0lphNU4|`t&{s(cbG#*>H$g30mmiEaj9muf4=R3#2yqnYg0Q<)U|lj zSybP1!Csu%zYi_l#>6~`quH!q!yD39OO50X+C0F1b$rbIuAjWMo;7izrJ0$=$C??Y7S6pN{`uP$@966?^4A>LNXhv8 zJsJ}WUFe8V7$Z9V2UhM;O*wZ?`*(ghqNcqS@c4K={^N?%XdzTH_|;4PJf92J;nv8v zNH*H55S-gTt&A%=|JFGEZQys|era)o%Ov(`ewE?Z z)n(Uou9wNmk8-H%{INCP<@^5-w=$Y$uHc0HdeqF$6_AQeP*%&_ErgA=JaKmS8kZWc zfsMh!eXijZqhDxYUs6G1XAhkTRV4L&U=0xy}>6Y9BO0ck>T?t5NaNa$I6 z!MvL1JO$wQgf^9n9Hcg-gULF*c#%$;fNV)J}t;5U5kHkFbPslAk^&6SJ z^!A`QGCt4zFq%%AV{RB^X4mMISHk;O?y;O3Z^J>7NzXspk8p*s@#g-OBBt$?;Q3{e zEb9m#v=zXeTre-$X7CWNfR|76jrL_LBnx@reJGqLz@k%arXFNh_oGeQYoCZU;x#Al z;(P7Fm#v~`KF>@j{EwqE?dxL2--0Teni{~)i_Lus*5#L{f3?4`lOIiWL1~1&2?QP9 zltgg|C#@O=|Mm=Bjw`{udY3ce!g9(q^@3`7!aSOLaqil11#W3NPoOu%^I@UU1;Q*@ z?SgozW?(B8CerYDwmHE2&4x(yrnEw7Q#+Ik<<4~Ar2+rJ;`fi|IJ7hQuP!yai+}B} zsUgb*7u)N`>$&7S8vBWVW4E`D#D>0)A}oe}XlFen{?LfHQQH94dZ5h~TGd=&Z+)$V zyS|{AipO6>yp9*GGlh}oYxkKjub^FZe{U5kQ$r}BdiwJZK9du|TF19}=13McUaEz@m&jbJb6M{m;)8&G~)C+C`J%-hjk?p9q=?f8sWuiQ2T_1QzSMVWvU#P}iC9_CNL zt>-Q;wk`*>7Ql+(OVA#OkF0MS(GK8ALL;%@$w8+XK=Vg1KnZ%VUxU%xuYkzXwyQaR z9()20Xa$Y5tS`?HI>4G|tXsUQk%Z$t(<@i)@5br)Q~lfE>yL0U==ZOK%57V~JEnLL zSrEaA0j%_DM^x2K(SE^UcyNoCC67R%;0wk&AD|ffoE<<3{8CiXwo?3#=fmCJyOA~& zfp2juCkK>h81Gp}b-OMBptjA7Q(<`-8{-TZh+%k?rldpo4-tcsr)xIrc@UaXylD8- zcv~q(ilEEoq68U%9PN;h)Mjv2@a+1DxAHil*WbsnxAC8`xs8$j__O@Pz86bXD$NWt zK4td|8>Y%dSM!!>4`thRtq&Kw;V)aSkMSM6?hZA|TCYzH%398^k3CjzCm-vjKV2Lu z`<&nY_&3MTR!%7CAC!>9#0jJp<=u7G?x*v6fnEnX9eUXj9dhC;oxDysV8x+;i68Tn zWs|fB952csv;}w2zCgJ&JHneta^vr##+gI% zl}beQf|TA$hdz6WnN3s9F=?WBi_lGmY3|<6IqE&;M>;fNc?(Ap`w2Z_^J`%zv0n)< z*|+$@uX|LwpON}u;+=mtAN1UFOuJJUr+Pe92F#4TuOCojato1EsLhV}P5qotEuT@u zD9bf#<8}M<({_{Ya`aBWT%lXs$FYb|r)=FjvrjF!Zx~dgXcq^1xdnmJ21I??kJQ`8 z>OV1`_Vc`_{and~IxzJi=1F)jp%m^qkcs{yUx;Z>B|v*3%jS$+?vI69#?AL!xyjr& zh3x8OeG@rSKY%hhd7|G{(5eeRlvwE%>C6kEyhmfQol1)IjvI#>_|CF0VkHe;+b0oH zpGMK&Gzv_ zY*~eoc-Rt^)uJ z`eKZF6SYuzYE6kP%yI)IG~LMjPtk2LR}aHurjz?)__380YC?OWw|2BLnefm|-UBY1 z`mxAsP}W%Imc4zcwyu@%&?eya>*MCf+MdA#(=3n0gWB5qc`su{b(*JKDSz`Z%$86YjUsb;?Gl_ngy#dWp6B6+@rK#Y(GCOeV1nWb^*WjB>APcl8wQ?<#$iiux(3k;>&+4IG8so@h#_NFFiI{ zQg0%VjVBG^_a_vzchYrY;x%tw&mlhdUoO3mdXM!l-r`=LJ(fLotbTI(x1RW5KHYl% z_~^Mq;d<4h3bT)$3$uUcq z&=^_$-vD)(1^kQb{Mi%ynNa?$z<{@`0o*bH9GU@KW&wiW0Ab1i{+R%ws{j#!0C9mp ziJgG2l!3C~KzY|dMVUY)vp{9mATCzj##mjF1o?(o|Mp2Wu^=^rouDUk=PzKeF>CN| zaIk59un8*Id?wg(C)geca=Z%u!wPbi0lDOZ+)yCb9gxQr$SWcEM}835j2{SS`<=-c zf&#;Kz~R7M_+6*=ehRx^07E!R}9oQ-`Y@PL2uMA{c)3GN%Y#;Sx?hHEyhM%&AzYV7K zneyZd4!gc`{EIS;nF@d02}i??z+j8Ol8wMIkH8Iyz%Pg(?2RDajUdI1e8v_@E*nW< z9!V7vd7mFHixxqL8^y>L^;$NH$vlcBB#NydilaA*Yd7jGZZt1jH1|8%=nv-6{2|eT z1<}I2(W1N2Pa)LLY%yPCW4@ZlNQcD87R1Q+#whN_DC5R{<7A6fla2jh9;@+GMJ#IBJl-WF-mM_sqc`4b zH{J&~!H+E=KsF)BJOLDv5W-xL0O?JD?IwidCPuO)M$0C~nkPP`P!kIhlY0~4yNPMI zNf~TOf3jqga?F$RLXrv!l8SnhN_LaVaFZ+8lB;BsYZ4PXM}poYBpc=@7mXxy)hGYF zO77@Q?qW;nrb_AEO-41O^z5ee;ie4Y!aJ{1#?0ZPA@B)7_&_3jp#VO84PVrPuQkBe zsZu-5lXrKYPV7?;cTKs9*I!Obtr$0MML%T`C z6iUakfZ!#i6S}7p+@upVrlYfG0I4&ug)%6TGA6S!UNmNq7pBwWWs(VHQfp_voXcRu z%b?lIc$Jj-<|dPwI*YY0lU*+3-CpMVB(%}4G|{B+G7{6z?Ac%BvcFnnONVAZrBLPj zvK9BTmGN@Eg=QaZXZ>LJnU641EzHsB%hB7*F~H0H$(}1|0amxjO^IhDGlZ;k>0%k$XF^TNybVbAxI%MY-~4+_l(73POf=i$WVn&K6J zf!XO&1+f+d@u3BYg$1t3eFgBnf;7Cs4EDk-xx$<|bE=PM{p1Bja(?1_=zzcCD^@QKGnDv)g~0x?&UVl6(QJ*kfBAbjYVB@#XZ!;-ATnLq2m6%;z7KUmYd=+>XLT3 zl4*;Q*}~$<#*&%2lDWR(CF;@@p^^>ll1e#AE=A(8_NKeOZ z9u;`~73ee-t%d*bD3CXmGc;Aany+BgDd$M8yx9{!-9z8&^F6ev%E`;UxiMeCDB~6{ zJJl8~>oYGatPcYQYcU*^OvM*-i zE_$+OE()m+)~OLmt_K%YL+9&5J&GddtLnpb%3?UG0wE2F_zf|Z4axlt@%s%Bj>gE_ zCu3)0LULo4a6`U)V|r6#?rmc_$CJ^ah~}!UfyTZ}yxwrWK1seQm&F2sY#D2>qr1iR}^|rtDe!q3I-+WHK6+^+3?4%V3io`8O;twFNIm%95 z<&mVGZWr@Na;r89Xd4xxwIZ~M?4WJ=qLiAm{k1|nlU2J!G8)}o8waQLbIP~Y#GY`H-kcO$8`{2t;=Z<9@d&G@`PM2G(O0F=AC}yEhTq>%te>yj z->lH@3F&X-91yC-Yq1)rxf7Qc8R#dFM|chl>-N^}^bIHs<}>$?KnMLr`o?t!7dg#_ zQU+IPd!C3ejU}t0*MA1riib=ViWWSFjy->^AcoEdy8i7Aoj`}1bO-(p4BLA4UlWW> zDfHeejQsW-MuUww#|~i)j)>h2;}VXlG7l3fjtY5>kitgqsz=BNN8jF!e5WED+y5|1 zr#SY;bCeM_Hi#Hz8XV)f8)YLLclj{Jr8rLLImQbc2Um}N7#t7t92X><`1yWZRB__| zcKkDJLa=)L>)-^D=Y%ZbZgut~?Ni64WLJGWz6gj10BlX{9%>mHLo zVN*(9x_=K&9o$Zu6HXhxpR!h*PVktrfBFDcPW~C3ZuOXUBb>?oIO3%^({(%T2b;mJ zoDNc)v6Y9FV0mBQsxuR>ow0+m&^;lpQ&q}t(0z_Z&Q>?Je-e)&$q)CMo#Ct1}Bj4 zg<-hjK=W)5;S+I7)z7skdblt(xY#1SFgrLsvADPnGxj4`GHzbnC|N4}vlPU&B(qPr zOp}>=s5rArw|q&5SF5=E{(kA|aM_V>`BrcG;(i%R%mZD2`rdj4Utettw(?bRg@Ah% zn!ZfDG>Rv-N_8aJY`v;WxB8WuXbtCTk#1>%l75XP)e}u=jecpCwRDxdWsSgkjaz9w zG;fV}Y2u~+IswtTz|!!$mi5mMz?sAKyW(|;hhE3&zDSL$vv$;!q@ppUW-m}GG`w)@6+SHAt z)crl(eZ<2|{loqo?Lm9kOr`z-#Pi@2ifE>@<-irOKSZ?LM}O$Qa4q@%ka9Se zIyJR)=z4d!LO-=2c66Y7w3Rx!mU?81I69!8+<7?C6FK~JX0v#te{4c~d=)nFx8+#P z^Z4Om;fDT1Z{Zj-Z5-fpq9SsF_jUn$`9v+{goJpU;M1wp-N}oxd9t)q72VTx+LvRL zkEf!D(^t!L^af{Av}dg0qiBk_C;iXIXWmYoUtJeS`5`Z%?T2du0EAp=s-KBCXuEWcJ zZI8tlvexG(a-5pPf0L2_M4NDvM$3qvSG6X;#~wjDyWx;Qq9vCDWDKuj3@$`HE^t#W zV8p{##~0CW0incY$%cS9xdsNqTU@*PxoAE#yPK-;tM=BbQiGN7$HK_R3s}=lnfrCp za#`{`TbeTBr0>1#ay?qad#AS;GAfND4!pP|=&u|B$j3t76#%>vzz|Wt@`-ERhkNT7 zeW&cF|(7z4<|%>!gbM+eeTadVSkgV2R|p)oPMSoje{B(ZQ@#zb^{ zG+Z7K8YC7sAt?o(i=Gev4}M!iM#cd^G`F-O+uA!iySjT&z5j*Z4v&nEJ&sRIPEF6u z{=dYxGaLOWhd3=P5?zL&SHOWx0iBPLOJ;;^3xy>@ zbovXk{EK`zHdu8|x~%XA1(n00L(l)=Z~rgtt?sI<%%2)=F6=9twy3YY2G{@q+M}ZD zB?)Gm#yevnwMj!^bZ_~3X!;C^N{mHUH= z3>FxKPg2D({rH^85F1FwLwiMY`t6J9YNS4$1T_W#noJ!zqB!zR+LSUwdGtg6Hw*wp z+TD*i`RP(OIqlU#gW;Cu|Hf~XTU+jLucrG%hhOeD%P?ZHuK12{kZ8(_Pp|hf-`ex3mw2f>wv~Ytvay^k_eH>42dN%q*UWBFOS!c zcz!~sPWf449GsYL7b)}NdmITS$Pk6{VFYZwiIV&AiWXan2=ljCfaL75o7QgBa=c+M z+g5^cocUIwS#H5r(je-`Qk?o~(0Ga`+WT!u&7m_ez;~7@ib%8VjX!|$dnHQ{grJ)C zC)FESJ79VYEp>ioCeNOkg95poOxpWY|8;5bPx5vskF}i)dI&RxkkMvsfTk$yf}$o=&8vRT>fe_yJJ;_=9bVvqbumpVB6dn9{iBL9|&&gsPgwM<;-vOtE`d#~!D*?fTfeo## z@Bx7@W>m^4s%PkMsIS>!2ls)w4ompJNEN)9vDP#n`li_ln;%CQNI(*J;yAAU7+U0j zPMOk{|1+?WwV)?7eHfu1;Pz^g$psNBT4%2RMRv{iubE66rN6BLbRQNhMjHUI5EqqefCuc zY>bA{5JPmP$%~B%`1V!J;2W7hP2o=_FHfWAuaA3b40?Sd!UER-J@39!Z~Z1keV*h* zVPgtlc4Z0`8X03aYyz*+^<6V#r@)eq!wyY%+!w|vpA&W5id^Cj9CNbTP@ z1{DSyLGa%cNC0M@6+#os)n=6UHVZx{G@Hhdg+RX2f zSkycUbFF3KmgkPA%xnqag?lwt%Ik$xZdSm!*&e9)&iun>0eIuLAk?9lX+t>82 z0*7#7i-ombDJQ8~tO;fYm9e%}G6ph?YB_K8l^@yc*J4=Ad{LX`FG=XMs?-IUqVYE% z&*!On*!tuoDTKeK@ujn&t&qG@L63`$=m_XI#NfqX%Eu>|8~Ok9Ol|Ka{GBV@r_>nV z-=lYHV=|+8kbQkJOrT>|V`6jNDB1L+1eDNe`}t4CBv}zt!Td~3`Hyz+WZx~j^j+jHx!1k}4@}QRe9;2T%n>C| z_-#3?#L5&B2EtYgV;SI@XWxhyG#yqB&Ylp+W@Ht4_Mu?1Pkkkf?_Vo{Q`Q@9r%a*HH zYmI>GA3}JN#pgG=Hm3D0M1lKk*yT#1Pp^5LYU8KV!M;v`=`Wa8R>jSu35N31)1Hix zE_8M$pY3Ec|J;Q6Mr{_bQ$|cFeg}eaM?SNfh#MVVJ!4oU!((&0if?jufR+rFMSv7G zUeQp&YJVSsEjE{NwWTsaO`(QiQ+cTw*1fE>d7pmry;^y-6-cSq8RmOMQQdpc1ppdy z73+p9hiaJuDhS_VXGMHh_}e{oi~qsp8I~ftj$Mm1QT2a%OcD^Q*1O+)l%&$A9xp2+ z@-fEGwbCY{NO93mi+GGty8-Jg z#bz_ceM!ZRnfw=dSvTt0RiC^@a==#H`s^ZQu0 zoaYMP@2Y5M^|lh8&}`cebG~hy+Qt4W;+V3c5Z*Y`qjXW3RrK zaf*UDNW55FH>n?PTDly%yi!=)(5tKKYvsCZ?+`iJV0NFm9lGikemLC~zb!0jvGJ6>b!3F{)iubX`|sW zCa-sYrlEWDOyqL<7Up@}^mu<_P6=c z?706*($4rtYk0n>cdF6wf3F$vaYpLg)S)ZSpL-_2ciOCJ%Yo4~KvFYMYKEUh)Aw(; zGv`i#5|9OcX=i`wE3Fx%?i!@3;kOO)>;LGXd==C@_2;vu{TFbMiEFS~0<)}uz&BKo zHBPXdCPn9^^Eidg_k`el*94F|>W4PU8*@2U}DGTgkXk zQ3eED1@80&sDZ)Igpj2Akb{VTzuSMd)Irgdpye}ACZ$aRDx|PJw78xy>{_q`>8pnhT!m8Xc=*ovc4`i83ULQembZ7P*rU=Dq{fjrXT!2|ICxIBU2*sYowe!$W3Z zV`@Gv6yf8*(A6EQZE!fIRs<6&tee>zupQnB_P*Tq#W0T`6O07phlc~h&QTFHnxSkU z=sTH62CAqJ;D|$%^Ygq&!fRNtM)<4QDDGXa7Y*TX3NUL!WD!M#DtYAFhG>yq*Efj~ z%-6vm1VhkVqI7zqL_-o|H0$}&Wm2} zBcHstxDqwJiN0pw?pc2qEvLTOlqc=&;Ce#sZ0xvTm{V}FX94_4dyDv+qI+pKb`8hX zgloINx1OxGi9zjdF&jOp(NvyBf@y!Aw6`)dAeWxkks(j(E!W9z+LQIRyB+>VFdgeA zz1cY(@5y>A$Gl^n+yTs>?t4k5{U7P=zqWLp0zaCk0e3uu;Xl$_dWxioRP#(8ZT{wl z%=g+^rl2HAenZ~ItREF=AD^7Jd-Pm;nHIfSQrdR!sIyI;q_@Y6nXmU!r7beW`m(=k z=NR#3gEX?gxo5LltH|QznkD5P`sQ@krv7fs&@0S!5X$Rh$mPDt z)e*{ZYRvP*BecSEL}yQR-HWxK%L`7*b?!?H=*xPN->Qb)hi2zd7uC<@6=@gM z%;h)S6p0HJA>@k6UDaq(VJ&s_0K_ToP6!mhpI3HIb(>JoV0e)8T zg7LnRwZ??$n-WN2$$C;LioIZKu2h4xbdS2sSgZ6%J9kHsz3j@OHKQdRI=9=p0n=JfnpkuE`e2U;YvFp= zy*gRYe0>f}eJCU6n_|Ntu$wXR0 zZjlEC$Y&{S)oiUV?$Sy3irW~hjBys)l)-I`p6zuT$TthA%+Pl36e$NrG${tm&m$Oo zPt$G0|5TR#C$jXv8%uYcQdS*O3}~c09SsanRm7n^i$HPmpg{)EpUO&M1E{E{7@cl!jAw6LN^b(9H))|a z<*qlCwl7_zFY^h#_3X<{>B~p-6)yA@-}PlG^oFFMf_3{(uzMp?K%9(*{fJ__`h|WZ z?LfQ8K&S3Nx90#VWuWg#?YJ=TgyD|R(+-Y_3{L0{PI(T_qzuj>1{W3vm+l5vXouEB zhBoMQhqgS2c2b7+5JLwGLq~T*C$z(7BEx@mhc7&bX|Ga-ZxF+G3&Rh0!%y5Ty66a| z-UznW2oOGkN7Fn)usA|=KSDw`N+vq`TyOM+*C-`?l)8D8mU?lN{(h8!ZtRul*c-hu zX0I_;_!xWh80X>`nEQT=hi;rtbo{;E_(!jC0r>bK1KQ^Ze6@jbF*v@cSBJRb#FzVV zX|9RSdJ{6u6S9L7a@G?{qLZq6lZw`p-@GQZ;ghK0I zbjs3t>NkAK>VC=^jxPio_4k~1g-^RTPkSy-+vrYH98UX-&IIbs1bfZ=93q%`;nf8| zq9P7v!U<+0q5V;dsBp#E1h3g7y4l3R*<{h#H0#+^(YbWRxlFyee6P7ez1iZ$*@DHn z48r+buKBXVxgzWNYOnbU*nIuqOsL|h^Es-SVF4k&V7`UwfGu>DEc7(@_a80{!WTxY z7bXT5CPf#g;EOZOi(`w6!~b6ZUl^e0?cLG+-QpeI`=gr>f?cVFn z-R~{m@$KC${M*L70`TCAzK!1z4G)B@iTCZ_!z>s4ZK(uq&;Wkm0iNImZQ#8);6M)E z%L^Xi2tMHg4dJ|4;TXQq9KMSZ-r&&Wa&^;&n{ot(f90KF~0(ED0b3 z$Rz+HAOa*XXe<97<1g&u%VH{b@&U~u4FvE4FHjCgp5%Em<39e+PF^e-KpY1V03^@> zJm6N4KmgJJN#)?>dDG)he!@VGERYlxz)(#BU;#z`3}8MEZ+AHn6$V(5f!&#Z1Hrz8r-peiE3=;5^In*sneFzLOJKrz5N zhcE&y;0wDh08IW1eZ>pKJ`$9kT^>>E%`)h%F2QA9EL(E{#y~Y|ZtF?|4F7@bxeo3I zfB_b8NEsVI0?5(rq8(=Qrc+ z;m!;2-s^|J=HM;>i_Qz_)$UF6?gh^aP~sR8KP${4?fPE8{5~c;(dNX!6AjNo>h25Q zF7bQ<3_LOKz7Xq!#0z@?@g?6b1YcUeaPrI2@f+XHIv*wnKMW4i3j=@i6aVux9}Jl= z0=@9?Dj)F(pYpx{3kCoKC8O~>-@nz)CLezcD=+j;Z}LSy@dhvUV-NK)-}EzY^vK0O zO!O>0fA#77_Foe80DlMq|Mk6q^JNe5P`~g?|LNeK_GE7l;}r=kzbsY{_wXzC`SBn3 zP7K&$QBx4i71I;7C_=-=?q<K`|Pay_OVQzPV^c9?h6nM$Wi=H5B(0W6t@2IzOY@)4-A>$0KmW<$DS;SZ~ftm z{q~_zhu;fw;$FT09e~gQFbL4XA%I9Y$bgukI5>l7;7FLUfG9|KX+b&IF@iYr$SBY` z=+OAcg9!8J`U)E>J4;(@dke54v4Y#{`wJW_JWO0{e2ko|yv*F}{0to}JxyJ0eT|*1 zz0KY2{S6+Tt;;KJUTJ_R;1B=d*s5e>OX5 zHm%yVY}>kh3pcLZxpeE=eTnuiumlihjO52rR)BsBbQopoDua%ZUb_NFus9UR#5ykV z3*gA{BgKprC1CJ$24=m!BATfdGSuj%Z`m~%h&{JZ$^M)_KbR9=Z?mRfF^)|6b5=H-`Sj!9;jX8IE5nQ*14W}9xl z31^Z>#>pC+bXMMpXP$auXy=|bS?On>f(}Zkf_)aM-F%2H%4nmGVt44HC`Br1rIucb z=}&WJirkW%ehO-+qT)p9sMr8%YO1QP%IY%e2&9i*2?RTB~h<&2|fJxZ<7|?6}NuYi_#g zt~(vN;iAfJyzZ%oKv|3+XC>x3NOs?Zr?T>3%n3d zOmW3Gy(&Qv67;Zf#~gp`amXN#JWIwRpFA?k9+#|f$}X$yvMeUQOf$?I(=xNoHOCC* z!xsMxbkHXGJd9~ZAB}X25(kef#b|YcMfAeSf?2ugp(>?e@EZfBvxVpCQwr4}AdimIMZvKHojB zZtrWL0|!<;|NW1F+LK`K6!nSx{!W63ORsGu__88d0 z4T>&?7TjDB?0C;EIa0A}Dz%%1+7hl)UVvq{1i+P38}cmn7yB zIax?UelnSH%w@sq=*wtI(@&BtCLpc3OJULvnT9;%F2}~qT7J_l&YY$>&xuNC_ClNA zGo~st*-A35@s>DTW;f-i&L8Fzo&4Mfp|yyif&Nzq}>GohhGBu8<%P_vK{q9iRTf~t8@c*0Vi7+q;aWeG=dR*Rk^Eay#Z zYSNtQbVMk9r!3gHJ$D8)7BN-mH;dZQhO%X(J8h~{(F0IQ3e=|&%jh=iMoXjOjHVyN zX-$~=)Ub*bQVIl*RoiJ%s=@-O*W>0vJ$g^CE-0j9t*c!L)YG97^{hcnt6uNNRRr4e ztx8>sTr@UhnG;799e z%hT3Wv(GGSXI~53>_`+CY%OhMLF-tTY80lFB_Lc0OE=fXcDTfC4r6h(kbsUY0uDl1P&%e~vOx4Z70Y<(*$UYYhxy8LahgWrjN%ikIK?bpaRE~&pBBq_#WY6oiwDt@64!XgGq$mR z2`K>~2qDNr7BZ2KY~&;td6`HqvXh$(=tWC<(uQvIqh%3kN^2U?mbS_xCIISCi+a?gF14s>Vd_+?deo@-aG+le z>ts~9ShJ3`tpi=_U*-DNyvFiuDYjbcUJDz3Z>P3< z&hP%R+uXgPd!Off;3mcPI{Egvgx?sPpli71?ri6UZ%E(-ued|kTT!1Rj{?Etcm+HT z@{WuA;~l?o=N#_vblXSe7H>ImBdk%4JNKPV$lGeCIgFXU$1OahD4{ni!9{ z)SfF#`GS<$2MYSopKeNn|3i7^0`_Mk<+0Lc03k+xhh4vL8HqemWO@q zPlfqayIyCr&m8Rm*L*VD?=DuiqqU`SH*nv{Wp}&_ey38u1=YcKXK->ow|@tE;0F(R zlIp#$ddgo7{x6j))^yB3%<~h81{=^>jwC82v z5e8ty!yU&P2H~BNx^#QbzV^g_r0N6AdIkTk_cj%ij8&LcrGJW^Yf3Vxv;_%?#z2a%kc+vxY`{=L#^zF`m+o6&CtbD!i=}#cg zXMY12J>RE);U|Ar$A2JVegx=%>|%iWgMgy9+4ETdV$b<^RffJ@iAjo_q^m|j-Lra(*DAgy;BAqym<2Y@sh%MDWNRH_UC+KJ??5K|KC@qOs zG42SD^_Wh(I4|O8kNcP@^caig$d3W}O8`kN{wRYy^pr4JjxF*(mzg=#Udh zO$-Sw5J{04IV0XSZykAVv5+zmAORvtk|Sx7C5e(I`H`?dZ{Q}9E9nX?S#B`d4KfLC zG}#RpsgXM=eC}sN5$AqD8I(ZzZ~PW-9Z7Hf#$`i;k38v=gCdD%rj&9sm2Xp&a8{LU zla*;>4ifp4UpXd{lyC}{cs((ek7I}kcb15AV`jOQY?*jMsZuD1auVk%773Pli6vwA zfuZ$BLrH{?wR=bOmr_HQDfdc?=P*sFmyPKr!B}AfC53@$Ta#%`iAb4McXaRFDncp>>(J7h0 zS(pSCg~SPm708*$xt-nVo<-7}OvGH&d3DlRXZG2D)(M+#)Sk@JpYI8v%;|;nX+ZeN zd%o#{;t5}Qn0=kupzIl-5gI4)iAf5IngvRr_nDzcsE0h*p@Np2wGt|#9`c>_Ii40; zp5_^xDC$ba8HVrYnM^sNGfJ8Rx}X?(qtltA2Wp^VIGa@1hN$(QxNw^^YNQoHp+Y5D zk-4Mb)T7slU_n}fGJ2#{x|cS3p*otS8rr1!=bGqweQy}1;GmOMYNo_dqBy#uEGkcy zd7@0ZreX?%atczPX$@t1rg{1sN&1|-MV`n7o*TMlKYE2y8eT&hA0euzi&`IiI-gDI zq=M?DIXHn2I(i?v4S2cAsGABKj{2tq2B%AEqL2!uzGbM`Nu`{ssyw2QXwsUL>Za$( zrWrS@M(3dE$)L4lr`4dTtLm!@whWWXhMx+ZDk`dKdaP1JsAO2FxY`Xx`m53U9iCdN zWLK?wcde#Lr$P#*b4sR)Dy`w#E?AnZT3W7Lx}}jys{83p-Rg!ADz5RWgu6O*uR1}o z8mFP!qS-2#?#ivVs!;Sg4ZSL_1uIhI3aEj)t_ZrC{JMIZiKz#=iPr$F&jxF;Xo|3) zs<7#*uIVbQb_lTEs-6)$4&ZvRCtD&JTdex3v7_p-_6oAwI;A1|uHTBND2uZ{(y9l? zuo8%`Dax|6YOMQ8v;Qiyh?uH5tF-z-twYqK(JiPwu`w8Ml*d8TV~sGGXb_>7a7wKp5Ov8xxal8m!Ri<(x9vrxOYyJ@*A3%kp^l;*pz0KCB~ zX~Sz^s_P1?#=NPPw6Y7mN&>qD(!ABnYH1d|**mTRyD*Haz2A!r-m51S3%=z$SJQir zT5G=RyAtA?C+<|szVW+$kjpUeE5G+US?GH$48~zB#$jx7rZL86jK*F(#1M>3DJ;$F zOwHZgQP*rJ_?*uZ%+0_|&m5YzJ}AxujmnYSfln&R8)eUdywD82$?;r$mAuSKXS65s z%oZ)h`&`Xg$Iq>d!XS;(vgWncI?|cf(y&+2WEInFDecC+%%g$|qX~`B%S@sIP1E>` z(Fr=wBCD(uO(7rs(*(WJeOc1h+|(qU$w@8M+G);mJgiUc(*9htjXc#|9l2Ev(aXAs zu4mK}lFeSds6pMZL%r2SJ)NBEpg4) z*=y4sd#=J9v89^T?OfQ84P%~c)iLY07Rb@;Od?1P+3YJxL0|-)4cebA+M!L_qiqDe zDP%_=1gg#2tL@sY4co5WmZp)~uWj40joXTQ(wkk_szVOD0NlSV+`&!Uoy)q(d)$uA z+kAc4J)yh?^xSn8-AXvq%PqgJf&@s=0^7~q-R<4q{oUD(1hXLC+1&t0aNg;S-s`R2 z?Y-XOEeqw%-t+C=_5I%2t$EcoZQYzr*uc`?`^~m{eI^19-~oEr@j~DQ{+7DNBpAPDwF6x~w>50zhLFn8@qcIvIl&$XSuMX?6F6*=I>NE_%H#lzYgrAi|d4(-QYa|+noU|knGE@?9I;X&+hEe4(-z}?bS~0*KY0Ej_up7 z?cL7p-|p?<4({VF?&VJI=Wg!lPVLPO-YqcR!@kQb0Ppio?-=!N@Ar=H`L6H#&hP#1 z@Ba?)0Wa_aPw*H8Z}10?@CmQ*3(xQk@9+-~@ewca6HoCN6>squkMS9=@f*+a9q;iU z5Aq=|@*_|37A0@;Cy(+euktI;@-6T3FAwuEFY_}OPxCcz^EZ$4Ij{3O&+|R+^FI&t zK`-6>kssaWD6CPxp0i_jix?d9U|-&-WF5@ArQX_<=9@ agHQN{Z}^9g_=&Ih2#e47jsNz60029}9H + + + + + + + + + + Without CUDA Graphs + With CUDA Graphs + + Launch 1 + + Kernel 1 + + Launch 2 + + Kernel 2 + + Launch 3 + + Kernel 3 + + + Launch Graph 1 + + Kernel 1 + + Kernel 2 + + Kernel 3 + + + diff --git a/docs/examples/te_gemma/media/transformer_cuda_graphed.png b/docs/examples/te_gemma/media/transformer_cuda_graphed.png new file mode 100644 index 0000000000000000000000000000000000000000..cf22822baff5b8c19a377c5d4e0d23d3225b0c8b GIT binary patch literal 369694 zcmeFZXH-*L*9NL58bO1iD7}daBE1OGQL!M^3W9{*#DMe?0TGlADhSdM6jVy2M!IyQ zL+GI>NDV~@36hX-SMZ$Aa?bmnZ`?8NxPQJq1|vH=JA2JF*E65lRtVPBzQC}XefPF) z+Zfa@s$JQ(jUKXX+YVJaD7d0K)sF)&+nujmII}IcopTnv*-4s5d69cIOm~9! z^mi}bbl$d2@&oN}`*8~c?6z&2uhi8}U-N*^Q{0~)x}H|aQ004XvtnsU_Q;!$a@I%R zq~3q-%r5odV*BYcm=Tw#(+3|Xy41u;dA7&hwQAsGjDGpYgB>5`E}Y&utvFV6d+kx% z(DjP5{fncYD^a-0y3s>7aty8)x#d{vQB*ef0U3ThUNW3f14paw_MfhI??tNa*bfeg zt2=(Wd{JS3`+nNpb(;Hsy3D)qZkHM(?RKeiKV80GSj)|Ldh6i_e}A~z)_t>&My8L! zejZ^qM2XidGGXh}mx^vFvEqzBJ@KUw@qUn!{U5YTk5>k~;WetX(?5^LxdRYB;(L19 zTh(UsHBWlTPs6H~n(Yf3nBd>KSupbd2L3-){~x%%o0Iwnfs}hrBg&=Z#j*wUZQTbj z#BeFA*1Ns;ppS2N=5D{VtE@wVmF`b7qrG$=LaNqr%c`H4&xaD6KlSWD%ufoW*$?+k z99qfC6uLe1^aJdMz*SFPZ#wAF;fTmIX-vZUMjiY@8n&cD)La;A(ZLaEMSkGQ)>6v! z(5N!?qi(VO#SAxwW03xjU+P8$9&XwApCi8H+LH6>wf79UY!nz|utoGE;XT8_A)TGO ziwf^z_t{(S3S!2zn=@SUu#K5$%)2cn`+i={6xG72!C05vFQy^%_)1TS*b{QK&eJ@j zvmZaWO3LM}slWJXM%!!yx0vVXA-I37&)#d;7VSE(VJLk?%u-QXLpM9Cn~;isBBPkk z!((?svHS~@>v^qd@A@rW6a+j>I}+&31$7)96*>aq=a1TzJ$hQ5L&;yd#Dh#ZNb?nc z9gp3%7)rg{G`7FPsf=NCkmT3zRsJqhDrPD*&2U|#K=E<>!9y;1jK^7LqElG#ooQle zSYCj`yeE^Ty3&gVI0_-`diCj1BpAXEXW?YF6m{k&j@w4#ZQ8oX4?_pSd|+Yc@ke*5 z+%?p>>>ohPvKnXGf$p~sL$?nTPhXx4>^{b}rdQ}kf7MM2(b>7)W_=z#@lhtZ*j}vm zd6QQ`zh&5@K6;u##94X>e@-%4KAagXW0%n6yS16m8`#0HRb%#lTbx$Z6j_h3m;*gcV#n|8ymor zE{}H8rF$&YTHJ1wfm;%@BKSmjo<>^OqvFCUbavDH_Fu=A2Sg4Z6~mlunG74JiL zY6-hW(}lc^wYbx-D|xiVlO?zDKiFHe2SX6EweA4?#@)_#glb0zftA87?p6!8cI?C5 z5cqT8dC46j?aS8=+E+i(@3ZcRodM7}q;ZphSRQ;F)n6jpm+$qtc%u_>F{YR^#3(s+ zIobHg$TNO-+cS|m`_H(7j# z%vz)(8^a7=7>X=C4!{@;&-5M*+lh8|$WQbgX`a%H8y<3c>HXyfw|SbQp@ZanikiuP zSQLX{=+L-G6*h|bOrA=CN+dIypvKNA{k~1FXWI)Ugd+2GQFN($g7T>kHad28QzEyK z?@;gYpc`e^2gQ7@jU})z8)F zx|79XkGreDZ7vvzCD!@hv=$~eqHyMtCJu)G$a*61!@(ZV?0L%ph?_mc)hD}=we3?f zHr41deh|};ZnW#g``$pM3d6V~wqi%%!_6$A%kx-XVtz=OBFb&?oFiJv#^a$)5V(T1^0y9Q@HThYlwa*QW-g#);@MF ziKD#vQ0O_4s^i0b2y|##af=3$t?TKXN%{jF@@M$C!pe(bZ*XeGWPtk>?mrl_E6mlG30}>3OdS=ludh|`3=aZ+t!Gmt zywwpKCTb}%XSM><4?#zm*WU?6NDoY&WJY^4Kdi{r$2BFRrJL=-Sden(QQnq1dWnPv z4O=vB+$2`6kjj^Iwt}esa^0-Uj_cg#VwFqy(&#&rf{XHJ!>u=u?jaWcqXW1yAv3^? zlX=fzGUUOKR{y~8W;$1=vyk~^NIS7`S^wH#d2xMKvRTn|-b|4IIN^G68@Dd-y=$=x%hXZ}rFqxJ2TBfUEjU+LeSxt#V%i=T z@hW$N_d2$yTGaVn?#NtkAA-!+X00qsu(Z=*u~hG!SUQI= zR-aF?s2+{@M+bAe9NKHf_A-B(@D_rM*YBT@VPsh_foJ36od$>S9R83fd*+8&-WP^! z{*x|-$?FOIig|lW*+UF&yexIXCU>p?8LPSwa0JywDcXYm5S=e}M zLf8QI^6|gu9lWXznQr9?>f!yo7QzrWVmTag%N}A@e=|}7*gVAiCb~L#Y)0lAD}%kO zDaS(oNO2o$3(@mgqn&Z-<(^c#Laf=l>2Oyc4wMY{9(w|l#>hby z%%NSVmDnCqsK3#aS}YbGn0!_Q~8 zINI%d(cr~RG=LXcuEeX1t_zQ5SH<-;Ghy42`(+4~1*aZ8ZuIHz`aBO74@PC*7|75af1~GbD3=x|b^tyZ zoKXp$n$)oem0$P3Q0b7hZ~X(kb#p}3D?b3|{sbDwe@vF`nmqYFDIts&J%tfG@}?s# zDKu4^iy)7~UD>j_qN|eyhSu-2Di$x(H2(B8J`w8^Y*peDn|X?0-oYdtgV!iE>G$he z0UiO?-_QE;eIK@(FXZ#(~Xkz>qnlI9e-;17?T4Jxf-E6 zKgpi{#u>?NUQW0ZgtkAkjNFnn{Y^`EgQ#;Ft_`l2O|YPG2m3r>XE+p9ykE$E+E0H@ zguA>Mo@hRTP#BkpmmI*WDH!uDzs1<7_jrnqPBti{u++cdL3#Se=f@DKVuJGs*`XwJ zT50Cbqwbn4qDB>pQ~=`H4`R+MT%4Bqpx5)EKq*J!#xOKE)F6`s5JuWu=JDY<-XW56 ztzD|*0eixgM5h2Vv3^#qBrKkFBFFk{xhqr17+Hw9yqQPKC`QIDj>*V@SZbCdutnk~ zD6tsV(xMjqV%Z?9Q<})x@}p@+Wl0@CXneSmP9c8X@1~CKRpZkY#VQ5|diM^;aA=hb zPYG#zmJi)3^yoz)?)pYwLms!CUpV=<3FOqM+Cz-skOF&u7dU39PLg-JIN1X*_MnUy z4H|oi>qSPV&E8;iH(s-d4aab`$qfB2wDA?3*j}2@_DG`!{h@0Lu(ZAk)GTc>;xB<6}fT9;p zgrC+%dYNYv#vd9cT_^|#7V3c;WONt_7+<4?972v9tDt(&D8mU0;{ERA1sc(zirKW{ zDyA5q7tDI$+G*$!6BOz_4)(}8t=FzBy*jy3KM{ZJaa!{s)E12O48sG)CWFzOGVqq* z2AhZli^kE6U)=r`W(B%vH&b5Jc2V?pZQ$^C%?v}=LAozmdhLX>^m~MF zwkvB$e`RQZQiVjH*K$oxd-~w{U-2AbXw*>30BmoRW_uu9Ss0`lAK`XFSM*?x7JrSJ z8LH)Z6RCLlZLQ6_X$j^v#YxvhmVO(@tbHoPKs|-X@4~`8M8w!bF*h!8APs-b(QeN3 z>JB~hgAt2a7yN9|+vSn*7vD3am59_PH?Z!r_s_P+3GGJun3qGm>|OQQ{S%3+4f?qH zJ!ttzD~<)vu;spj!=t72^-GjPLiVmVhrNP#7I0&2-nAq&dEtjT@(Qg1m}B}s_2q_j z^`Mk=(!yR*&FZs^Vh_RzwD7&bhM2$sCosN^Yl!ZjDdI+%^BN4#)kYF|97;Asb;>7$ z0&eCC+DVhUPzXV-`GsE{p=F-Gsx2GTiQ`_uF8>i|`FJ(_Q3WTUN(GOp8&=3NKGP0| zRrcucfE27%ko#jy@+hlag~#U`(_MGc0Z{{ft}~Q<@vWlDtR6u*Uq!rr*&|0G-rKAr z>bvwkPm~<^XrMkMeE?2Ax}V>PTv2=^#`^~6a39YCaR|?rF2{doxS7!)Mt`mb#~K7^ zTFsE>-2T6We`a#32_BiSTo~8gF#724IEgV$wSl(Gcg%OB*NX@rkQbjKdJ03*Uwq3G zM*5VM=Y5jVr77K~CzNGb6r0@0>Y&6io|B-vF=`BQwyKxSpNbcN3WE(f#QFa~Ii1H$ zJs7~VGhJ_lFG$9(4B?LsZ6V~fJUXx6E-Zrz(jWj@dJZI8{mmG?gb-%5c~hkK=i*_T z=1Cc2$L6cUq~JiC2d9BQq5hM!y#5tR4V`}KMI6nJ)#ob;<@Q zRY@doH}uCAHyTDh^4KLu5xsDn=oH@7Q}RRGNAr`@>T?)4(%)dlX&gs#&XP5EBxbU} z2}~`eAessu>R68H&r-S&FJ=Hpo@uaQJSTAc%CJ}Abg1PhuMZCo9D7j2esVdTQYIoB z?^M#Jm$-5e?GvJ5TjW}m*JAv4`j7$iv5kcjCOTK3A1p0yy;*U}|GvHd-Mkp!zFC)e zlMR)at;u}7uUPCSgRClT-o@I6{pkrqJ#hR7(OxIosUf@2l1EE9;jSzTc@x*CY^Ec3 zqRof5;tjPPkawQH3P;Hk_Mp8q)wp&n4<)PRVm=j(+T=}3HFzbjl;u_oT`zp!tM1s$ zR0q$HhWn zyX@sp#on(Znr2TLmtzmrA;@Whk=4nW?l#VPTa_gscx_< zORGnWv7jl)XM*tyt7VFC%isVdZ)v>8=VGm{o^z6>u~oW1-YX)X~Eh=HFShd z!R#UCRuuaef%>})&N>>xvmtkh8}))wST~kT(a<65CxEyb3}j{O30w=rLHr@KqP?X= z!vo+MLGrv`dbD`vAwX);B8W>b!$HVxSC%&|ODahT1wOlP;$&5dtL?j3hp-}1^A|@! z-j-ESmoJJO)t{CHL3BZwZuy*qLHU%9T3p)A3~YH9rvTi8OXbPC%t$M4v5nD;?hm5A zzoY|#pQwQK!jt=>aTQMB)g??5cS|Wi>*1AkE z3Cj>IFYa|HAJXSrHn1*qWfFnof(f_mJ9u>D7S<(KLul?|l=7umE3dc$L^}TV{qo zgd^%%b(98L#2d=mPPY=z-45;Q^HWNZNL~mF(5v7#Gddk|+ubKyq1z|u5O5cTWQia{ z&+)@ppaPkOiha3FFPTwPz)E?Y;A8yg1uq7YhTNEMzd;%%l@HJU7Uy!F32N3DgPTVsCGO6e3x_*<_-(lJE0L zV`1{gPeqB>YX8GhRF8^c$6?J6$7McQ8{Qg*@|$yABeFN`5pB6=k9&xfVaZgkXlCAH z?|M*J?=semX18}@^=^j2U5&1>4Tl{qnYv!7kIQ+FJ-Fai-h7+N zJPacFE@J5I=s+ruZYmFr-ZB)x!#gUmFa< zZj8qlYE|L>dZ$C#8N+s>wwI>*z+y@mzSLi(UUWTxC>e>l+OzuP>)Tp=$1H)euwtS# zadr9A4V7h9&euho_F=n_Z0Fe8wmCX356M&!>(+0cElQb?3p^Lv zFY&y|6}mhGuL`&185n2@arcB1ID0^Kr(fcrV*ik@?DWUo2B)F-z1J0) zt&y;LkqdV>NjGKf$!#qj54JB~87^I+)eRX?l*_j$$FnM0)mS?`8?#1^m4*-Eq)3-L zx}G+<84Z^f0O#QzS3lC;>22m~1u{1`=3B)nmSHhz%}o2?)5$i=@hb2MgxIU_3a)=t zEfk4;z-+DB&VW9xqJqO9c%_n0d-}6|a5hW^b*(SO73Z{g=uxMeJsgFM3rP~4+(+ZR z+%WJeN{IwKEr+8F7R*27#uo0gcviMGKRu`MVtjrBQW!hXLd1;sWL9LS1~+6WJv-`C zYzO@AS>T{U%jP|byB1CrWlEe6#Y(Vw899*^*%u_i8`9Iy^0PN`^@ZDIjzWmUUn_L^ z1E*8N_CnTAe&epQd)$sXErs^$8@E3z-ubC`{-UKjho<`}KN789 zvki_zXG$^0eyUtMVOS3NgUb{J6vb|HfMllsR?@;a-dT zCKl{kw%@^U+0Hn?3{|T??5c*$6?(gX#jvN>BYew@V$%fEyRyQ%4j??!>T&uZf%>S( z>U64LDL%SOcBGk6n8Rb@hQmjF_j5%BPwnjXpzcKTp-X{FwtK$zPiI@likZM6mY__5nn|uMKy3H^!i?D=&BEJ@Vs~ z%EQu`^8Vx%FS^e2GHzYDxQ7mMTKc~|Ai$58eO~6wFO9@2Tncn*TUWaOs8MBPaN{`q z%YJ}25nMa?wyxN2H{}h2mW4Aye`*chzt*_-kNvbO_Y3cfGphNIoCkXNA)o47^Id*h zS5ysDS?SWM?T-Zb{rGqrZ!ZY8x32Jt@WOU(Le%yrTh@2)gZ=Z_T>5Pb z!zEXZy$R3o0lZ}A-)^t-D3oGjWTP2z)2{hParE(woo}vY*8DFfW|69dB--HmaunI< ze$$sXrrko-V!m~sDKF1E4y#hT8dc=lVlQz4OTWcPKZ@;qX%0hiu%NJ{^~10)5VzXa z9iCdOKYdI)CFLZ2yx%ks$5W%KY$Dt0(!ay6a(#*eG5r3Cz`s0B)4k2EW-qtmO!LV8-W;$P-HTC@)j`V(c-fXU?)&N*P!uoUy{$B@ArQflC$JQ>@pQKbN`UPo= z+$uh|cJ1ZtOjbfB`eIXlZo$x|-;nXL^M;U3}r(O7ipn)kOHU;$Lby)dbMTtN2#>thX4y zO-<$f?_c?Ox9pa4tZ(>X9p4tqmZH_EUu*Ztd#%2vO|T_aEbzb66n&$`Q~qr()D6!h zW$&SNG)3+=tyq@cO4RxLTrA58@b&@%X`2zTCkhB z`z*4m!OXVEPd`!t>C?)&P;4Sikwlh0?H1L@kKVIi;t&kJDfRgtTX~?afwzcCc+uFS zAbo@L^3y{YGJ^lsSlUT_0pi$xisY5ex@2z1(8R4EeW#(ubcybg)t+~Y2{%N`*oq0U|6mS3{|gbeUy##;74qm_ujvP)LtdAxwK*|)=8Z1IFPAIw=dvFK zwz_|rk>FDCQOtgEBqpXYk_S^++laEuw2voJsKh9F?~T`Z(kY}I(7%J&T;~&Wl1HO@ zCVam1l6#wsZ^?9kHn+-LxprqRkEVSa)0lSRsTUKC5rX8}uN01O|FM$PHnR3iU0!;r z%j}f_`Fm^Rr6hcN+{6U}vG3|-9(OBpt6TiSq`G_Di@}bJ6xY=OiE!#C5>b4<0v}&S z%5G0S^%j$<6MHcsXv%|wwCE{f=-VLJA17sf>E*4Upf_Kt#C$VdcZt63SLwj*AB(9!!jv#+Yh>|xtbpt_FFs&~MG9P|k8ejaiQ7hAR@yliFl!Q&u zTS)%weF|^IGfgt&95H!kvRJzQuDF6?j1Ols<5E!sSX*H=UGmCoO=Lz36V}OpqBr# zR&~sGGpRkXD%CqMau$`Kb8bEedJGornQ@!dW}PHedUxe@dr~-6;{XWYFS62te=|Pt zIp+arjoKyFRNo+^KL=clabcZl)ty6JtOShR`K=hg$P^^eU0jj zh-qqT?N+Gdr@05hro|h?t4q>gJuX8J*$=(sx;1<08dZ+H#@!!SBNb{?-9804Cmn$We6)BA#`PbSLsi`YZ-ZXnadZT(Fg0o9@ z>|QO#ulrVwYMwfHpzd1^jZ@Vc&d!TGhT5}H4%#ytXs~LYKz~8ceQ1(_*NahyVXM;s zjI%xay^+*JCWCh!S#i9kyqYgn?rkRRO)5=&HkR0#kdop!6z46>-j?U{waGC_=o06)nNj~dl!ZEj4ie`9Zoxb7vhtuR+LwvE{w*{a=tjes|tDAd#p5-X}V3awtpfMh-yF!+A>lCw$e%d?(V&iI7=~o z+3jqsS;D-AT~RaBd}_0oU}5Z!U%0vG-rTtDSK%yAr#(J+3!^pao{#hjsAhA zI7m-99}VB_S0Bv(32mBOuS)X*nD}h%yX8!IV+r2`O?ks^&~n^a0n@EqH@L8{R`opI zs6OD}=$QBAO{CmLc|852M}8SOidjqAJYDNj0#ZbE^%+!#N(u2p{p&@cxN4mVuuvd{ z=6|BiIzwJ^86S3UU#ma7Ll%p_ENZT**gdfD-HkC_0b{`RG1u5K>s-;;EfpjFxC<7ug_Y-;&|1~T4)5vpx3#IXHP)|HePLc>#SV-* z@Ou;;#|)KRiGSBhva0=%- z)suk4J%AgxtQ}dsP@}pGgq5J7^&E{FenPvg>V-(|-dv-yMBT>}xs1sFG|kok@yDJ! zo6*x9bAgyZLK#D5FWET2k8!`oC1;gZ*nMFmcZ0?bBPZz4gK4Q6)nnE+tW7SCr*Mn0Zu;%Y>P3fQNKnN0vx%Hds62s%!Q<7l^WGq z2gh1fe1BrsKI@2?p5SzO4d;Ke!Y< zbPdt=#~7oh{w{Hbt&k_ruFATB1wWc&xgEdXG__@HHYuqe`dZT$xv&@|@7ukKrr{+f z0Dwc!S%vcTGB9ygdPqGms=6-{0F-PkGM`h_Qxbcpyc?5D=uA^ri%*%RK50MIET?tT z^9FF4(Ul6`*ESDXvo_t^v4*RHZZ=$)6bC@ zmCa$r!^f%OkK-DJTW&dQzB@)D5C2HeRe?>1&ARv>8UYKI@Pl-Kc|-v?0ImVRorwgY z(Qxn)`IISB-gIX?F$dr7zK`~4aICoQtZj{|)!pXu6H@v-0e+o2V4j(?4k>a7(%6M0 zvLFC~2F;w$TXR~AnWk!rvqB|xSRw>q7Mv@Ko|N%D_z=|iL%%;J|nVJBlL`#FdI#O{4fh>^qK!r zopj@vgRCzuEHq9&4@bXWL%T+mRo%JDq7KQjx>)=ntCAVGcDwJE86*Q+1wp5L0@ReZ zmGghrGM8R}Q`zlIU<4;-fbP#_1+>U#Mq*=!N=~=9a0(xz871ie^n?DOnoDF zXpP-&Quyo|5Ohok(0aqCvl@U@RlWy2YD`ec@C}Cqo!L)MCgkr)6KYkh9;-y1<%#1n zO&y>S?D5`{va0+(PKUFPumXw_AN<+x>L2)dFq#np!eb2ww7TniMI?BMCfu{S1U^WXxdA5e3)ANv zJqo+#MGN0RG(-#E_-O*s!neJD&V~bE(|-Yz^o9$Zs4|{TT?MH5lf%LHTtGks^hJv* zz^nR2*4L|GHo-AJF{CuJENrz!tbJz8I&w}d!0##d_lZb~fCWXhCn;oJ&D4I@WWpAc zyv5hLsUpw!l+CKaOKGYCPn*5-#5C2t!;TtXQ0hVg5h!g@o}U15YlV1cy+c*+!>9{M z0{r{J?n1ziO>h3=j6#{9P9(TphYguJL8_m&<|_ug{wnl`chcD@@AM^zR>Vo(`xiU; zZEM=RBlk+jJ-HKv1;-TW9#Ai_(qRVEKwu>=YuGB+r2{ep+Rl*%(m7 zA*iJ6Yn|k>nR|NQY8x=^s67jI+J0Y-+w z7sjeUpc(%?+Bb?4*S$?!!<#>cV#f~u!)OQZ)%trbT)n(@Ny1xrb}Nic5`YXMo{H;^ zf3IbJ`m&tP1{6&TY&sLO+tP>mlD^SHb{>Vf0QP-j3xe8Kxy6k)6Ql@n-Sf}yn5HgO zl1FWQGq^y&sTQU~bJ!r{I0JwP+?QryfZsTnve^o&4lA9M4%3F@#Cv&PHI2R`dOIN> z5WVI5pg$M-5b9^&l)SEi@`}@$)5PiL%G9|~ zDX$|TL@n0iyr!vkz+$wd1Yy!7e>`b@+IqE7%R0$V5Lxovhu@I3sjy7tjy$gLEQWg_ zKLF@xf0xmJlJFucU3%A-{Z`&yH#XNN*|hW>Meo3xI>aA4el6>-;+~%yqlX;h{BGut z>h3y;oPfEI>MkcKxC&>#5J;FRVzzEibbO#uZg8yP1Br8!^{L;P$u%_jj>ZBrqThq*V=ur2JPFlZU-fv7 z*0_{co&&`EL!&f#T?df-FP%qSHyUq^wI?gi?Rc<32k`^}$EgKPmR9%Oh0NmfWe zn~y%SLgNokE}#EJe+n>{xLMVC-=LIqBIolzpF?72fzDt9R z+{{&@+7B!;{_oUm6X>VO_+faLR8?+qrdGhLhj%E*N1odN3*8Ey07H;o&0_gWeSxX` zsuK%nrjv>`2t$^!fNe{hr#-av6hUR7t^FC0WMNupjp3)+Poq)9dVLJchX?}jF{QbP= zl+T`%J}&TdmG{)8Dc&XxKM?h64S*D=&=YU9wX}et8VwNAPx*J?Pnz;F_$ZXyOSZjO zO6&)5rJ%JqOnRl_t2cYr)YL6?V1Vn5qGyk`xoIDUMN7mNOPZcbt+GTvOr07gBpiT2 z&LPX=#MFIhuw=CohaH^X^?Wk}Yx0I3Qc;FC8*qm?BQExeKdul73i_DKHeXz;N|X+k zrs04c26@=m_!YRg{ieLD>SZcyAG-qlIQ>E0fTuq9h??-two}RUkfc*R8EOF zY07QD0QKJ)AT+k}s&IUJS{D_mTlPXAD^}EKH<=?N)(s1OyL)T|uAO1quby$ce^16E zqK|L{F`wy`+|mFp;vCZRdF2p!v`ypK0ROxzi=}!DC`>*pByXSn=3O`1kV;>`8tdfy;-QqS4zGawUwKA$ppXB%>KBaZ9 zszX`C<^XIM{=I?my)W`w#ISf~WwOQe-^%ZrojV@T{qb8dJuT&@DBi7B)4ghwicC`R zg=gxdxKROv1_TIF=E#k!KKLcFFhpw|8d8^HBt6Y^!}*6Z8e+V3~qBh zf?V=0O`exPX8PZX<20eQSZBU1PYj_QctG~>>+yS2g%?GC_SZYqs(z|ngZ|lKJa7vN zZklE8OKvo`@ZW0Zv~J!>6upA;R2UPT{+ zmb+hRnEL*%)v#Z?PjR2>Ls_;;%Kz3w%P6bv4*;S0!~THZ!TNinQHysA`TyGaT3j3R zMd7DAaI3SJ_Ns-BAFBL*HqrGP{*rgQeEVvO7^_Um1X1cQ$KQH`T6CZS>-)cGIJRX~ zgR1zXB3F~*v8I1(WM-7{!(2ALIf9lC_?H>KNC4WYzb)(IzpHBOH@j}b+^4nELBIE_ z#lP6dR-`2W4({YR|JOC( z_xN11Gf4^A=V|$z!uH$zf9a0VI?h^jO^Sz`sQ=#crJMQS2N}w@49&e&a)9_R#L~J$ z%&7E>&E~%}eXSnw)442<*T4I(%`;ltJN9quWXAVf^WFz^GXHv$2ks2h4+a z`0slHG%I~e`M}Q_aAhuH|`hx%7d;Rh;s##hmMLDOyx&zdge_N@acT)IylZk>+ z!&Sajw61nIhwL0s?|)|u&?i+n9wm)}Dmab@I5WooOW$~_=ZK++@3XPo_}^9hm9jQA z>$~ZFlQh=yH1|EZD%N#!Wvz5|#GQnW>!X%;`fidh>7Uh+(K@g+kY_$rU>!-*ou=3k z^wJ0GSTSE783%skdf*bob?T;rt<7g5^(zE%xHtUXFw!Pt(|xvRyK_?b+ty9w;Gl9r z5>Zl|;))pCv^&%O3X|z5qGMB%8?9Z!xc}-oywXvR(T%Z92E~%TWd7a?r^&!~b^(KZ zOI2b?%CS81?+Vou>Gh7{8 z@<}%$>&OXUeRBqK&Q)PwTKaCx$;WOk3wf^$>Z$vl?TB6iGzzl+iw2`M^>*=0gh4zD z0Tt-$dl;E}~eIxKV!X$2;_U3-d1 zdp)rbu7c=G86A2dR;7GpeDaK%<{{KA>S?T7Qs;xQJDpF*d|#<-OlT~qn6xL%;+1*> z&h$GdXYZ^U^dOA-4fa)4kzZ-2$R(_?n4%<{VspQy+^3&M- z`pjqDn4{=(mV?Wrq#y5Uv<@rk~`?pRDv>FL=8+!sjV7Mm#?Z|6Z5R3l)1d%l51ZglN6VR(>DpmXt=wml=<~3 zg#l`F)CHuHhH=DRkj=j1- zByZg6XYxLGT4G90em++EBLQDj)nXC)mW!tv&;XXX+;e7sK39$7s z5%Po;Lqh2pZ~O?#4la){8E=fwyH~N7M?ud__F}82^+d5byNW}Rv$ABrQQ4gr@eAzo zt`9MT6)uAJKpp`9MtmxDU*4s8H$m{um{+c91~cmWCt}eo5ktpNY1OV*_g_+ z)`?`L7t;k%nPN6XvC5uL>O5KKMx}!Bfr76-PVI>ug(v+ahJi@6poWVqXvYJoB9;~?Dy?d zCY3(a<~P%Bko~Xk+n*mC~U~Ys!WOm%NvZIpp+IL+5hT z^vSh;j>OZ$X*dyzGJ{{`#z6}z}=7{$HoOFWA%^3KbJV<+chrI?f2v(psgXz3$)w zc7yyXao{} zqN-Nl^=vlg-B$_4BP5A%IG0(CDu4En61f(j%xe|sjrR~!!j8LL;Md6$FCY){-Cy{fTzW3A1`v10kdhftN}Jn@t91RaC-+GF!_Y{Za2 zkYV^iwMk#*PQG)f-p!CedNE2hLG<`v3_539nXn{x*8mff>mDJCo{;lU7H=%~Jiw|V zyl1txjmhtP+pS69-m+-!ts4K7t%|&$Qmo*85Wgykhq606j8^zMA%q=113@~L?xmYv zRdI~juPl-gT21}3YMFu)TUxlbnAmGNy1g%7OFMOB>{MERtSZ7xVv ztTeu9VkpugGSl0E%Hu&ZyMl~={&C=nGE0SHNyi5@T<&@`8 zoxR!pq8gXu7o@Rv#aOs#M#@yf`syp35*CeLdvJ(^$0{WH=^=fp!igSx?p4-!T3V~q z_DDU01^DgFBEOI9$uxPkVqKLbGRq1jc|-zQ{We$OKe{t(b?&MA6;as1dpWqyH+gGV zP)^lwu4I?+eubYvf0FAj_d4a!J{LrbLvTE&93LuE#!b%qxSxCxdOAwar~(>~oU6_B z20d?Y%m(XBsK^Cro2Q;;4qc6fa#L*>|JJKM3zHV{8IiZ)66TBPt3q@rA<&I#^Az+Q zWuNBW_;{5i5xMfD-Lg9Gwv;jNT@Gc2j99Cb$i%!$JMn!sGd<7Ft(xkcGJWB5c<^E6 zPN8QCPDVU%Ip*uH$bt~@_I%zdU4)8nh*N)TIna8zO@;B>rq)octito~f z#~lEjJNoUf&dB9?>`k;8vhz?N<|B#Zp~3~52IUgoyUyg7_~VLYpmvokw}kteR{%P} zt(%nP{Cw-2E5;z)u&ioG^PKeeymBM0R&~m1hSU~Fs2LcDR47OqEKqiZOPGQBDlkh) z)E~~3pGpgNnRg6Fu1=-MYJ@C;3Vqn52To$~EakH?_WRzemrT&1yZTB_$~VU7Q%?A{ zBf*c&Nz;u;f?0dW7SCsZ+WUr$1+IOfeeFEiATa28YrNj^Q>fj09{E*Pyrk~rd88G4 zye%tll;VPl_8T+OVzv1it%HI?$U3)P?Z>F}VV9VoXM9~qQO2BM)XMIF$VAoDibPG! zaEb?B$5?Lda?KpwevKZ#DCKVVhYmS2&&a0-rtOIdCdu<=B*`aB8fIPt4i7w?I(BJa zY^2VY&hAz&7cJ{_7(X%DcJoDlOl9Wi1d4J|ki9nMr0FxTky8b!mw3f*c(hZN9$teR=X@Z zh-J{!khIgI^#{0-O61hNY+35J>%jr`e>5hn^G$*Oc*YM~CHvBvp@kPW-+ewEpTP0$#r5${}4 zxpM{Jp_Zob0s4ogU?&cMy=h7y8*A30)Mxz`ZV-x~l91G^-uzn9;|@8++cIS_8wgZ|(W zal^JVy4Q4XGU+Bzc^If?c7tLI*x5e&iXQS1Z7$Y-sw$NiW>KTMq(ImYzYsSbugZG- z%l@I~?9f_zia>mOT!pcD0oczrHD#0CXtWueBL7(c?CLTFUFOPn1Bk;m9m9v(%ZfDk zT$m`!#plIuKQOB8{e%hforLrrL*Ms?F5Pe&}g%8jS`j@_j|KS?V2 z0AEIoog!$l4ueS8ZuJZ?kc|t7 zwL8GCW+*Qf(ja@aWFIwWuwFgn&**g8+ z=U9*X{RuLUA^CHYtAsh(>Z7kzFrREw?7kKf`sFN!N~7r@%b<_fRSuEM=0eJOSKwa- zgDoo%?@v7!^tmiUF#TDL@ArW=e`@vJnh%0@HEVX7xq2Om`!0#YT$0tNJ7oO=uTe@L7w@Id+8>HcPTXy##&3p_=Qrn5KL{ z%^M>Db5X6>-AdrQyY%5Z&L~?ro8rDe+?8vZ`p%Z=Gmu<)F&+G}lp(=!L79~L21i(F z)gdqg)BR||`WhMiLy^%Rmf7HXLHd$#7HIQ2?}~(N!`nTzs{DofiEL4n!nExg=&NE{WEmO)ch;z8D9_6` zvs&MK!2&wHFM$$)^fV>rN)zpGUjG2)&3yzC}V6O!M_yxfO>2sb3bn;-pU zs8W@8eZE=*`jIAmS(OogNKdzLPbjy~hgFIr>u$^1Q^^o!^r3=uBI2QX&}IzW(@DE|1u`NN;Q1YieA z=+l@=-Rg}FY3F5yxvTF^ytRAP72w-?Qp&~1>ZTlU@Dbm)8)SUGF!4^W^4cc7khD!R zJ3~_5>9f}NJdMZ(tKhVghzcj|ba| z!Oj3J)FY;hoPAi7#Cmh=F?OiIbbz07$ohLJEae zN{`nwW`LY?&XJD)s$+7qpX3O1SfDpoPy`ei{9Tc?@WE6YPG#qGlFqj}PwBpQdNPlE zsm-x898S_UA<&9=EwNHd8aDi3pOc#cp}Wul|H^g(G(MFYyk~j>jV{Ec{XcBIcT|&U z*FF5s;NU1A<1h#Yfthgx1rZbi0tQeN3mp=A5s?;}ln@9_q}hNO5s=1NV_UHPU%Gs z={$*xQkC@Y(yov~N0)|Q6+asd4l~YR|5CZap^-AJA97Q|#LqS@yAGj9%b<4u_Is`P zk9?Ar<4-_7Sxd;cyg2zrkguVE_2;(H_KuCT>$YB}O=^Ktp6~Y{&;G4%C{}k&=Uq7% zkPcm+9dZ7=^+#c|&S}SnpiXfO=falyYM`sxE;+(S*xq@#dWh)0EepgRKxP?!)S~Y1 zsPq7dIlKLlGhIgf-uDv!oTQI~aKGk8gO1+&s)VpPyP0eIs`p#;4DDXG9O2_;FL``e z7wwCQ{JyidAWGb255}&1$npkJeRfhNOb;Se7+2c78Bjm(`o-mY{O<#375+9|N z`0r>0p^!1r+vvD^CtRld+?LXxmLW%>3$>PeZU-U=dzIaXTo@jn90J)X&QI{sDS zl-JBs3z?*c9(J{LQwVjunUsACySEM|s7L%-2Yhv}pI7MIYEu@Ry3{bPv0>$J$;w95 zDFdO`iYJ3pC>-ii)zV=Dloa8sNb$37H~x9ANpVp!<$j#f0WQ#P&2XLaJ~^9rS)l+d z7_*tY^f!1H&e}=G%rvs=AziEWyKQD94Cb@#DIXYB^@Vu_x`v_`RQikHUeOo4lptSJapmBB$;sQ zir@^TKSvWAR6rB_=-y;MSSxP{pE@-e1j>yoSW7D|0a^0TlO@sOjlPBSH&F*#mrT{^ zEBDii&vpg@-6@z5=jD@v-z+k!%ep_1&q(n<4XjU}#9How>G0#n7c`vmNJeG*_X7&& zJPkAUZ)4)3v|t5$=BoA$=)2WW6j1FRr7C{x|NAfvBS+)~?;B_Sm>qf^6x6P^4k5mT z3>0{~m(p!jg4RK9=}`lc)LqtJtGE9-P4K4#mZl#eK>H(6_WF7C1tpWB3hX~Q5^9{_ zm(EYqr^)aeSxW+K%Hh;t@?YtW>WrG>BnNh=P-e}I)}TplDJ{*EqgnHOO=AR8`Wv$4|?+s%nqE`taD+4x@yOg-eDnPS&?h zoZznn`irg&cDGboeIT78hnn zJC}?48+a=^_9toYiLav#6c0|g6r8X(|0GZ%z>~6)vbSHe&zgsYtgAM>S_Q!L7Zq_M zan*o;frx|~H2ympf@}A>WG{j}A5EZ(Ni-4%ISjIg{HjkPcu;m_2#U)fS|&SR;x}SI zN(U(bq*#7~;Q}X;4`T`(bHnPKshb38-N|8;IM{A}S@5Omcv`i`hYlguAVmkg=kZ!HMdvQdX{mT15z0C| zKeyXoPIVhAJwWV|U-b0qs|^PO8i59_F+@S|%_ZF45vRO|KwZBfLHGSK(-kgImt)p4 zIPSg{CSo>PY>KUyOlQqYLZH5sN-+-YZ7feQ5P!$X(acSlNCal@th;L)uOX%3h{E2pR;Xkm#q#j3}Cv?M|^1KZ|zhIKa67esO|9A2r~^Gu~?o znF+rg<5K_|GpOkhc?Li(N&oK+ts+prhQsbeyL!dz*zWu*{_eLy zyO)MD+72S#;=_oq{vF5h{J%)$bUq5acfTed*Vfl)FL^j3KWoe5wfEyA57tItkkoxQ z6w=gTe{sLlti~0#_VtwlByFnI;A=!0c)TOZ0#quS;<~GtuzwyDROcB7{7jnLqDj{X z08a;R*RTH+({KT;-v!~~{5I>n1eP{afKfq&_Z$iC*u15W&5@=UIU7FsV+!((b1We}68zY( z;<4rk;?111J=wxZh?$#}^sR~B+L`oNRw}P1PT{eub@3#qK1shzB}Y}T5Ugxo&T31> zi`nj9Yy-TyYzR3GCG03IU);>|g|-~_CCOUk`qo!0yH9>Yvitd?-a}}Ehq7-BU;Who z@{A+M@_WhsdMJY6`SwHZB;O=g>#bGBTopO}4TbrJzk7uE7isa!%Nw~RZsP7-KcRFF{`21-3< z|1#nts_*e3kipbi!NI(&A3#XF`jYV-VdTusmvvLG{ZA~=U#U;fUdB03FVG(!i9|3! zfzyrEz>UL4`6H3J&x@F_vKpGD4~}`G^M!J;f)gFa8`>(qbobU&`s_x z#Is9neJ$cLn&f{%`a-?lBEyOoFg{{EAa)rtD{y|q{8D&S4ew>ufS2tfiZ5|6$B?Cb z|EUHl;*MeFc*#>yHoMdX$C7oy&&e_&@_1G!jV z91U4qMUG5aAn;bk%r*qwxDC!QZtFL(pa7aA#Go+K;* zDBXX`cZq#ws}G)1_7JHy!3|>}WT$Xx$B@B=h&Fq_L5YR3->O1O5h-xr_nf{;#xV&#<(_kIY9zPwuoLQl~I&Q4fXbI z4sUdIS0jjthC}#kDH(Ff2y<-(#&uErc&Q$t4(G$g$k{fQfd5xmNgu{`bvm!jufd)W z+Ao78IKZEGS#gVuvn}<~ODQiwd|G7yQq+-^&hzQt??mSLTPMKroey9eLx06cl4{KN z%0s0k%T=>O%R0T2>E7FxM$&(H4gJzP9{}jq)x@~&k*Sa|)!(E<_#q-MoW!B%w2}(~ zV2)$pAYBn|Z`*$ibn6-M;O{9%PlOr$Ct8(SRCdmu(v0-|Rd8*{vg`DGOUHP4_VmMB zu}XiQ8=q*7Cx`Qb>>IN}9Iq+g{%<1$NTTeXXEpL!f8!yQ;W6b6XP_RP1s&z7Rea}C zxCEg@=q1=sZJF#ksGcV$K@YPP`%*ttTr)C>-(b{g#VPN{)=Uq^u1w_B1K{~lM|}2q zTz$27(%#f`Zp#7fuI8|M?O&u_}AWr0orHki7!SQ+HxqSipL;Dtd#b-QaYz^L2*yt`GnMu zf`Pn+N6l>#YKABllrAKS#2Q!_!=UXkyY?jti6GWT=DX>u>A7?JF9{{NsUC$TXhr(F z>l;6uIAzxL>U#urO<2cjTz%R}u^A6T2-Mme)ABhb;}A!aETpvRzV0eFmt-y_C2NPcy5nJKe99Okx78 zACwOr;p-f9_3c)93n-TTX=r(%uH!`;J7&>uwHci@8xTttLnlx zbpWB@#c&m;U+c0l9z&FqV51`N8Z1ED?{EBfF?1pVL@g53_b+3hz5@vF&?f&`spl&9 z)V{}mfZwlzZ>|+)EAC7zykG0&HyVe7lzgi-Q$1(faO)px4nSwNGDM?Y@EC{Dry|I= zO9ChmhIOhU*;Dm8$8!OploKKkD5W#S0vjG3^M;smeeg*7Oj01)Wy3ea2v=5Z4u#r$ z_}MSF5*%gTn4-Ed9%q7WfS@tAL96xKM>+e*PxVc*jDY0iS)Y$io1^#cycegzSm+Z?AZ7=P%;3&@(qyBdn^#X zynlu7t~>yOFCR=uC1Ap%w`ZsXhaYm0o8|V9v@&n&!0RC)%JkQ!5{R-dx2zF7RyqS6 zXKZOF#$mi=sFhh#X5c0FU@6_KveXS~{`)5!fme)k37O>@T-$SaqzD3c@m6>DxQLd$ zpIxv{p`$Inm4hty2N~UtYlPhU>d(!s>Yj6+IZ$@}jxT;k9iu|8omMtpXw+ zsf#Gvt&dPW^QLjbW8@KmrTvGakQ~Eos||&~ZmrgaZqBqh_}^03KZYDBh=x!$U$z0*j)G+YpTxlutwzb-vOKZ(B}$d`cG&HQ<;dvoGd zNGA04S4nQ${&ii4(e$^IsAW^8u}z68!Pm5=jf>D?hrl#Y$$B^I_5L7%)-!!M7Ic;O ze;J{C-1+)3ChA%@d)lG>4KovP&AB)I2ql9(cZVjq?HR9U^}|6P%>UQI1po)9 zD4U=wy|j4u6lGQiMC?W}h*BbStP-VC#Gf46Xf?^>?}1cqd6rLn=nV4Kpp4l zu;Nx9VXzKG-z+{ro&{1pZ?`m6-l9o(r#!vvX^(fAttld^-k&vnmQ5S3FdZ53Q7wx` z+s-l%lA2+H_gN!3W4@hRl^ZcAOemB1T!@c(zR^P3RVjFit?O`TUS@_RPl9TvEJF+qAC2T?s17qvZ?UQ^dxDo=cNOwNrX z!{w}p0Y+{Hk~Q1-?ptA7dphK2en|O@;jO`(3qFQDZb@;*A-dl3jDyDrZ{R+|lz%Ow zC5gKkQ)gedC_0ehHPEIh$b=LGIo(I1LVjvhw=J7jjM3V!Z?h;XT@m%Mcy>6|^K0c~ z8+rJBL35k!TDsI*a$!~ZQ+h@jQ{Pi$*x_E^!8X=2m|@wpo!WyHBv{$X88iEN_hh8X zLf)Z7#+g=x(pCbyS;t(})8K>W^%c<7^qh3WS#}LWCjO40L(K-_XV>brT;E?cjC6kH zV7;j*l(l{17e-r?OXoN{CTiZr(C!V`zkc*DJ1LbZ_YP}+tLQAH9t-AqRatyfphv#? zv-MICM=uxuqm6yS0k#W$;J*jW7ki9Pe?FBi#lPJW6SbmsZvIaz?wf(4wxqL+i<>pK zUrC#BUNpx>QrANgZ&wbtE6qH`(@q6#jjV6aWygdKsO%xTd=;65p8P^LGvtHsNwFdZ z-;a!uD#YFS0Jn5$=C``@H|DNaA}+|MPTlkK0CS*YIi{6PdA-V~rF~26rnxTE zzt=V0#@lf1H^@fUYT9N$%4-Tk=d zbANA~%`=@1%O5YE?n|?2%n-mhEiV^PE*cBTF;`d zYB%!9+_R22S)SX&73)u;Iq;1~lfqA64Q}+IAA}HV7b2;wlES1j@s0CCWpRTb2EkFPKo2P3oBD3aAC~MzPIhygtqA?6NG*0+Upt518fWJ z^f)h^AyBTuD9YSBibD({uoU{m7vUtD9Ks$lNDvWiTkbj6S5o z_jXdu8w?~G0Y@*-SVTLt&Wvgw!2qv-d;o;Mp0s^huNCV_C4`|j8;4?{=C|oT=Cjd0S(7;8T~wsnIJ;_5Xm}7PeiuEqFKfp@&1Z^$jga7TkfxvtrpS)v&Pvm2hk}WW}F|Od$~}tqtKMXQD$E+9rh9u2nRm*4S^3 z9?3^SK_7<^fwC)DLo?5WJxAuD7*^B;BUlPPB-dwdPcb26^_A!D)OpQ(sU$Q4sJmVw z8e(zUD~>rRGsv-5!QtX(mG(|)*_+VfI-}(c#N9jU?PD1OOLM14n*EO=lbFFm_HpWa zS4wJd8_TSD##guZd#Jx!g6bU>sEXQ!w`mdi)!v$AcA^2N8uD7*yclm?T8}ckNwQne z-Oqp%#isd>1xZa5C1sM1$=p|7w|${#U;TZhfH27@C?uS-RalK|L#Q4^#7f|sn^BJlmN5VV|s*; z>rx6;b<6QVj+-Ko)h5NmGsv&_&=ku3=(xnsv;>M)A&cu{#>Ffks@{7eLuXWAX;kY; zLs|K^&x=c6X787P3(4?u3@_}RPRD*_5z1M&IkE@QL*qgy)Sq>UO!6y~WB3$rOUBQp zd$^@gKjic~28c<6no$7yD;csk=*m^R^DIH9Xb~@o{!3-?Fz6W6!Pf=_*5k>sck7^P z6@a-ss1!`lQi>c5y~QHWg97==*~g~XS9~@aRSI3<@>!CETy4!^f>Zsz?e<~t({7|b znxddZos1;B$|rP&XUfGGV_jOkU|O!(yoQ&>vQVEdi>RA3T{w6{O{LFghCp(){5UuC z%TpCU`slG)SMNyEd$WkOKAql?~Ue)Kl;?`#U*0?b4 zIh;P^4JR)*AnB93hE`IIi+t^xxaBHKtnEfVlJRDe%iEvW!uSnoFf*bKdp3>~%c}US zJ_bYF`6NzG9hFuJ5@U%TjJ4d?%+i$u{ZXgS)QdQ!A;jB*$x)L!BupJFb%ci5*6W1MfZ zkczWvNSCwlCEkJB9!@pI$0Em532ri+0BzNdUpv<#XDe>jVO_$jvNGWfK8NXEBj%JV z`}*8mZpuoBV#UH1(YEC`X4gnjT15X%g63E`tIOZVA|Xz(Y?nrbQ>{y(K{oE&9y4&Dp>H z%C)CveY}~nI^>2~c^?YLy2>Mvn{Rv;dHF33eg-uc*c}&1B%2k*(cVo9w;HGfnhu*x z)CezHe%5F~jdDhiFaijiU#?HYb`oTHRisV8RSXPmLRV=8 z!Moy}>kAg$Guv3dXbAzzV$0G3e%J__F@_(G^YUAa{4o3WhdL;xD<~nO%D!-h+v&FS-PoM)wFm^R*DS>*h$g8^ z?fIa#ft+86%1YSU!Z-N5}kVQA2M@*<%mGF^GYlN-F^gdoyw5RfZn z>)m_y&f^bAGg&uos_Z2sp(5)-5qKlbQU&zM8d;@&GN{Z+*#&w~TdY5Uhz#kzB($}4 zdl8`bkq@ocHs`AKl#^7ec5F4pm&#kA^Y@QLu9Rw$SgIq zdvq%xK5j|u*k5Xcwzmm$%yIxLC8xA z9F3mK!$C7o1j%(oW?r5PSzSCz_O)zp-7o5W-_`rgS?<+-lvM4^aRkxIl3A(2s=g|} zwo*`dAv;O8&Gxl@G^U}fQk0DkE6(g&>pBCmgCG7Z3udQ8;A8QRQ3Nowx|0m1#wq9b z^4hi)Z`Q-K%9ZQxdy2u@BhBWGHn$Q4o_bGi#>Ll4ugTZb*M-rZSPM{?hh);?q> zj5O=C8_9?6n2ml}*Gs{h=)3sj;I!1W8GLAT5dY6jR6bR63qm4t=N*Z=vcf+6CTlUd zW|huqv4~nsmUP8Ma{Ncptpwcka`+{DkSsKG6AfX8RKX$)cc161-yEgS9YYZA@!!aZ zV;FZ@Aot$+_J!(1B9x4=*znAZAfAp$sx&zjZ55)Ab~o7%gL;k_S-VEYFGQM;F9c3b zM$%10E+0|QcM_jpKjzKHqjp|KNvwLpmX{iE*7(;jYpz{|g(iKW*5Gt%O?cX|RF69T z^-cs$&l}bbuUR)l?tU0o={w|cZCcGK@3`{}yW|8EQF2XzULTqx6^hg8gitJgSj9a> z0#H*N1%dJQK(EbdfNBm6tef;XZ6Is&3q)D>>w z(N}qzh;ro*j-Tm*gX@(MgUDx5-tIU#A!(EN_?fvzTlR7Or&kY54U{T?|4{d4lTItt zr$rOC&pGQ$SaviGKlUvsBl*q5?^iYTDx6{w> zVz#I{Y2TvRela_9fp<+&YfKkhrDZRlsm>$}ZC*qz>|^yb8Ewsw{b{L&J^~fehbukW zRUPY?o4v&zDGJ(EeS0Xd06nskQcEG9i*MNJY1ENKuvJV)OC3Px;hYj3I)isnDi5^J z?OI#@J`&%*b`o|uIC~{6ee#3XxDPiNSI6ovUu+1UK6jO?F_(JP(Rw%Oo;$JO>Owcs zC;&d0MlG$g4Kfp5{i6gT%+9|pv-WdVOZ~Pdep~1iHp3c)Abv?e`Y)d1LKn`?$M@ZQ zN%2aSTKsYw0UqH;hyvYXs`1ae@VU&lN7~PW=>$gYYemlu?@a8S1|n;z zi3h29vlJZeOx+gluv{HN!N>40F=WmqX1}zNE0{{3_($zKAT-h*2#h-%MXG@@fP+Zt zfjd>_C#O7X)tvB_V33-3TP! zCpCv%V5!3bW)S~zCCNK5^N`?ru1!UE<5IVd>yll&NP2jBs9mQIgwQ37w#3I#%51VF zvfsy)($ZDbLO@PJj0>WH{yeS058%je4c|Y9- zHD8h;r0GD&)p?75(#DvZ`k7y&V7a}~v|Eh7%!yDd=SY6A&j_*WkL2|Yg`IQ-C`Iu?OqLEc;HjW%=)~Ok2-hNCHQ5a!c>XQ|P zt<4TJrz%3V%Cir(t$kY!r{}reN@vBTo_oD-m@d~>dz+iP!$ytYAjU=^v34n6-|606 zDtj7&d9n2|VpJ7523xwHH*zCX<=$E_oiNgK!~#*?dj|m-C(5AL7?Q?!c?r{$e-6FF4vH-w`oB*6T0A!>C89g+-*vkOZ-gV&U9(EH_@i`xe#m3djXk;9u51OcO}-`L|!L7O60W73;ET@v(011Fmc={39ES)MK~BC0quM;8eg zv5U|0T;!l1Q-P-!T`9&IlXb9|x!H;ZR<$@&UQ5&052umqJ>sa4`_FwMz3OvLQQmXs z{_NH8d5=Ve`aRdsenA+PK84NnskdYrzA8Y5(%tF$YU<%p$7}0mB9T4#P$IFt-oncr z7#&%chWxr?S4F_7&E`q?vWZ>72^O@VBs?-(0x>uj&#+9F>7FGzc?wW`T!&?6ev-AwxEo+10=Y;!JCrAu+h|Efs+q;gb;6nN z$K^EFonD%bM#j5f&{s#ubL3JYtH?1q~8lin8Kz$aQ;uL~KWkdSm}RAmd!4E+Xk zUeMlRE^`1#e^P8tyU?jOpqV4Yp%Q8D?L9TEq87ACixl)o#fU!D@|u#vFULf%iSn=jH6JXN5Q> z#^E$laB{|21+Xc+EIK=IUOoSbawKV3Bha+wDGE`&_bW_!%QK-$gErkz?Ch|M*MJhi zK}+CV*<#$OHvvZ?Z@dC*;{Q#)aA7`4i@lX>47Sw-T4sHd*(nl?)XLMu?3#lKR_|u< z&FsZnA^J%kpJvb0Ryf%r%7?Ct;wa|3nrX<*FCtxvmv(_l%zST!i(&Qp;H=;8Jgs{o z*N-tlukWR`WPE*5V>n7}`+F*X(l7kDdpYx8teAK*oEBO#yLFqFB}>(4 zxABKhmh9Zl_VJPYF%9m6j|0GPe5Gme;rqp8K#kE|M$4u*I!!3&4ikA=>Oo%?13h^b zpkQ8S%9G6}jRI7D-3+fBkV5ROB(T=)9O)Mv)l9UZ97G*wpD-bicy=GTo7MAr{xg#H zwn0vGFXPG;!N6OSX_Abuzx{yhEh}k5sPgtBPzzh=gzEMEo{K-=tSPaTIH<+PS_(Ft zsz^GvX)nfqsa*+WPOgEVI(unbnoyr*zXUXE>#e`zRy^k|x8xmZ1V2(Xq|l0IRQ6nr z(MjhQYD28Xk&m|DADYs(8=Bf|imGbCS&^RkaHtbV&7;D6>5fGdvq?>XwT>7%L?a+N zu3;rM?gUXSw4d5n$lY+YB#q>K@IQEmBGPRWk%IGf39T7z5u?1NA2W8uWUU;_94s4ah(Y8h%sRum{LXY zw-z}jV@(81qnj@R*MCb?j$vBdAI`nibwTu%?4fkoyBTw-*}X+as_3t8+s@=fT9CiP zMC?Vqb2Hj3ddo;|YJB&_Syw;R3^p#$SyC zN=!Pb7bp!f7_qQ|b>n@*-BREFB}+Ar4>2j2iQXKXDL!2e?m=ur`Swc|uFH&?JN+!h z@1Mw!F>)kZ)a79APT3^b1ml*1Vc17Cp4{1JVH}1~5~XBCcEHGN?$WPxdhVUzvIotX zl`O1Xf+n4lPId2)h+6b#^n2uVRo&38V>=*7Qa}2Qs@|&$7)=5TiL7C&y~gx^Nky4&!y(}PI*vG zhN~sv#HZ?s4m~0?$~m`mb=994IU>L^w4>=C0=6zHG~MVw%L0cU{cY}iGF<%A6e4d9Lps&0S&=V+-C5f{plg8ArFx{;g%RqZHl-_(vmrO=+5uL?`N5R(9(iSS&z+(Z$?>t?7p~;HLL9 zJ}NZ_zPF}5muN(pi@O74T5XcG;Y+d*=WAO1&Xb*;?epjK`|3-MTvjom%Lva!Ez5%R z3Hx;5{7TTu?@$u`?CbJ^CY_YdJYxYwo!gVR z4t4Q$m5h;(Jg7!wHN?`h>?neT$KXr3x87qxQUvnYuv{-JYlDiN2oOH8R||W4Uioa5 zu_9;(n{nXt-+|HHY4%CR_0{Fy+#{Z72Q~TjZIz_Tnb$liQ=PhZjw#M?>iwLN9za9upv#pFIcJJ-FZ0Ch-4NY8qjpJ;MEXO>y9O!i>9Y%MW z_T2*~{>Nx+#Tladpd14ABCa4L zFja5*00zBU&m{SA_jB$t;3*o3jMZdxJXJ3c%#*3>(nxb^ipaGxyA7F9E!Kkx`K)8| ziwUC;rLdPb-y6P)`NJ(!ronH1+av81 zBc>+OwDp7}{By_(!WkuAU-eocy;Ny|wshnfz2uHU3Sr=yP_04#vw(8T%5+tgRqa6Y zXU9O5jFRgQw;rA0m>NxZ`ke;WqvJu}ekOa5j?EoT!>>Q&bPtyM=@S^=XZ|_#ND?y# zMu3Qls5bg+fratT)xvAdxAl#Y-(ITGOA?OoK~1)vj;?LKQE02?eLnZc|IJijd(ga5 zH~yH%aJ6@BnEGmCj4LqryQj=jGGSumhX1gIxp4$}D?gRBM?;{yz+>>XeG1;rF~t_e z*d44Rj=@1s?Ql9S5m6CIcgpKaQTnL_>_Y!^oCeG`fwf>;=;T{|;asAGL_fUdDH$=*j|EZ~DP~=S@5OPy zAnY(7A=@#;0=fCnN8+uiNQWXZT!2r~bWcKJZAUpAUYB+9Rv8H++<=X$O0tT|ye1~y z|29N(Z7N$R+w!W|Nq!vM(gkviwwg>i$JnA{8_J_0%53k^!TP8+gh$yO!`z_SZ{dgiy}QI71At~|?Pe)L>_ zzU|SC{ohIrkDb)ACq#Q!Iv@A`^3sDdPRAM*>ZIGO_a>80q1yuOfw(BBVEI>=dPLRA zNZXqSRQ~l$V&aC7W$WRVS>kt*H22V+ChKuOL#w4~&RiH(F7xW2PqX`6K4228@9!#| z#VPgtmJqr67NgBD#mT|6L{a^sj6;B4>;l-vt|*0Mz#*0yga zIQkp>;7jb&&6YAFKNW)K`#VDy)^l3UD>=!As+V}@^t_r|WRKYbVcCiQ+uVZsU{}Sy zQWzQ+Q8J>oAiI5trQ!BD4KwAs`b&!p&Wa~PUf_MCd-`*}m_sz@x@#v>^#J)_Ek_Qu zHWr<)bg06>`9iOdkYvNJ)hR=Q13&r0p8vvsiP7eanZwE=TCZFs-3dp8FBwBL_aAy% zucj{2|Evh;YqOKD%D^v&lknYaSz>JR3$S*=deGN%9A~e+6!++`N5MrLTfB-Ne6d{m+Wfk#P z4KHpn#!XgHO{nK2pj`yT#q-TxGg3kl``Wea#8L!3=80V9uc3hP8MZ-m&~gl57OVn^ z-2IeogQ0JSa+-&A<}#LmCUHWKz@txm+<3_o;62gOqa}@%lyctMBRb7facbr?jy%c8 z?fZa1WlPq?1w)kX$6DM#K|Dvg3h_fbDLWt_Ew}IWo-TrVRT=2yZgo1UY&;+(phbbB zXDj;%cM*C$f!%Mqb%%^t?s9=>gi`f?>WE2RXEU+;`I)%Pcb3X@_NziM*X;Ur3!wv& z*6d#nvh(~Lms+p~Uh>?McpaU04|Hp1Rh^IWL8Xd)Xw~Ohi}@yiK5o}K8*;c=2e$U^ z?I^=%?1Fm!?W(=9#05?KH6U8GrR3(M*V<1K9o`9_<&mkg{c+QJD+xGSqfWz#oMSwc zO98M=t^MD&sb^EbepCuIEZQNMn2SsU>OfI8mT699e65)f`s28sxjB!=<7N4Geu%7>V32yu=^CTFBIYv>Ngir zm3Q+9&ZQrcZK+>3MtPt35xo5CF&qdt;ih<5PMa#EWJV@(VJqB!O+FHl29!o_sKtTo zj|cscW49y0Eho%Kw8~&Nu%~l$zhDAjwe>f#iMFq8jmH;&Vv1C}Y1geitPgEtAO_1_ zk+cOxNaNHCnR-97t# zZa+L5?}bTn%tcq(TXz-$d9w>`)Un#Q=VodRv)s934lzFRJptg?ITJ&cN5srlHG zs(o8~jgkDPDH3IN{o%E0Zpkef?@v<}?ELKtOB;*pd~m?~nE!tv5@~*1=_OQ7Uv{)s z?_0IR8oMgtb8gEsnAsBWL9;Xah~sgd-;I8}33UPk$ZfaovTIv5a^|~ED9NR$mh%Gq z&pO!pjM*sL4t4>McLbqfP-T^@pTSLWcJkw2vKJQ3L}xOF9MG#L;f!YCyPT=)ZsCh(oM4$?_rC1-87M8;XgJg~tQ< zCZ6vwWx+uH_{@O#e#=S2%`J7qU|PLY+t5rgAVtgTw``&JI~LcjZ!ox@^aW7e%I5UF zu07;nBOYg|9G;wk8>>!}QWyvuCI6vBozq?DdH7pMZG~yTZqc=H1eS5b(D!6n;zs9} zXjlo0!`sH~DkNcp>{qu(Givj%YQt1H0*en!&sqQMl;=?v7TsVYRnWo?aZaY z3ad4CgWbaazB%CX{~_(uEu(>7jUHGAoP;`gTmVzX^A#YKRjxc+VC`9<>BQb!%*r9w zbd;xx>%>G;7cC+|T-^*R2Q`ZwlBDl^rl@d)HYD<&fKlfEWr2P<1~_xj5Jr79Ockl0 zD^G0LFFvlH7%vRBA1vjuI*wdg6i z1bfc1y&Mgg#h+8E((0QO)+&?e@D!2LE+0NzPrB&)R_Mn=Yl3{~mPhKh6TyhnVYlyO zS*@)NUBH;IUr73E$E;+}ToRik-K}>{nCYxX^}JEG#YYre97A@&Fd~f8Wsa z5C2gNdhVR-m6p+Q-6r@Up+McO66E%*1Y~J@`5i4dHOpkINFCXqmmp{1U52DkvB#g? z149r`W=t${j3ra*Ymwh3!W!C7yJ084U*^W^fVXB29VVu9hO+Z7W)fosUR(q>NC^4P zBXWKC`tow-qWdie%WdR_fqlkO-}m4WJqjk`LnM|l)rkq&J8l>}8M9J-9*m!U+{-1K}GHPW$R(u8wb8KSTapn@kj zmEh|X>ZV$PL!eo5gwr+0QuxxUKt_Eq-XR#9%VKzyJVmguUd@dd3YUxv^=Ph_Lwzr4F*!~ zf4L1D%;CY#eo2_jj|X??{d>Cg0(hw(`+SmVm{HoVFbrzs5H6^-8ye^{fp0qDLtoth ztd5)^ytwyttt^^)v|fXCee zIGXBHHgHSSQ^lgD9I|~o9k+ekS`j}NcYak4w=HK|S7ozyO9wdCe-2$VPrq~VCAQ!Q z?bW4GV(M~S)Cp<$*f)EK#=08j?ae5k#<|;ceE)UP2piM=aiNs8$d1TJv#a}y#&`|V z_av*r+si1?PWT*Ld zsK~+ro)s1z|L_Dv7nAX&@x=Y7cErW^Lu`bqj|e?2D9oUL_J2rJg8&3|7Q6Ll20y-WQ|AFlR5k#+`r`rq=TBX=$y*zg~WWq%gp zQ_#RYI)h)+DXl}0k+2l}sJ^lBaIxQX+*-LddwM$okw27xqx%$WUsz$y%j$vFPyT4CAkx&S96SC)Vrxk=lHcVCjl~GTZ)(ESK_Ko32_vB#Irg$X)F82@O%)x zarxxv7f-UxdBrfptVO51%WZwz>ZtjumUHG_Ao&l3SVhMtSg>UjxFX6uR{14u(0YWS zFS}&jJ`}qE<3qLU7SgvG80;w4NU*MAbyWmv)QLks1Kq;iSiYM<+M+kNAS4o8;UvR=6 z#0&+77@!~&c(Ffpjpk1Wzfg#Z_0ihpzd~4AsCbd;TA!0=fGf*F=W|lVUe)d(il(4R zx(drROBgatK@id_H57W=-*=~;>T5R#m|MtG!x@_56Lwpe$)R5-X?^~doT22B2^iR#{*__)a)YXoZTq-dD*O-~keAbZ)?=S$K|BW)`EazJ1uknycjfY~dKPmNi zF0v+g9GcKH)Q_BFk~R~G6>Htqh3V}pJc>kYT=c~+z8vks%Y}g!;9-V&N$%XbScIr< z`dE!k$v~6N;g_%pMs0=~TW{`J@0Jkhz{*3#3y9t>@#O@J^^Q_t zd2KSm<|9fyQ)+!BoeEvB_!Ta1uJ^Y{v!siW=t0$C>TGO9uKTBs@HX2Bn~*JmzFJhB zPMw#kkz(VDkkhZZ{~a3}B!V>|rh+!{FXw4o_^MuQe%NVf#Eux`@?uttaaQpO!yygh z7;I3hBJ^IAq`{NlfhR$Q4Y)udk6&*5Zl09%`iwX+MIO3F-)$3ckbR*d_{)#|z_m;W z0RI^`#Ou4HxUNkd#pl81+MyA(oLXuoOU)}ED7Uc&A1#nOUw`LD?mljbcUaU1SFgj`xJ+I~q{`PK>L)yFOus|-H-S3XFN_4@YecC}0O z53a697du_ii){25UHMCy+x>)lxBK03-637Bql63=3wHri`;^QuV=w+n-uvzg#65IE zZAWiapyFx%>`}6(2_q%u9 z|5yvw%;KE$#NN-|=ftL&Z(vAa+K(!5uoqSg^v!lq{FDOai~i7H*zyHCssb&cJ=%dF znKteJ7olVJl%^>8*xc3#FWWRJfG=uI|10+abAA#k-7z zbJlGQ2W8RV)AhFH#otbl@bBX9{KK=Txsb8H#n0R#eH9OSc^Md8wMztw)LC6MxLW3b|Rz=e|jPY+of21+>%4d;ZJq@<}+vp zQO>NX}pO zOVWhh0{*GvRp1#yY`6}^3+r8DSrb!-ur%&JQt<3riM5i)b?%LD9z!8CC*&pvHpVV z&NCu&c8}hc89g>$W^u)Mc}({)g&fPq>+DOrtouDy6ZJH-%5 z59VO%b?YC5Ie31`T;`6SRGL~O5toHfwOcf9Z_$G7oguCVbPcA_kD|(7bRD|)12f_d{Yy7Hlf?72rlm+!RPIiA z{ic*?DKD;cvzJ4O__s^bt~HMdeTKx>bU(YvIh9+aNo)vr42xgBz%vn>m^iU-1xqIs zl;bGDeS?*_kextyUXU**PDNpcpai-^#BKrGyI9uZWI9dULXwmIe7(fa;KCo4xdpIg zWrK_Gs(|qN;yOCW?sKR|xPRFlDwJKc_m7@SUWq+ODavqf!dJDBJP<4j_QXH^NufIJ)?Qz z@tvVDwi6&czW7~uPK@X8;X~h5UB)2{N&UiEoD4pmp~zo(#2b2047&-$&AWE7gE1O@ zd2~r$AMXr(5e(x*2!3d!MmC~X+T>{TIkla{)2&<>(&@+lcEXs}AEOS5T=?cH67b?T#YMq(jb^h3(Bh;wv8$D{x zxW@BBqSi*tdQk<>k|r`)*BAy(zHRjg(ghK0=MYyVN}RNQCf4uMf> zbEP6ck9fjjh9L5f^bKpGnh3VH-redF7FnqvHr?e3q>n;F;wY<3N_;dBSZ1tG(76Ei z{<~JNYoaC1FfTvNPkKN1c4_r@tZxjB=EbWUrJnN>yV7a7X2qr4;G3MHDS1>|B6rAE z8dURKF&!LUwqMgi(<{xc0_Vv@7C^1r2GvXlvt*7(D36oJJU4fY2_9l^(F3aUZWj56 z@!xpgH|Q&?d}y`GEnE=B9B&X)d*Mk6#p?*`bMhjTZ)j=(7fK9S=5DsOVR}gnHR0mD zeOV8UW7~=z(b&C?yaf_AB~DlY(y+LVaHp)&mLJ2cn5={g`+~4sF8Jp?d5xAPH@XFy zYn4{pFFlaZ*aCd7?FH>iC@Gh=^~9t=*jg09Z5=msPqg1{9l3C3S;7oq!7kZC_kDfYcN8B^tJt3v+B2REV_KZX>)(9dJ)h(S9>Yv#B&bXW(B2~+kwE6dTtj0@M3 zx;d)fci2OdXFRDf`g6=f3tqW`;ifO$MYb7Dbf$6+x$7~fU!jB+c6TQ+kP4@bvIII# zQt>R^dvA!e1o*@xEv1~HZnp&D-Q~7aw4mQx$ti><+iHhqlyGesU z;)@d`-l2r1Ct8u!BIEHZ^zCv(b51zyJKRB^F_ejfdVc*yPner4To)PXQCXEBpjJj2 zK2GitjMc)@=~?O{Z@rKCa3&nw=S_Y{`d6gr7as#L^|wFr+Y?iDfoIJ;q5aDB58k!e4B*mVqAMm$jPP#|!F3ij=F*uu!(91f5G(%pC1{daGBWQVf zX6rsSW|O|rc_NuKa@(f2e3{dIu7}0D5WKI++>DwfJwmbiFxBe3qft2AfJLS}9`{)f zB#fr2y>pNeZ0h#|&S=)o$e{ewQ1vQqQZuN=(T5NaBAc?0p<3$Qk4ba%5|#&1#1 zXBi5%?Mi1)U#QacoQS^f{P;-HAa%84w$M0^*Y#j`<3c}7-+`3E|9(QiDL7A$f&W+J z)b9$`|xO{i7!;*iHFE8@*ft5jR{6uwDiG z*2uIaku=ly�sn(_(N3K4mfwsSlZ`ljqRYRX>Z-KIY-cOGX4%cgpI@$gC}Ri0i7T zLSShD+U+W!j)R^HR0vDycJb7X$~pX<4Eqe%c_4p-$d4Vf^^`7py{8tBGrT?F#p-w4 z;It_BkBCyj%%&F_&dbQK^b`6wjVP5f1D>m*S7(kZSH{J9#uwYV6!WJAU4n`iUF$oJW^JsPbX%fRC#FWMa=YL3jo|{R{C93o9_G*AG!T}`_Kq^OtJfw0 zI-Hg?pH?>pW1SxF-I!1##$9A_7?2|_>>SmewxMcGVClU{dcSlN8r~wJV<==u^aVwe z2aU{sGB{q-<4xLoZf4}7CDSabvj)~5-0=jdNLUo4yO6f-5kC?orA*V;1IUE*mGlK~ zhuezJco~FEXe>%h$8O&IMM|I=kE^*5r1)~^6!U6-_awW9o1$;cJ!JLtBhX5JQc{7H zw^W=OJ|J%)@=_3uR@s^?9)1!x!&WpuB$$xned))rBs)7|PxIO5UiCrMO#bC{^i;%a zN*p;>pzCd2$4S_Kk(;}B8J>i=Zt9d&+FCy5S6Y_RJQ6LRf#TUWy%Lf0esbio7(t1e zYO*ft9gTrYoX6N4M)tjMqr-kZ;|DzBz-+tsS!#6wjc@&grr=H#I8T!5r*fm|yi^6Y zALc%gwA16z=6wcz-RjL7jw%UP;SOe0#*70?Le*RibWGUY1~vMMG%ec8=nc$`9Y>u< zRbPi~QmGQJOG|1vzr%1SKgv@z$4^8Ewf5$7#zWJ2-ZCN2hg!JJFc!XKwX5S_l9Jq8 z<=Hn)mBoZdfB_w_bMtiI^lbcu2^kAq&qW$nWG%lH%gt>vP6Y*C?uMi!si(_gxOM%7Q@$2ve&VJjO)P`MP%&JS zC`ZDXvw!e&QGDkwa1>VO$0eg!Z(}*a6rXH5-dAWFh69r0;o3r{J)_OjurC2D9yvGh z`aZdJ3|I_r2*oK1!{4gYl34(<0K0 ze+d+k{Sl`FRWx=k0{H`en5jZx}!g!Ok)St%J#wKtjwz;(oF0%R4}t zmjJc)4ZKbI;l@heaE6h9Pk^Z=YQ6^k>J1a3+*pn-8wnjj9#nELa4 zPtb?QYW=nCo^Y@Ep;DTjr}*jwAIZEa&pi)Scn zi0w32CxrbI>SFIALzB|>X^(iaVX?*p5pDbjakh9+~b)N2y4tH*}@jRbU>G)y+Q19}sfGVVNN zZNwKEm2`4T?-LrO>L_r`X8KBbpl%-#HLnc+jC=H_7M2#S2k9Pe#k>vWpWw@iEj198 zo8mTRX}r@FK77&0X@F_LC%2m+)H}gQP0R2(ypm#<|KrziE1Gx0;Z^ZrKeTpoETL-V zSJ?g6;K@@w)~yO%dTwO+6Ii}&GuwyM6*skJ%D@;T^mS8ce^x}wd(DK$aoe)JY6sm# zjyr8JDy*{;qdesi2cMWhFXC^^)V)wX0kqYM=0qs?ilu5Ck1FdKvkw}7iM46Nv+uQ) zHOZR9>k`g-(^7p_PcIZ3U(ytw*;@|n%of`WjVwRIjg@yu#Wi}Y(A9HA9rl~S7em`G z7+EM;Z%*jDK!LOw(^!O=FQ!QQq?EX`_=H^tpG#|=c}alGV*<~8@?ve}K?wm>)4Gal zmF`I|{bQ<+cC@~*7n%K9F@wgvF3%i@1iJK1slz`XAht*`mil_)s6&ATqz;B&6odil z`^T3w4abf8L*n(yb4{!Mk+W4Mb$>~6cA_X@AcCEZmE1o^%z|T9xozCPnFU9t3gqrz zipof;V=NgO<|!-o3HO;vJrOfZt7op)mv=?a{fXuIv66r-PSNTx5Lj4Nj!fIh-*PQF zb?_8iij!z+>VRuRzpwPVca9G4;43vv!(f9)+>|F)4Ww&e3h-%l)uT|jjM287bzc54 zDm-}|TEwow<0b(vmhp|a^xj*YS(P>4^Dj3c{)7<;t!6kzO0fN=Nh>O840K4|bcuun zU*EI5nSB!X`ex$niEc$>W(9R>IlywyF$7JKp+I)Y2+ADUHc@F;SIicjup9nEYxq(n z1Z4kD&yT`{+z=45du{;x+7T95tU<2<5_%-yC2ndRz#*_dDQy66G7B=no0g@#H>@1= zsnM}CTx={&Ak^P38p4muMg_+9tBwxs7LZnP7(?Ys3?J;)h6JRG>zD(anKk-QOqEIrkiQ+i_e7g}nDvuM=AzVD_&gUX}6pYAzw4s}^F_vhpwJ}+a z*_wvuT1XnC3ikHF=}BN?w4WQdZ2beH0CGv+3n!zA@!Rq;m2AOLp3Oc@Ihe?W;H=?h z^Y{Ws{aqy;R0zKS`*7~|TiV&$oZ2;&v%7jGHC@G%%4|pR3}rXn$!8vDLr8W!;izFR z5BJp@24UN>V7Wu%JqucSDds;?1iUlu&!a|-c-`eueIs*feUvMG!s}bROpABGWaJYw zmMT0h7~T6xE^)B$SXfniB!Tn^Yvv*(3$ig2e6@k`HY_P6CSlTgt(O_XFuT@>-9FH$ z3FRjY_OELuOH4Ylwp!tjy45T=8YIl!%-879xsw$MgCYqLr z#$=M7lxQcGM8w?8bnO94=25@^8>WrZ3+P z3inAI^hhJ*u{3xD4*3{HqMIg8rmtmuVpnGjAz$)*{$w(hLidH6!F>>(R?FPcM7AjQ z` zFX0+w?$ag6s>_W!p}0VoKY&x|(jnKv0at`jcz2hxCvX%Crg?98LRWE3_dGD)wAkhu zrwF$v$6sPPJ5;IpSDHH7)B}Nb(`L2`Zkc?e%XHN=k91I?{Sp3K)Pe4K>o-vY%{^9L zRE^2#N6%@?!{Uez1KT!GcBel@Fm3SiU1is2-FuoY*hRI$Dq5m)f1{IW=#$NLd&`1+ z*}J6p?Xf0iL?p=*@Z!mmQC!${cFBdsXGhR91pE-X?v9cfVaJ>uYfOKhJO8#7Bx#i; z@f`Q(+wRYl^-gq9kyr>FbNr@jpnBbYDW}?f7&SO#d2?fHdR%qPrOu{ewo!X`>*y!E zn=WKA8cbHMaw5Huy)%@{PYzFPykJcDbZxORN=ybhShgd62@F)|mp4X_;jahSco7ebjooq*wI-a_TAWqKohIKWj+JRwk|K@c~82 zP4j!Id|2?+9d<*xdNyv_8%mS5r!NtfjAmeaaBFHkYS#CQz3s0sthD3P&}$417A_Fs z+Ha#sH3E#bb;YAw4PD*<3?=1K>YeaO75_BifKL@j4D3~VKRe&IM*n8PSfVPfc5mNU z;&9*qE}>^FzYiN?c_f==@|0Sgz!3sj@G{;LJfI^+9CQM#kzij1u~)CAcBs!uTXs}k z+~B9NRbjK`rb=`5Sl#_pFq^H^?Yjm1!yKHdkjb>xJ)LmvDek1OY%k48L6pLPh`Cz! zj?SjUckeNvahbrSx11qumAQpySix|yMVqI6oCBssXC+E)m0na}AaG9W?%xc4iG+4k z`UNv-au1cl!2L7!7)xeVM=1#uPr_-l?0K{=3IxW>%?@yAsfbN0ny*S*Qzd5@1PPzk zw%4D-r4(%i`^=@DtQhJ>43bi=n(#9I3MeP`l;?FSdWJ=3Ra*yG~ox({WY}P99^#{$C)0O&Z=c|W*})s^fdcrvmqi;{Id>OAlV6v;f;-~{4Sx+ddn{%GS|7AxiyPH=foN!= z9{7RL-#3qCwzD0?OETt={DR6tzhVFf)ab%1l3*Ek>~4uE&j3*-JK=wELNO<Wm@ z$TP-_LTHWFI|di!I9i(DXIhZ0ABp1xACh(phxXCYmM0Ai!j#GeQB(;$IdZKiLbX2c z#Bm&#)VMZ>+^&oDE)Ihh>J=h4oJT*|Ct@975U8`Q?QOG#%ecquX1Kid3aC9MS;u@r z=ayq%+aTC!NXrm6k6&*PLrW1A?E1EZk(gWM*b7>St|;z#cSz~!?FjtWdDSi^=-dO{~!;+*zFFk|RUratGA92Oi=LO}yH7sWg3AT**jAyFlOR<9Q zd-$C|{B=C9*3HN3&72mvH|41ognICSPuI&F_)++J%WNzY?``IiS<5CVEN&n|TQppu zyzjQg+Toz}Muf0ByRlN^`q18ymZ=GBaaECLzmN1dslwBE5qcGzOeXAzwN&yWq+wbq z&ET~}s&+{JFBGMtyaF-gl^V|Qs*c4CQH}lgbnhP=dFKb73)NSA`piE1M!VNag(h>G z*i`RaZ!dU>lR;ofO1*#Z)R_|357|5Aa*@yMpb?31XWQ&J-NDnxG3aXUGns`5iTAzr zEq7!~+U|`vG;R4 zBHF8It{DTv+$?Xc+J@aqP-4~jxwW(jkv2v}OK92`&$RT(&Cm2ohO-v_&ddx`XC(K2 z%BcwTgSMPk@Eguj`(DxqRToG&AC2pGan64g7l@;zY-rs2Z8g;PDz5M-p!fMn1er|P zonCq0Xf-k{|Cc{eWy$a(2}im(`;XV(Oh-{XLD||uB+7;SVv-+B!Ib5L=x|&M_9*h4 z*xw0b#_Jwvb`7L>ctjX{9lR4$891?9UDBIlVX)+`nEjFSRr9Z$6)&aF%lV>?YYSmDl!|pu|5`L7u>X16SOA5XVoTZF;@`MqAs^EICrZx{;*1; zQZBSQ4adF-HSMK9W&^VIkJFF`UFtMFFY(nma(yAf-|dAE+WINAYIF9oB4j_cE6CcKxq35m)RA(}qo*@UKZe+U3|E$PEPGNAHlM&J=n>mok>&AFEkj z!xoXeNj82HyGht0|1_54kY94>rGZdQ%H6wx8Jz1)z(x%_uRu-K7DaklWb$F(&$F0AQCH=x z-<@j~!8+&$?~81-zfDdiR`gh!3PK2_V7A0ztJoqJ?;)_30&yNh*g~`O4r>+3VLJyi zGOLg5Cx%~3rqH-Qj*K%)67G+JnVZmw+!y9MVc*4(KQL&meq>{qU6S-P#xhsr_8?~S zO0a+bB?K&e=5u>WyRf*}g7>nA_Wzw>IYV|C|4hYJU+%_>Ud7_&=%P2(aP|DI@Z;d$ zG?pgl!0NNp^Fiiq>FU=M#=8DhvNorzD<1vPyWbwta1&3N5k5&`k!Pdyokz$*U7A_l zzHfFpsBpJR_q7fG%5)u~GY_)j5jiL{%vkvAZP+1xBUcGGCp_&`N-2UM8?MkVh~i+# z{wFJ*QFfkV@G_2Y$hNrvkGj=AW0NNp@e}WJCZ36t)>_MT*k*c$6gRSb40}>rXYRu$ zG;W(F5kbIK#kz^A*NMI8Z#HC=+c`qf&;I&`hItLT-+6)<~X%TOy%$dpmfdVMS~CvIm*3qO*%u|t*$+e#z$X%Q~JmxIA7KbDl0 z;ibSCW~{Sx5^#RHHzVf$tXi}pV<#E(PS{~@518zAh|N0|mT_3Ww3um@pO~#nDUQuL z-Rk7<13Q@U6_BqOd3R03}704e*J>Ec4Nr^$GFH2&`AIjQkE@&d{xbK_Cg78~&! zZ)6Nc6Z7uCWTpJPhpD1kGli{j;!viZi=Q(dA0v$LNJ4YOpFbMKo`o8~DqCXa(p;PM zgz9eVe0cKDC0~mVetX_(Jxo1(C4>pk;70u;=Upqs2l zv$ucrn2_JjIkdAIwScAVhWxuidjtdBzNfl}ZR|#dIN8vcQJaWJd>e49uCFH2HsK)) za+27IQ**-_)}ji>Tnc>)a-S%XSHyB`hR?6QT9NFcj&oCRvC0kj!=1Qf>1Jwv#aTFO z+)ax154@6DLMIxdNRkJVo52h~OxaR~3~=`C6@vBg8Cd^T>GD|6$?oY%0>YwE@O+rG zeyQ@BH`~Qv&4+dh1Np&kzR%wvtI$QSpkdCgh^sG471|tqyLBSWvw|(!#e`1ZchR^0 zY5X#l;DlBBYv(sL9tPhsP*t;VVy(KwQQ#8npX8j6-1=_K35|?XQ!XD% zEyzgZ8fo4Ntm|MUl&ORPcwhBnkj^VfTVkp+w+4YYU3W6YRJgTyI*$FIN2*R9 zm1;s76}W=m%>B1npW=g>h<{=ulOk(qP!L5ub}6UwmexXJ)`MdIK>&`{EFK}{(28F9 z4>4=8V4H^NmPLQ0y?y&~?ii7BA0Xh5ImSu1a>tC@O(4e~3;02Zxd;A~yZPgZCE! zdN}ZZ>0ul;%CF+wbl%+qNye7QA`!o4ed09zlz~~`dKk6;x0}^~s?(Fqc+_F>b>Ah5 zq^`%Q7BE91fByLdhHbMw*_>N0aG;w9Q=u=IKKYwn?Qxb2hI;r-(lZTx#pPDdhc zCS<<>D;k`zUi@Cm+(w;nmX=S=Uo(ar9Uyto2^Z-U!y|IxTpukD$QPS7Wg%VNNkrVC zYPDs*_3K!t&}SH^4ANxdVatba%2ggICyTNilEs|-hI1tjzfz#9-Vs}mFCk4&Ryy$H z-1B=|*z^uCW3eZW(Y6MsR5w`#J>JjM@pArYPv|7+-(w?iRyDo&gx4J768a22-nY(& z)+j2do9z6oMrNnl>dpbEA>H5HpIJ*A?_#o-UO!vcc*oPdtQDtO9#jK;1QXczcQL-1 zoeTWi4Mt;c1a~1G)o=SIYJi412Sd|fg*F+~-0v$Rpe_4r&pS;S}lM5}=ox?{ni zTTi@w%~xn=hqb|LOxBe9JcDnHrezJvUpWVm_v~gNWqVra?8ihR6l)M>qMW|gE3*&v z4d&|^m7VNh5x$x?4cK-f>$DxS=ls$oj|3On9HABL%DxzP%ck{wluYox(TW_Jmu-n~ zC-$d~W1n4d!+~ezSzYAUfVNUj_Ojt-m+^zKW<<1}4t~xJYRXu8@J;;>)7=-rMXXek zzDb)+Pu~WF_qi=D?vZpq1M&dv$hAaYq?(s{w9zKM-n55GlmOElSQGgjUtigR_NCx* zMBE%fa}004PeDmbKpPHwGJ?f${*igsSjN^@aEXO-+U~>RIjyAC+Rr5! zg>JO8T`0u+Ub{Lo1!;BNeeGng!x(1ppeONt4x7Zqc}>TqZSWsoYx1rZthp%Suj0ki zaSa7JLuA<|;9-IyC&u)-&3|KmqSi-A?yEpL`BxtCX*Z;!H{`K5Zk#+LzNqa%UDA?| zRC{guFTBNG^!7MamUEUk^5eXC;)E4N_=nhI=p(k!Q|>v!N@E}^O&A0FqZcr>DQ$x0 zh@`J|bKBI2Np;AO-+UN)$x#4NCGmdxZ7qAa!X1)}HFt1ydF9L6ptaqa4{7a?x<5}* z6Paa>U)ZbYTlU$W)3ft>WnyJ~D3E?sMi(|Aq>(m6Q#w$ zq6+}D2Uv$*;T!YxV}}94`0;l(9L7XQJG>jFbi{RT(GI3-)g*i*v*wB`dVQLh@S|)q zbO7h|DKs0XQJ7*H3LNcCf?%9~7@68s5ZT#`9K{0`_HPYzga(NHFu9rS^RRh@DLS3u zNLul!h+TTPlRqAPMwvBKt)*UsuCwxqMGNJjKK0j6B8Z2bqWlPhgQnS zyX&>GYVGwF8|!p^qJUx_DIpI?mgy({!lB99vv+ zWJ|S4m+BCj!g=Y}A^)JpPuXh@$Z!jtsI%cd?fHwU&Y+*Q+%>MEY8y^!6*oT;~Rz zf;?zMxIo>AHP}rHFH+jf>Q~kpbb^NZe*vCas5a`zA2TmzJR@HO#0hhUUHF$Qi0e!HY?MNjW@liN?+o(2LB&WSp`>ijo5FWP&+%V*}Jx!%= zvf=O^hjC{s*E7jj(t#Db(EZbxhieF-4i{!gz6SYU_D?&Txi=I56e9$nn88fte^e^Z z0)}bd1l%3rjff8>rM3x(|5vJVb^Vm@tGMWNbl*EMBJP(g@JUBah5diw*xx&WVR3GJ z++MCAdHrJ<`u1IZKq!BEo?*Gzv!mA&rdw$ zDfNxB-?wMD2SV!P`ZZ2|WaZd)ZRmTcHEe6oku^N#lvKKl3DRj2pFqRnhhDBUmF(Ew zqE*44zUvAmYOmFVAbx?BFa?8tk~5BHJ4Itih#YGuP#zFYKVF~zVc$Gys;b}fbrHJx z4PR&QH&&>`e|cD7eMEKw}R&q|+l_EesTyl}13bYDnXDmy&! zZqd10wcJvq1Z_pA*nN&y3`hwT0EWhtWhCzdL*kLa#OJFS9}tKdW=b23e%Vz(Ziw_VL5Qami4*=1ya4c_y@m*PFSKmw8y!x}0`mGvkovye&6el$DYx)o%O?`iOtD|l-xpA_HZ{mi!mY~pl z-Im2<>d-W`btO-OUaVCM!&9Wc;{BVsj^Sly`?fDXJ>qZ`VO7m9%rM}hDXI^cik;ZG zaXPHrH(l7dY$8^Zupzpa~F>X<&RH|$h;vwJS#t2a^N~WsG$Wr z7Vk$wjyvtD#!XVXxLJ=>U5lDtcCin1QFw31qgFWBH?L5ll32?ym+fp$hcbNoQ+&Y! zUy6)t*G}G7pe@Id;7EAj#AL5)O#Ut&ZkF-yUBOB%K`%`cmIBa&4ABxylt zxDzuop#hJI4Lu`oRsSneP!4J}1bJhux=Q2|l*AW{M-z2jN_mc2(KJXgG51xEHLurc z5sx7MJR<{)3t5c;`z9>~|K&x3#}du`bA2fiJly|Fi}38c6umEZfvxwieR#c(0`x*U z#&qS?RJ>NRaV7!2mLuYvs`DZv130_JzYIV^Qzz~eJMfm!&3g^qE`IbBZ|#XhPNymY zU@q&7_!TVez*39>CGUDC<$!^$F#Yk>mg3n7epWc=0;=jYY6WN1(YEP(sNWu2XH3Do zp_Fs)+w!!o7u)yz6FNfyjy=!3L6y}vtlwtTM>5u{(LiQVQb4-ubmz&vVXV;2ABN&7 zKaK2vYd$Z7RkVJyq8a;0%Z*$*C5ka7)MrIg3a&rcP;{USQPyRbFI)Se#gLp6c+#kr zHZ>yd&47F)Yun4&s9&^x6@(GNl!~9C9aB@^Un1!@D-gQDXs=juI5rvuP zTj}bN+QRWH#z2aF4a1EP0a|ILjz=FLva?E9am#BN#80cq8D?0&=*8U+FqF2{Bp8ST zTZ%)~1Ov(6cLkj*Jm$DdlO6^$SLid1RryISB8~SSZn`MdYM}eX0crm2hxB3bw=z3E zKj<0C;=f6q!1NZ`8~?m(0w?+C7qzORmnpN6T!Go^LIINpfLr?hzqkb=^6~~H|I?Q^ ztdOsRcjx>Kafoohs~AuFtfk!z*A6wgajNk`JUGz6)u zS~^R*Vtju3G_I}1f(vc{{=_`syOHa#6&JD{o)U;nbYPRh$=Gxv*ZBA|Z7SlP%~3P^ zd7UcY@9ACpkVG>i&W)%4Ap}mHtP|iMJtFR7kX`L__^vm}J>Pui*;O)PD(13ndBDp+ zC{E1}zw_;n(`@Q!yGTMFBdc{@xEa(Me*5YVmR)3aS-7fW;Eij*=-jzn@HA<8e(bXC z$2Y~dMEo2%XhwuSz(Dz1pCF1o=6lStUo3sG1uYXOYhDiEuRT$yypQ7JcFqG6oN1_a+rD z(?fR8aQ0MY;eBa)Oy9bcZj&$X$}AwyUU}Kfxg0dnngLV0&FE(C%*VzD-oYLF^E4Ck z+*q-!nzu|!cix(ez{CsH>~lchp?{MUsx?Ph|2E}28@A9Nb|IE$<{i zpIAEzKltC$9M~AjNX|oU9rYvMf@UW8$#8|MFi~c@GR_sGNl*IQL!89u#Re@F@(Ie7 zyMlEY*KkOZnCTk71KPKaop>mL=a=Y3KhffHWUpwOw+r>8-7O`WU)lf7V3j}#RrD<7 zQRy1MG5N3h31?*Dt2hZ|hiL)kxEsLr+RAgg#UncZKKb}w{17liCtv&55exzyfk7N1 z+Zney%#p{kNRU)K$$9Ur0H4 zQ3N_00M?lOvi7%~I733f3>3hWZ=W|Hr-73KAd`$#LuC}*&N`)ImPzt9^=&c)&~Wd!7fhGMO$!2uN3Neod1)p&Nlcn2*?vK$muo(MBFvD`A z>|ARx6>IX&f3z5u(73M~3NPjFMN$!k&QOz=A_^LLe)2G2XNY{%3L2TQ^5A4LqKYkjb;E{||s(EfoDQ$`sRAhg}g z@3~v+g=iZG)#1|fl;ZO-1Iy1p!Jidc!|orp>!D{{YLp+Rs%5o%!6$NYdQ^hPG9Nb7 zQ`gbeq{xj0RnzlWWG}9#^1=pE_Ks%GouVulOR@ELu%RyG7qF{`{|$a9vleaRAym>O zmbUgZ$`)5~Pc)<`^`|5#|5n8N5voU%$3gHrf`1BTCqTixlJ1`NU0ey2|19r2 zrd=WHM28CmCK5l45idhZ^p7=oQ_cDKJ8neWU>XS9rAa?$2+RfW#J%k&B|d*%a0m@_ zRsl?LM0$U=&t>m~Xv#wVPd{el9~0M(70%Wf*xp+3Pd|1>MteuvHh`1=Y4N8n$b0NQ z`Vrpb6pEFKuj7XCEFP6iY=+o7RGX|wnEb@Pt-H@0k~vrvQnUfK-P4y;T*VXuX-_m| z;*|7L)6}*a$$L_corbOG0Pd`Hdo*PJ0Xo4@ZdKb#aFW`JNDQDoM;#`%`OnS*YPV;J z2u-^jz|SYnggp5gGrA;v+D=#99UuSh3b+jo%y$sM^Q+-wV;*PyRfQ>(i}L$47hR2{ z@^MISd_f7=wR+qGl2JA$IAEFkW&HwkCk)LMm)bP5*EMZXCG>NByZYar2|K_ZlDV_h zKs~{SJ1!O>KI|_|Qh)ul!qqagFX#ol?9wm}w@n@qXMPwpfx=$W1MZ=e!mq$`{a=~K z*KmIJ70TXavzQbrAdq=-yg-}vmuK}et7YVwCCU8qkrtt^Zu|% zzo7|mcxnDP<}MW@1@;Ys9|0ur7IJ6iz>v_9{N|0@ok8R_jQ9-&w1#b))Sv zM5U#*gefl6!I8a?r-_FVfumgD91UL~V^dhji+$o&U7FN|uymIDP>Q5$-*P%`LGf(9 z?Ir+_j*obO*$_kk{9LQ5@V&%bA6r;qoF3W)dU0RusX9wpaLJhme@!AgX28R@>(Zy) z!1{z8RsdoFWW?>?EnR*A^-8GhgTJTvO*LAgtbut_efvAe>)vnv#7yve!u&EXuOr;c z(cTLKAmr`(0ISznowv%0QLfgJ%YkY^{hlJqPW|QQE=y_B?>+nP2jA)@d?oz<=h4qO z#?|%QCayJVc;1=Zq9uS{1&)3FOUGpG?qbI`qZ%LQA4|9Zu!MGMH&a?%$Z%zmIjwm+ zI@go`t+&!^dLcUt+zMRa3b5-Ieq%A_dA_9K3ex}4bOdRyxzP>9P2Imwdef+^@>Ssf#n$x$ITvz& zvimIL(n(Hzp4HbW7idp2pCK%BQ(SQ87efUXQ1D3XkShwHI^DshRT)08X$83A}+>6y3CT8|4f%8L7BgXMG zf3JlsXgl;!jwnJLbJk@AMdYHoorLM5jDy%!=*_jD_W1B$ilfRn4UbR5;u7=B|BBrg zIz8RxSq6YJ&p7Os|*BfG~M zY6T#srLlynNq0j++G)NwJ=0^ZKETxpl?EQ@Z70s@ZH%%25Y&g)l#D=EF~x^s*+kd= z`;GHEoai@@-M=Uh7M`$uEVILp#|?%RRy5cr+|ev!v|l%&0A(Ae=GBn#38ydBNY=IH z)*D9-n^knsGS^4Jk|x9+?U}d9)F;Mt4&?q6*VuyTL8^!%=9YhoLO>R!^?wng0{c(I zsB+O3OLqR43T;;R9t^MA`0x`A*aA4|*{k!RfBLliEQh|GFkDifl-Ut?2|VrQm}ycV zCeBrOAS}KQ!vgJ1<22z7PHoE};XM~amBOsah?p`>9fJ4}8{I6{6Dc*_J>tO+=U~7; zbMe1c$RKFF0#TL~_gwX##Bi2EZc-lty}7Udgj@s@U4I`PLjrc9IA9?4p}~%~t|SKP zxIj8!Lw@VcQW#&X=9+??*gd^bz1^>I+YQn#N7M0AUxMAikj??)7U8vaFg$3^K9KWyj|5n>~ zP_!W#@nv;x)tv$`akul?M$v1%1?3N818Wo!eHZuNLwn##e5sS=7d$MEkLnk#=q!sOV}qaW zV;F}xH&}_2(+7M@CtShAnKIgNtdaCiY10`e0Q1QrILeRRDMOO=orpMr%<}WUwLHD= zZL~Iu4Yk=d6`YufsrGl$KWX}_xkys(4<{=fX4Js*g!t($FnQ7AXl)1iH7MS&xPdxO z^`)8;y>cFLryW`|xG)sx)F#>!6@W@^8)9;gpiM%&sJl@O_)fKK<2;W6boA(e5s2e& zlB98*1pDc~KF-6d;M^63+RxqXznCu(!6n)ajQWDN8|d2ZX%?zt#38%8ez!J_uF-(e zJw>3b23&ad$Ou?PjqAPR$Z*e+`1-RQIzJ0YuY`S6^4_^g?y}4Tc=x;NbNzUIe-uue1wcm_{@>sS2(LWy61+^e;CYX= z%RplMnhqJbWv)(@@!Dj8T(4+lW>b(QO+PThDXA8`L;BbH;T46MliOjI{Lb)3ip>EI z;xcd$X$3ifMf{r^GpIHqe)+WL2!LiKE?~k^42(V{U3|t@7pG!PZ}7g0e@zltt?#`B zx~hNP*^f$~^F&HH*fFc64wd-Buf>1N)x)IET*{A`Bu1lS zMEuh%4yYupFn1WLrrBLuDHJ!w(7DUr? zKP-;9L!8^P--|x-=5)YGeGIlvBv~6$u_oyf!!EdrJuH5H{M>ISUC&P`4grk;ZUn`w zX9-RA!7Q}`8j1fmS(>#l=aNefFh9g#BGOL(5Wb^D#`n$O`9Z_*kedLk66aiIJ#}g` zIQYl#gRVDAar}1mD=RS*2n?H6khb*lwPbB9`;~{OK>5}TjgeTb%DDsdok_$);@%$OdLuXLGhx8Bgfh4Mzy9> z-a3*qZ~Z@XePvXYUAwi4($XNCE&=K81_?=NMH&PIq`RcMQ$V^Dq@+{2B&54jq+`>3 zYvbedzUPed{c_0IDBSm2SIle9d4*BEo-}xRKjvBvSy^B~UJt&fY;IRf`aKF;VMJDi z8R?bGJ+;hYoD9nz3H`NHi5Jg6;Jf@A_;eD!h{Ov1^zXuq_DcPVMn%oAmjo^@PBV>S z+G1xop&T+6Oj3gNLXmbArYRwgb9lJJowlNzOb2&;BsJ@@pS)B3VvAgXQl8xJV~5ni z>N)q>M&vJc-WE*iQCq^-m@yu5pukHEy2^Gwc(-p%h9$t{DD+%ueaSPCzix20Q{8c4 z*_xLUy_f!DW&T2v;sf@`MQ$%*f~`E^weaPWts6oMF->qb+m=k(LTQbJhoHL>eLVLZEE< zjC!2XY)`?Wa<&-)!Y{umNo40@dh`jD%xQL{ScrbmD|3p1%E?8MJmBZ zd!KtlgcVog`bJ}ElI}7nK3f;5X-Zwn07#pr{+(jgq_5@w(<>CF?d+%@C|`S^*1ynW z%vj^HhfBDM8<>kUlriK>qf1^{?w>2Qu*HSN+siljkuyJ+BIxba+38mgt%r=_%xDPbub;7hY-wz?<`OHN3l@UDK0pAj&V9B@9y>t-2J^@;>q z8Jg35Wn%~&Ks(;c~!KnaZ+o;ErK_fxz4dgDN)TYZJk>raetGcg9CaHNpjLSv_QlTr= zUhZC!?)Nki?#tu$e&W$FpiR4p@%=5dCwLViYfkWBEnW+%LtgiRuTk)b+rek6QM+nG zK&SEMy)?DD+A}L0#n9kbP7>Q4NWrN#VC9(x~QxZC|r->d{0SMt$BK*8xu`+YLC`Y)R{gH2)Hbf6GC^t zdWhM%G?`tNm8>z52;3mY{EGeUY7BeL<)^D9x6{EigH57i%*>@XHCOD3YxzKBS7Unu zdBBF|RM}|i-P2tA$OA5UG5FkruAu0u>jeB>SKF zt&!+&QKi49#H{PuZ^-^3kYDo6tag?(J_p_8Ox)#ceNKaRMm*sR$=_s-Fb=^KGW@%4 zMD8I4s{+xf9(j=_6i!Ikme5R%=2hcm9>BuXnl^~EdgMkHUM0gfCKfMv2FJ{qOle>M z^EO#Dvd=}2NUuIWR!sNe&x!rg7|;DH4#9{>!qx>v%klEOFcUKlc0A_Jn!RrQRyIiV zT5~SBy6s}t7Z#s_TE*4(%OMkzJ}Pew(BDL_@F0c+7iV0BPNVa zo(QsUr#KPmZUx6}+{xD2ND&R4>Jqp^NmJh&`wG&r;v9lJTCRn4=f7{4G)+ij24&di zHANSp9n157DE-ZR0KLw?Yb^@gJ+4onXWK1W(q9FF7a5k?h%yCs#qC>IEv()dV=XGV zpXD(r@gHw}vnsocw6u8R->K2Q#|{eQ{h`a?^=;(jL0ZHR4r;H9WkaK}Pj3xR%1@E5 ztxR|Ija1WHTU4N$>A&;6Zn}lyzMhXguNs@=dfMFaQWs}BR@ZCL(ZS_(SnKx6+;2k{ULNt_*3_2K58VhuBCw-* z8Mm&J)_7GFh+g*en#G~Ov) z`7+1zQ#h;C?u9#jLMTWB@xowsHS;p-a%GR9|9NWjR&(?DZlF^FP5VJ~QRgzB8SXiF1YF4R<07~ViOC6U>zh0hrhl3NMOhJW} zXTArp+ucY>xPX@^<1@`P8h|+{f5MdrJ*oB`Nmy!di>suGZ>(P1kS*{iWJuj>x=izk zlXrOP|E!q)k}!d1;WS{fXHjb5Z;~zAznbJO52L?vVqFd!SZXkDcTt5bH!pz-F^^0E zA;Oiyv-nz+Xk3XRqAvyGmQl(MW5eI+6Zq{hD4Wtiy5bT1>F1PJ^Bx}wO4mdRem6)& z(1Nx>`vf}t^1>)b`&P?{KYLQp5oFNLgYr+?71Wnaw^l9;;ttvu0XwyU(A%%vsz|c9 z?F#V17k?uNhRKg;G)wH$#TK}V%E~3(TV+cKGpzUlUHSj@M`{zs|2Ps=eZvw)TK+cw z_owgv3Dk!MG`^%Be<$s;H8ya`onwU>KY#uN>^v*F)Rq6w{{{Q*bKam&sGi3j|BU($ zRw9MPV$Zhn)A!E~45=Ufuc zS_0%6z%3vcOcyq}WfU5OyaM^v->&?sAM6m9TRqa(|&6T@T{*`cy=r5gJrd&`nFBYi)O4KIvq43R zXbe6n{l;g_-GR?()ArE**XcRy629@SG9UcudX5?Ksj)<)&p(i7;o#w*u4kW4l}52( zTeLbKRK(kydk_8cz93!BUrgzj2qls0%iH}nVa+tbuT);7@#}ji#EJ*YC)IBdW5mM0 zf@-O2cx|;Kkl|Ci8TW98n}EXsIVj{ixEDNcJ6qTW_9e#`w`&eyRhqiG-9Yj|xZ@jGJWRFnYMqx|KI;*Oj@cYN*YH&(!ZTT=aGT$m_&i6wx(71HqQwk0 ztf_E#8^EU6@VVh-!$75AOQV!-8QQc`XDGM_qmKyM-H(sWiTFH7r3MAQf0T7PSYF#E z@lE-KHYTqsMJyzL;>+?Lu%FAWnVPOo`74V^H_sc~pYAYg4~5-efv$~tvr%G&m!tFG zguAzkXY=~<8hZPEt#Hu#B4#;hYQ}c?3U0jFSAp@2d1*#PhJC%qEhnJ{ucyL+HfWXM z5t2*lZLcAwp095jB`g!Hny&9B@}-xTxxUg3t+dRsbK!7$?!k6Z?{#xVSv=?ZnT4qb z0tuRlEr)W2vkujK8|thZ?ChDb8hSu>8@lIph7ggK*}O;M#pAFbK z7>P3fQbeB>QQks&coNfGol9zT?$))(ys{KqqUPxxsakL#LEMFYlTXb0sH;@oR*8I0 zlEOanJ|qxB#Tm8vMi#+s;;~a%^a#IhT8-~DGzT4q|0c{ zGf1)htWn{5P>GtxhxGlPn-opV>*sfR3eS^;{o^f^%5ATgUcnEGwY@CFk53JYgwON- zAaf|2jjPPOK}kN@XKVR8&Id1S!hD5%&PNg_qBU(Fh<*$KFcaK~in|Ir8O1H_7cZCl z+`CPqlkwIC=u&3ng>c+dJ5qX`I#-r^G9W14D%Yfeh|7|>)WQkd#~v>uyw(lF7Ca_3`usz7HX?qWcdN^!DSezxdo1Ovd%Z#hgZ&yyiLauS=)gK;4W2s`A!t3{WetCoX z`1bme=FO(FPZqHWW|~u6xUkARVvkT*+#&Ya)uv^G`!(@+mlHqgHa3(};L?sr*aZ_x z>g7AFt1Uq_7XP_-L(lYwVCC||CIQbtkA_C|fxwJX3Ok{$1C>Oz_g%sNol7vIY%shb zftuM-I^TN{fAIVu5N|%0rK#dhe=gFg51_>@QP446U7!aBG@?p;T6n^&vv_lTc?XIa z>h@_DBs8O!T{Uhf?i6JyRo~z8_94)s61loy{#1muLq8`;U|(`SMtom>DpGi-u?4;} z;062=D(Yyku*(u=t1bGs0Wq&5D-9CqxkyFm0#VF;EZ2phUblWy-RK1V4(H%Z=L=ah z*88iY9bN8f?2o+Y zqsWXoP0;PLx(73sx2g3-n<8J`h3s-&BnW^3*vDzp$>pS^UY?&kjF|~n$VAsd4V}Q? z^Mlxe0a@nFy=%`gW~RlTBUT*cdoRq06Wl5a(NKH&d^UFw%Ak6My|L0n11V8o5mj$z zWQ#noa*W={aLunV^L_x6E#q@#qKG_Etd`0d=IEDpAs%tuntP%k;}K1{J!4rRJ5A3a zHt~zRlzE*~c0NB%Q(oJ!aI-_Fkm?mwv%DhrLEVwj#OV8UY|rIOXo@V4#}?}$ffnxk zO!*Z~vIiot6?jheDYmK+pgd%?$M!_qUJj~>kx}>6wwWx^=VifX@N;O36-;vRre_0vrt<&Tl9EeVe z26JVNXb526RJ@5`6rlbjG4{h<^i$VmWRuu~OFFqqd&1zMMw-Ox#i>3k+@{nOR4Z6m z%lh@?i5e5Mk$lvN(W2a>)Su;ba!f*L7!%L@9%z(ao%EqtP`o5!GFb&x#x=IF+W<7V z`Lj4yvfab<-_0%5IvrlZLrU{%xeY(hOJ^x3-LBB~(_13{90JwASj}3_9SooHNXkN6 z(rMsKcsv!&zc0@}%jt-bFSQClO|6(J>8<9xOLVbmNxuV<_|zpFkLCIQqX!JX(>Ez0;*|u z$GNRyH69MgyX+YI@5BD#`Luc9XM(tcWN3>J-ZBz+VTMZsOaOkKSMDqK_9a_L=(3@WtFBw%ihwE!qByi@qc>5pKC<=Jb|_- z1K3hAEVB(Q{*~+btt$sgaGFki-EUYy?BmZAalWyEK_5->NpD4fnb6UGN`^7K89nwd ziQu;3ho;t-&_m^>z z^h282qlPX6svHkx_q;$_$hpj2W9?Y7l_ZJNP&Rcdnx*gW=SaqI? zE}_~ortNHHep?(1R$srQNkX+hV09mfB;(0{^h)hgDhKr#O6jze=vi^swHBSr0?=wR zJG!u1r_?jMrWjPosY|F|S6yxoH0IcbptmAM^O*kI@BsW3=dnH{nITcRdJm(j;?y?g zg8SDm=v|WqP6O~lOZE1ij?HHh@Y79xw-eiaMJb1f6PRr_bk-htJ2KGuvi%R*yXlkN z8o|We{M>q--2bQ-z~w>Y7Jb(9g5@S8o$W*&M{JXvwgBnfMKV$)7F`>-mP5tTy`{!! z?dA238}>B?#@+?a7s-TR-UEo6zbpTdyxmse>5E%h6pHOrFkSudg<8x(rphFU`ARuJt=*51*oSx{6zNd>Ay9V-82Qm3n49zs50~$rut<+LcSNkMB%VbIQ_ua=%%ZbYfr^vV zhI)+T37;RovS}@+yQ&zTfXDoR9UTzbpj66zXf_bjsJ|Ue>ye@;r^2^Z$wy=cm(Lzi zpeTF+s?>lj?gjU@D&!Tdd*K#hcAL)`0b9w20N`eehhW0~l1~XXoXE3I=HvtOLeHqw zs5g@1T#qY))4>qO!zAzpJfGVdYgLXCWhG!pa8v^x2EoHI`GbzpsVzY-t5_mpWr%d; zp(BCkmK3=9pj&)_VU$5LnP3BmNBPm(9`|-MtipL9n@8Eu#q=rIEayuV+fe^TLo@jB zXxPYPvg2Zv|>}GPs?K zYh$hm3I$34nN;|#0a6+V^#ZhBE~yZ$m&_)J>hXcgK)3j0F+Gd#Lzd~!((X-!F*<8= z{-swV-Cr`l_Cm?YVuh4{(R=k|Fm`~^#3#vj=lcH*p-Qxv6-emQ&sat36k|U6iY$Sb zrj7y$%HTOv^9Fb{rS`phx5qLPqN+kvaL>SWg4V+L6EzcDqmi#f90_0tu~xKlhxdv6 z3X)Sm>=XXH3gorw7P6$SfLyQ^u^Zbbk}RGG3JLV5=iY#xz_rSzw6OLS*Aok1$~;67 zY#<}+2YQOeYz)V2^{ zS{^b`u7I8BEe+4tJuFTh&f%Z+4Wy)QV=_(&Y8pQ);jb_@!7B?{VIDuWAM)%>A*vajw_LKF0!LD!0Pcxtp4^}wmwp5n<&N^_9*$tW4*7wDNtHm@~(~& ziMee^gSjplS;`+Wfg0cv$nPE#DL zn+!8pOe}`NGsT5af#KeVBrZO&FUUFy*$5qgUFa(S+MVsz@8R5hVlRhzB`d5AYoDPdBee7HpsBSOZokfgd2UwHVsyM6JzyD8l|{zF6>P zK4j2PtJCJsF~lk$w!Zkw#kfL@_r3{cju4j6Qz(2xxGvD^^jkv%$tKiu@#N^z6`VSF-Xm(Q0v4og6--#;@>Z_9<-5(IWwWMc4-wCrHxPf1ygj61llG^64!|ya5BpzK{ADn0TU)#!C#vSq0lO_ z#@>1qz)_4=arz<)ubf*S^KVZwNINSVA{lf#;=*;aqKAU13NB3dKrsFX{#J8_hY&p(_9p2)pxg*t{6C-*UiJ=ZX!AZ#Z^$^_K^S^U#Tmsx6 z#`>_KHy0Gws0OknWL3wu=*usdKl`Int#px&nG8iA9f80a(@=!kyKpBH>1Cn^bRHNr z(G#}%LApBCTJ&FPBJq~fKe~gVr*OEss@*7I<$Z{~yqw2u7+=wzMT8Q?mB5w1{wMcl zIErkqZ6Nm4e<)a{?=Ix$6En`}LhQ-EFQEUM>`TpKTz>S$t#41>jm_`IcuOzun*3Yi z|5Afjw!Ic<>(r#8f}0x$G{ejWKr_r+MUJ5#fCWE2aQ3#O{@mUSS@>Swj!u>KEf_lh z4Y1}S*?@(YG5C=PvR4TJc|Og>`KISRq@Y@WB3@^WIAAk|zT49Dx!H&g;4A!Q_v#6H ztN#;WR$=2lHJMfIPceKWcHV9#bt1H8i1IeX6Eh{3>xbvtX7u7fq3{&;UJ`ns;4hO#r!$}WHsmwI13%6qp^;2$y+PU=kW-=`Ll1|osbf^m86 zJMbzdF#y{^VDn?3m9W1^j}WLhcCH(>T|bW5*%;4RmLv^3u5zXVR&4$SV8t3}YK}tp zAnKAQ&;F}yLFntI!_fU1JT!}KF;!U;g4S`gW_ed_lL)y74Gbt;9mlRah;btQy}T%F zLx9MiOZ@f^i=rvZ7o0xemP{tDb#-T!Y#EaYbPHO$S+Y!k?F?J-!J6B+x)*}>a3c+3 z3n?f$lw5npa*6Y=01X=HgQ{x%2uPW`3rSnpltf`R%U|Z%hN3~p3t&Hp2WZsuH{t?7 zSW+9%w$!~TBBt@N_6we|rzBl9=mX9sy%tYfRX~q*au#5DvfBBeYHn_^Q3rAYwFJ zy34&nI3xfLy}Qp~>BEnV8V-0D`NzKmQ-Ob{sTpF~$Vl-TmVVMu9|R8?|p*!%tBl&H)b0oAr>1YTQqXBbUoQt8c<9H?Hjz? z3n?MgK{O-^1V;$Q0(aE~yA|Tv+gUTw!swerM!%rkH7=l?w28l?Me4B)btO*nAUD!< zY~^sN>O42-e@vyOYjR|VyusU+S}0vm7m<^k0^B>Eu2ILlGa4>CyWsvA?Dx(_vpP!E z4lY=&_3G>L0xG&|w@8e*M}o{F#-Ei;FbY&~QVv`B4f7BY>7 zgvv&Pz7mS&ErAz=Pzow&#S4dA47}jy<+a1RrEKP}|C_Cf6hD!H9IVET0g_#UwrozK z+E^ut(pnB0+fXEEtzhF9W}ODTyYUO}Uy()%^-3-3F&5}mG+l#6q*XH977{%1BLM-> zv3vF5#jpbG+(ECuXhk@xY0~7q>Gw-H-7*p)#nJe3trozT`u%^Ppc^%k`f{M4YGh# zY4~;mx0uv)6MD_A2#SWdHzR@(3Wa64W3Y8QfD)eItxzy3lo>1!ggR_P1Ch++9Jd05 z)foA)uSE`+gae006&F)%@L#S?L`&=u95|%vcR!{Yn-{Y^F_u^(`aWJGY~XRcSLz)^ zvE|GxR;4Ry+iyu#7$uB7_l7PL=^l3<8L&ViOiVtes#21)McdVy121-s_^rMcQXUg>>C|I2TDOkuEEN8f7 z#<)RJR&jk0bPV1-_?a+)+xdc(RQa8J2n|aP|KC zAq^iY>M*UXh$^W+pOOp{kWhSGy05%#y8PlG*-2EnpSaz~3pw3p9SE|KC?oY4X2nt10X5(x`p zwk0PgkEt{kMBH2Si_j+{dSdZfT8H54UUxbx_C4G)gS2)3Al$XHV`2VPg2>by8mFJM z&;e5eDsyQN12NCp;VHSfpKcJ0KQLwLdP zTw;lU(+wKUquU$P?fJbgA zA1#!o(U<8tX?w&qAMYsvHV3s+Q{mpeCCgku@8tZ_a;% zgfz%Jd%E{QpJyOqa-$|MLm)U5j*Zf_Nz6I)ncdn;ugpioE84cFhfi{glFmYXa1=eNMPC&VWHRKkf?eMX_ zSY_&UJ&lV3(!Hv|t1;pPvZ3Bf45YE@Xl*Z068I19i=)e&aKm5}$^pZg~EiN^J=<0+0lIthuo!>nK zENg??ttUZD*}?XAl8bjxeNgmLbzQ*H*66&X^7$hMXq7h*f!(>7Ua(O;rd*uodDm;1 zqgqM+!wK%D`K9Dp36xl$w+x4zpOjo}X*!c8cUNfH?l$BV_7Bz<3s{nA=TxugL9Xv-Hh zX8{ffTW_Od9SoZ$jIrdO0>5LeGMjBz{LSqp58s}oic`w=Q9+|4d*q2`Mu*6#=;TON zys9%ROYDaHZ@A{YCC_xARipEwuK~KGRsG<$8H&1pJ8$RT_M6%d-BBxJmz+ul!UC}j zhuX`rz7W2*ol83Zn1~|~Ses2Xw}jbXA1t`y4RIa#)k-1el`wbs?Ca`TcE2xrt?T#k z$ep#*g5fy4C^MHGZiOmGA`RE z>hGD_fADJQ2~9&TB)Nvh?E<$unK5*kBWCKK-+59Q+KmBEs;J+Fh1ICf-PU(%!dHKB zJ@kWbne4{c+Xf(4ODQVtJ2^|4$DR&Z3VTeya%DZ=F(u-zdVMp}EH8F@80k{C8jY*6bb(v-ceFRX`5dk^f{uKmtoA61_78_ZQy#}nrwF8w z|Ma$cY)W0g^s*-0ukLbko&$8YpnQeXtcBvqc-`S#l}kzO;YD9|#1>K$q=;Jm(qxFFtiaz#7lNb9L`gf4)>I>ao}DIVQquUGl15 z?=|M4-rOnyfZh_)byw(f3{SKtmKo%c!9n;KT7J|h813TUhjzYH*`@rtOF#h!2Lvi^ zB?NvuE&=xb=7DJyc{I|KH&7C`p^zV1ZB0Kw*CqwJwz&1>Z3F^WUHJ9Xmid)x4t4k! zR=*S?;{qXIJ;+3G&DKH1b9> zjJ@owXKd_!dK)lFPHEG6!ZJ8_D{wgzZ^Xa`E?c%C5(zXA0k@FYhO66#JGJTHo*Srm zINchPNrUZJXFidv9j&Q7{G9{aWi01-Y)}Eo--EmHMne z)HoZzO6~TE?zZ(WpiNiq4+r%iDq?t~-(lS)FIB%)Q&)A?Ii&o^spVTB>Yw!|ZBJa+ ze|7+8iswXjIB^RoGGbJ4J zsG6#hH=eUnRHaM0KSvl&A)!mNYJka8JNC9l@@rgc%jNmy*4y8@Dh4{bKp&NXh6sTb~iee-H%GZkRs#r<*iUBDbjQ z1}HjVs`BF#pF2A`Bbyl@IhV=xNCZwmnk`SSKz#ut34%?88+=WY7GoYc90HR1U=tyq zWCp7JI_L72`k4c;htMsepbXVMYr;{>b%j#evFZ128;cIPITXc+_PESQNf^9)r@o8V zy%)-d0+R6cr}Rd9F8)7=t?Ew6hnst0!e@TZa6O_+Bao5kdbAE5w> zujIJG+F)afPFl0OGxDrSOz^(k=X8d8XrQ!}Ys3%jWqVH*9p%o!P!&G6Q(dUL4vu1N zg>o9nv0OU29dtXbS;oa-16vxW_{lPD{cY-8Bqk}z14~N&O#%hKl^?`G=bKWiy}gp( z&aBq-4&0vBO(b^Dp5ttV)2-*{1^%YLV6gc(b&eB5tQf7i|7cyV;zL3cXINLDq479z zEgQQ)=E0|VBeFW(gS#A{wEGGShgOd)->rxnap?e0dXw*A+uOB*qHi^7N}{Ju z6fzxAD^&`bJo|r5*<+dG-ss#^UutN)aBRXTzFxb{(sN_{P243q*na;aFM6`oQ;Q&< zyGlw;bQi?tV+wZe}qH61$z?Y%=$x_o|BJ*QuKc3RM8rm*ys&5_mx zEgZdmM!NY7p=u+prFR_eA}@l4Bhq()WQv_2=rY@JPF~`vjNXlIz*))5brwR>NJh1V zUcte%3E9&(ZPQcV^xmxl2X-{;BjFK`rY+WC`a!+x;VFD$Q#j(0^Q4L`E9r}U-eKHU z*AK_qie(N{a!E=2VqO&}4C*(PUgw@sNPRoc;C^am0dtQ{Z7CEDP%48}rnJ5FuSce7ttpIzaQHIC_JXt?qC@Lokhb8!Fw_GkvY< ze|VA9p#nTJca)zFGJAB{8NIx`P87~6J?6$2_N229`%H9hAm1TqbhAvjq!Z{eqXrpG z3u6Lfo8v}hDbaJmD11pRx1J5F*K^-Y6$61Ic81Ze3h(ENX%DWot~C8ME^V(yU{oc*7DNFzYIChtZM^qfW+D-6K) zaxmhFbpQF|b*6dviD%o3nG>9Gqy1L(#KgTgDsjQZWp*W`Cr>|jAzO!llTOH zHj*kPuM5cLRe_7Y58`K&M{?z^?kIg{_*{C;Qywby6j z`#+77_{|9g@~xszqla1Sp2N8xy(deg_D_~@Fg~)&!M!H#Q-9QIeE2#{d;G)q)|Bqw zx#@AQ-3SVwSOv^f$nU2TMjCH%M^%d3gayE-d?O^KWXijbqHgJ%&)pFy-SWfIOGLAO zqT8Z|js}j=U-61_^*4@Z;>kkQEjk|4Q6PPX`PzHGctMc?1Zb3UDT1Ab$@jLt#JTK^ zFw3OIx7-gxc8@pV9Oo$g^{v8qU9C1*sEZ)=N_~gr<^eUUI;S@ z;zBT0;qS}Cz;%9%Tr0nW82qejc4~$tLePr>m;L%Ss5E)Id;U3_++efuVDlMDQIz0q zuERcaL+C6zSoyO$p;;ponp+Z;Sk-D`3?AMyWeLedyESF zIuzu0E|mctX3H14xgPb~7bCrfb)*6whoxk{P}(`?pL5B_(B;Is zU+I$7?=n${RT+*WlUu1Y9jAWuuU$%~%-ixJe|D{`03Y(bmjPMBqsu=q2%47jg~9-T2_1N4i}WBDFrbC)VEj_!8?Hs07mt;kPSJSvaNG*fyKmJ2 zzrGZou-u9a+C~$ldWG3nwfcR8U!>qR-zoq){K8^jS0cjw)7=*i8Buz_I0Qmxl0(Wq4NPU+kIuc*>EwEqCOZm0lRx>KH`V0 zV{UM>jaQ4jHe6x-U9G@k6jjo#<>fTfK`L(W^BxUc&u_Q0Qnphv7P413TG+QCABxne zxag|d9-4OK7YYI&R(Ywq!PrqX4>T=x<~CY5dQ&ZiPSx9M>6z>X589~%&Z~)w@2UWJ zV!V>xBHF8m86>Mw`@=k$UVLt3tf1+bc!@ujSk}*s#T0AC^uovZ^G^5mK7{R^P}}qe zfiNgi0KfUIRqL_8zh=54(ntREo|6d@^2SrT2iKHKYz~xf8Vy;vk7N#7TQ5rZ3O9hDGx9-J!h|}^yn-M!=>C-CK z{`JYfR^RKn{k7$IM519v(NUT*J$+9~bd#KIOGj@uEoq`{*Asu2Z4T$h& zv=9HS$MF`6LBM@!#|vZY*qG8^oE}l_2J$U`(66`(GCz^x|EerMYBZUC^ zY(i?)3^7%R9^Z~)gtd=8hz~cCY0_JX%5YT%A0!laNxhbE+g9Y`v0-~ClB%O*$ za47XQEL@q)ER307tsWCoKEY3x|KPvP=3HbUX|zLc)H@N+uS3oVo!C!a^nq5&?=QSZ znoP38eZ?P#ZVlFtwf`;y23w_mX(tSFp}W7aoa-hJ2)5$Ay2_txEt#^DbgnxNaaP5j zPNnLb3KCbXaD8{{Gtm^D(5XeOIO0@hPiFZ!_bO<@kHYW$Ej^dT>;n^MFbkXA))WFq zIwA`6aZ;U2Nvp&|6#7PjwBpfpjutOPCMlyO-O~4Sl03C9HG3~&Q&{H`sX1L!MqN=H z8^l~sM2mEG(P;udvy@r;r7?IMf6~+Ad^}DPx>V^vnw3Qfb{iuxe(SNGtRdyC_ha6u zcM1C;7#S5+j;%A2*fKnh-AP?FG*<1$$rWP^$TywQIEb*xCKG z1lF4w!N*On(%Y=GFy08#v`nTBe>}v_4gbo){I&Ir(@XS4NjXl(7Wp#x@fztBxQ4bu zV(C%8HtO-e$94C4jk6qAc0}L)RKu_#7Be(z0h8)y__wu?fZLJUQ7dZrSC(Ys^t ze4V>2RVO|x6)X(07UF!C>{(GdVjpFmU2v|;=|1K+cf$XOwAFv;mx&T)@%Fp^I~+WV z^QU-<)c{2l298JIi1~tZ=1C$f1yAA4-RR7W=Nj4O_L}<4YD1|=ZYnftU(-1(9Tgiz zV{@2}5uQuzlc1{r!yCLdhn_C!I2kS~n^EoWX;0Y{D?Dk$>;jo{n$nvf|73y`F(sS^|LHPpq2UG4Z;4=qZZkxw)ZPh-jfTkU>C2f zx3o@IN@;&KU%rq1g+5=(@VRYZMR87Zt*_;Fb<1!!e)ciL6kggz9vEqyifKs=KA5Iv zvPP-Pc%z^ZT~TM4v~BV1DxyhkunYLn>%6P$aPNW5 z1AJ1)M^lXz{g;ovHq;!|evg=!AX|P`?)5X|`AVa{#jkrF*j8sb+S@{kgUy-5EiFqd z*m4&AVb!oITi@c$Pf9wEE%o$%d9y0mE%8{P8`K#=rJ^|ANNJiPhE;Te z#Z0)4?0p!+XYq%;U)DFepPBl_ekHU=Ey^ftM(&d}osBcf&4V+k(up%Ef5_@|62UPgbu*c zQt50;uOA7_&T4jv584rT)}DX=@k8BIg6I7Ev zBNOF=m{>$8QiLd*n_6veJEVe&U2`U-b9)=AuxG{@J|tu^+83?2j6gBUT?uOlMZlo9 z3C3lK0WL5=8WAP{Pyj4ISM9zjh4a*kSu|CM#jP4*t?)@`vDf6W@a_t8DP_$QieWsK z`<-ddg{#1xjB&nHSnRD!?YnvdKOfsAss&bFb>`}+>{fDjiv+Sx^aHmaQ(;ne1G7hj zwdN_hWF=R2Khm3t0;o^L7Jf|PO~b2~+8+m09$DZ?C10^t9Q*j=B0cE^Rx8r|v)(-` z0+SLUX_s`)&Of~UYl%bitFpO^Qcc>XdnBh3s4_dz0s!hlc8yl91>FqF>f z=FpUzdw@Q1UP^8gusbzL5>SP*L}LA*R$KP~{05P%cE6lW2`9w{ji0>{e9}7IO2X%A zm|-kToG@Ugw_|~#y|Qc1RjL*CN@l_>=bL+KYpdPj(q``=QeW)FNUlvX0%7vKxQ?Fp ztO5_TT-%=+ZkV^{j_POi;cw#evwbuR>zBZfNfM7um4hz_MrWo1EWFi7B#yCqrteU+ z;ml1x^0io8ANl)B{cXN&l!uOS7vgyWCCkIXztW3OuM zH;dTE{q`9lY)AcjPoO&I^amMw`bAyqP|N&bj}oUvu;lQRi;2OVSiD#Kn?AcCy#P^z zzNj)uy#@k)9fH2jGVT^thI^WGa6_xEVe9HgmYgfSU)_iI=IG;}&~d-MFw8JqMKmll zW2v_HNyAL^Rt4N?7|Vj`4g7`6)f))y*E3m$G|oq_RZaN&14qt`ORMeuBAR$BEb{e! zOuCdpw(kjIzq|YyLUcOxZNGIiiM~6o$n1J@NNyQS;W{56642m(`7rxY2j^ zBeyL39a&Y{f@0RD=oM&I=bP|1Dj|l%?ipQFNj5AFZyuS9<7q}dGf%-#w$0zJrUKjC zlcFA(Y{(R$|`I{ zwO@lcJ+J7Jw4BA#rYC2hhW;NsVzG`S2+60kw8c+8l&lIp|g|Dp=n+%+gJ|~ zqIH;X4}k%s2MDMpRBd#!kjVL#e<{!trX=;?<|1}E&t}GVr#ovz2@8#UC+nC z@jF4xwTKHDM3eaZEG?pgKM0Epg`St&_XT+aejVAiuj82|YG7qFqrE(FMz$~{R((%W zv&AKIP<=iEFnjkOqkiBF=05g!(q)T^EgCH>(iV{r6P*3@+wYP+ z4N)rnQ&7ywedB8x;43zfUXeuLYm5Xq!PGm73F%VKZ#ZX(EOW1WdaTnf?`N&Gy7F(* z?-wGXO8W@*kQ@_UU@w+dklddcq!djSD$7kQW8@&JlFPI8x!*(n!YCt5TbEL(5-oNv zm!5e`zP%l{A92nT0%&egn)BXgF|zw8X!=W=DLmi5+aS3wo2}7I`DLV{QY8+&UAZbF z`$bD3YrN%-^z(B)>wG+Q0WAZ%recxU8G>N-6Fr+3pIarkLt=^TkE!*CzkCu)GqQml1LlR*b31?*+=GuMud5;P_=S-J^l&=VrZ4>b_-)}7v%OgMx*_+onKUsY`0vMfUa zRu%EEqigp6@YlUjqf&LGC_ALXqAwwTT@FWronNL_aQt{jkEQVcvGvwrQO4ccs0s`< zbPnA}4-6?C(xD(FIW$rtA&vCVB}kWqba%HPjdTe}cQ>5Jy|4E>d!O_DLkw`iJUsJT zYu)P(>{z1Dv7LP+N3?mtgTGuD5{#sNWVX`;K~%mxoGT&&GybOyJ|O-~>Vx-HGlS`j z00Ip=xRpae{c;SJ5~=$ zqRe9doif|o*dwGvm;tI`<2G2hJ7l_1$zsD|L7vu>6p5qoMk(S?5C!JSH2eco14ON{L3 zy)KkQaIf9iR7K$UkL0j%TGsUG{I^-jXzOU-s>U7=-nC1HUwlNzkw;pGo#CHrN*2hM zcTf^la4VdrdZUxDE5GnUQy@lvI?RroU+NdYdI4#y^7x`k>FWJqwZY8gZuEcLMCZ9m z>sWJKhURzzv;WdQo9CB3xZGLexo%zQ2y^kg%rkC31D$u zF@T9_C)d$8Mf#fUOlPP+x73KWBye!>@1fpojTy_(kOEpw=k*wyyiHa^Arq4M{w?wN z+kuacS#=zucRR=Pr+7JG^G$ceuL~;d&zOLt73MxL+fA*+@HwXS3=fF)rY+oM1+G1< zXf-+&m4*)W>ltJEM2k#GuM4cif_R%Fmi!qg5YYF(yUaKsh3G$&>pwhip0>Y%%}lTk zcFfVZ8b`EFvIn*Q04D2YwCt?Ui9xPIef3F%p;)~uG~rl2BCjtqv`W(IK^8Uo1=%=W zYa=uYHR5HtC!A}On8r$xMBh&;QnmalM3;<4v6%pz)d;w0Ap<~j2?%9nselIn#5d88 z>q&TXGy}(u>^CFc&m%5G8Bv{mmLq|#I7{mzu}M4%Edu1KVlZv;df8@Lj!Yi%!{%3| z!g{iDcCk2me4Tf2QkkGjyrC~8x58WnlR_aE1abSNPirhCEu_O=^)2-=HxHQu?JI61 zE!p#_TFN2OI%Ahk?Fq3ge8^8Awo(HVPu*EX!S@tIx2_JimVZ#ti-1rgYJ-eE2VL;X zyS7NY?_^F|=S%jMEd4J2e4YQ13Xe}MA$OF6A(x~pDse`e38`|codk1>BanaLgDD-8 z?SstCS$T7moec?Q@ohu~S~E(!bX*IVi3kp{cKnPZ#$a z^N{b3?&QOw0%(Iw_n#|`>Lc@t%gSmm1}T(9P;$k@yaNl62yBxrNGyDWY5g^lE!yv< zv0mgMv!`Pc#GU`@8_CoKm!;HG2-nis<|dxI#}~S3UOnVSZ?X+4m`h4I5|6wF>jh^lJfLCcVNPI%Z8zV(4-_j)%i6m^)*9U0%n1*gW?A+8X73R zMGYA;VB}~1KToLzKuqOn@zn+`kvUUNw))42y+`2J7FHnU$-KCdR9B^6WB_JXb=7HV zcGwjB2MvAy90>b`QXLligd&2qMYTsAJi>;mjJhW6({O`dGt}|E@suEwHR5@5hWYm| zUw5y|3ulEy?&JU=mT$4@ToF^wGYO+H?^`mLNQ&7ww2Wns;n`(hr`vz`4GoulNnaO~ zWaYu?&7b#N+DgdPZqNRhl1WoshSjMjpk$t9{Za3p*2K*2`OO|Mo}{js9kHOz95$jA zHL=aqeXW%|q~ob1%b2dQOCqW<@UP^pr4*`h(VKZ_m2s5?7?IN}iSY<;dYhrx&AeqWn6#5o8eAQ!i$D9ao`!)u)qdEhsWS;P)AseZrZR-`W+!;wSJbT$ zPs~6y?pW;#3ae=eCynp7BH0iy|6_Dyn^5D zOb_lirI*KOSD>!kZLJ9mrz|(`8Vm# zPhzid#8ylf*k!T9`I|BP7;(NjER3C#7EWnbe9EWPE8U0$;vv#zgH2tg%{RO|@+75@1nNpJaANFzT1W_)y2n*u%rpt6x05r+r2iI~l=-Qc ztt4}l-fX`q8Zm3qpgV3B!&f7+R>@gRpljRA+PgQo>tJ%AZ^?9mFwdgu0*m^N0!K&o zRz7=4hfAn=HR-W)i9$Y{TMAyn(zPf>mYHCu=n?op1j_wXV;M~k~nMPTo; zAiaKRn@tnJ7Knn}G2Itz$5J3%jpsOZl@OsVCfhfnhmhwCtZrkh0gL=;$-<_+;4Qo+ z)TOR5+1S^FeosI|*crF3sUeDEsWZ2Onf&50#QOI`iN_`Y9&!;-Hh`T+Comw2N7pC% zCGl2ch{5Gznv`0k5HZs`DwdUQ{>%j2V0LB`ub+8TY)DLI55$I4rvCuXNo#!Dq4pIQ z!-h>16X^iBg@RsSWYF2P>8I5zfbDDm*{w&P*6IoqzpZwq&8}-!;iT5;8%aeA{}n_Z zA1|_6<&4^;W5CiH+>=uRNK*Rd1)@spw3(a z8c&QIZP%{#U2%pBxu-qjU4`2v@AepLF#jI6z^>rL-%PzTOfv-!zzV-fi-wFLC4;Vc zm3O*)?gfE2`F3>v6sa{I4IIS>D;X<|-UA+nJx?+TOn;-Wk%<SVt*#W%z$HgHp@lvhb6&bP6Btf?XEfjyvvlf;D>VY{}@0gY;B6X0A|5cR2 z?0aQ4uuHqPdeJ33F`Uvl!oYYHFPdZOQ+l-I{|(>YWk(YI7AEkvV!ArGu9H--^q#M) z*zoc`rx`Pj!k$))hj@fwro;XHON$%o*Oz4lNX$s%Ge0hvD?DyMq>~*L&YD@XeOrU; zzt7RErw(*BzNdh5%}GLQh3w?D>6!s;E45`ZxMp{lC&nAhzxnxmK2w?X2G){p#xn*Q zQfs`q1QN&h_EgC4jdRp?U{ve;Uce%8IRk?AY^Cp29v~(%JFpMQbl@$yC;@ox=mkXW zQf_>9LbUCy06~vu$3>vhzKp;*@&IlnKyAoPxx}M`1Nlj_^h`+ZrWW_3zzj)eqlTF{tP?iA|4LXHEEN4W*b;+H1P>x?MA!JoX($_= z!kdj$4bfFpBK&jt0$+&Fx4Xpp%vfG3)(^*mFAzRwqNVSjQjPZ2+P~ycHtkY1vMQ4;)R;w8KOYh z-^L?dZf;yky*;{;|A5BApwx=qv`|r7Z=OF!?t0c#^ehZwP690=1r;y&ZH{OnAi@_m6ANMB(O0V%Q-n=#i zh~76YD5lw$(VUPzFza=8AK|X;8PP~92#cBaJLM*(_pI&w-YmGOFV*?VG-CuvS%6@) z`~L}WFbqhD-yb8j-5m=5mo@so-5cQcoFh4iGYTJG)PxGu1x00(v=ID7T|A`c7A4U# z+f7vq?v4{bH?hUR{vGoNPC^lO5U$_aBm`vkV zjPMq&bF?c1B?Ibfif|VcrAV~_HZX$8dQhypOUNnfHQUWLS9e;{cX{>O+9uGT1n-WP zJ1N^Y_P&=)N*p~AS*`N>WIfus+7N_Ye=Uv@3A-75XM4Ec8JK(DE`>Hs{6c8{!7vFQ z1yVb2ltB}(xeJLL83NSb!7Qqne~8X4W!|;Ys=h9$bi5H^n%sUr z8hlqr&#o%24Z>%7Nkx;l3Bn&-@OZS>WsVEm7TfCh`SbDsSgil~`7QMibP>Ur%zETFTs)>y!*%IXykbO&@`J)D-0irAYtAi-&tQhwl+ryXE z9ib|=r!UsP8pqfQ#h_jvS!*%T#_|GqrtTcKq>}|BG4>bD_-FVybVWG^*(DeUyRJQ( z)E!|ru)xxXeQmg`8JkUi^$I{1`~Tk30A@$n{@=z3Na7EDqwZy(!0T1kif25HTP3hgNIR`1w|@1J24v??9vD3JVTtR9_9#VSwC?f>o8HoCy& zs&OIdWG0RP4l{Qh4$Q;BK?*C;Ri;S6s_x2S)sGriNdXv7knoA`x)Np zNR4_?q!DWwyAS#FB)ZZi*WqbsjkYU6r47UM2^K0YEm6J~G)2;11#o&I8LB3H;^V@7*7$ zHnz3JqHGqlW%8Ez??F`oeubFI?0xW>3*@BBpakX!3gXVmBf(Fuq7Vka zQ9x!vCPG_OHA)G({Dd5?(QclVFDL(VDL{9fdEHwSK-b?(e~u6yO1J2;*tgYb9tkMz zr<-ja_G`>T1xUl7H=s{n{0hmHOZ(WCzwxjeWFJw5^@eK4+^q_Tnq`%73?-=GSG5cP z2qMxqzkgGcK0pOOBgN*yRUn&F*#hz^pAB6MAGh7SABwh-+#UIcp5gs!wd{7Tp>!Ue zHh%pa9p`lQN&{r#k{9Rk7V=KpLl*9?C|niN4DR`d;(a!fwz6f+SJ6RdzCx+^lz~B@ zWJyXIM`{=Pw&@aDxdP6NX{X8A6yxb}c>tmi-d*v0O^cVV+9MCNH^hKp%56_#NIHqh zgzf8}5ik(>_~{;Qj%7QOHA9>TgkzeEzkQaN#QbYU(%m%}8D#!9BP7NH{~e$k!2YLD z8`Rf8eIbnxFLa1{`<=w3>4`*s`2l0DS-8Sy=Y0gk4u#`;kWP#bI*aIGwm46Ws7-(B z`=a?emk_^S7S}TArU!19p1{e@6O-~E#|#9_%^`vivfWSZ+XGM-BfoJORsA=I@xLnU zJ5Wv@(Ur)Tr}>MDii&{4r_9!;sHpGq3s0*nnMa>VKE4_092}6WOLH0#ved2OVB1P# z)DzlD3qZ2xxS#qqbDE%-RgD(Bjj8LZO*ze!( zK5M%yX}Y-74tTL+v~#}N{Lvtz%>9{Y#JfO6jkeD4fnob&rv=SY(6GdiPVleS;7~T_*TyeuV2lK<-$6gN$BWa3myJK~{X=jQ!Rro0qO%%Qf2g1sf zU5|ScXKh+wF)w(Kh*}~J$OXcdKAe|)Hq{)nK~Vll!LkDFx~n9Ko^#B=4}+C$FIc)a zzcZ4(>|;9e*bl!WRp7%3UIxp6_}=DUB-NkPB$~5!<(p#7Qh(O_YF%NG)v%~tvG@xV zuL87JLgZ0SI3~+}w!jaH-fI~34wQVxRX>HSIr&59 zYf+k_NDZV%nf>GWyshR;>LD5t&eavM12aFPu&)%1>$b?s2A;ev$RwDAAL)TtHV7i$ zvboEK9AwcwIe;Ax&NrD~wnt8Qg0Rfh9y@0unZi7ihPb5PBs+7;W5Qx4Y`Va9IL5<% zQnr^`lLCr(-WwW%j<+xdmfN1UA3!;7@B6gnb7!vk9}!=>a2!Pnk;pGYOvmp&K0zb_t~zZgYoD4Doeih@JutRwd*0IGRaB-j4VvvH zggkQ>L+8t+WoX2uy1fJlhW-yZq!fcsY;g&d0DgllC$E&hG<@M8BaDvuGON&+quH+X3cwk|Io`K z3$e&5a)?pCCv&x5!uF>3s9oS;KY!%t+WXDLiT98SuyYm1U)wG$UPNQj#m&&rm;!$*Di! ze27w_T>0sW{NoYZx{)$Tz_OB^`YwJp{seFBR#)E`Hl_w#lmGWN4FSOG8gV-AJB1Sc z7PMQC7Bu-=@b^X|y6Zsd{r#QV6vHISkB~}lLp)h4D3Ydc^lw;Xl8|YeTSB3q1{9lQ z5%`_dQ|IBkF;@wq8x)ZiSs{qRYjyU=eoe1s8{0w;L{ExCto{?dI(ge`uS?i$%>k*c zEii4GV1Bn|uHZYQXJBy^rIYOdlJ6XG@OIQp*T+q4#pV$Yct>=>;D-+r7)9-`rJQ+e z53B{K`7E-&nke(t1Z0GwQhdwrY`#`2Do5K}rN!Sf>QL;TgJ!qbel})H_4_*32WDty z3ccYn-MiHso2(B8PzTq1?hTEI#u4b~6Lj=vz z{xV0|JyH`LGkg3!rLHi_;jT`bAE^=|-_Lk)UjeZ~Bh3=~gQuKny z`W$e}H$AYJc_Dqkkt+iSg>l{~Z!}?pfUqZP*qFH^bzGu3O^%YdDQrwo3ivY1lYSoO z%UwI;%V9hLBz&9fVPZu4^Z5?Gt`=deV$w0KNpt(BMN0~PkZ z@N>OMmNA%fK>wjVa^gh%XwglNrl^GNoZY^DAr#53tLiw~qtEW7c7>j)kj&{W$!}Rs zKXR88TDGesQ<&r@Cab7s=quwaWlaqmYm30SrBstGO7c5A|9df+ZR12B7ai0!@|*a! z@9?9Jk>5HBPE77)O02?>)4l;43<-h(8w-~&{Jtl(bqm6##ONTrX>?r}$9PgtL1+&b zpn9-2AL=o@G{9+ktUULqM;R>t>`Cru;-|%ho9cM4@$pbt7DzO6{cEP~Bm3LxDG#0{tlm6ylpUjS3s!ZS>#_kWIiD4qoH zIEej!Q+%g$;9&|)uljpMA3M4?TVn!ax61CZ0$JPv!D~UcUTZdOuis6F>%c~3?Lnvg zmw!j!ZM~9(@U~B8LSCv4-ygn_ZWmu3mOxj$3Ys0flQ#j5&cK(ce!Qb>OO~)0iXaGr zsz|+lIuDx%soz381dGjzMAmZ~9X3{MqU|~5vo5w)w44b1OJ1u-{;ijJdU%d~D2^_y zq@-Lk8>up`2Jaq>tl+KhHtQ#)V7qxY{BZjmpUq*lAr4}-8KoNd@>Y3woz9viq z{UNgSdU6tpv&v#XZhxM>MM9gp|^phdF6*C z*0NFfb_mB>M~v)&muV}Tjl+bqCJvsq8&R{spzFfhBtH|f-=zYg5>iVXp^$Bt+8z92 zhwd4!=c-!7NmgxmY^p9M3xEq9UwxO1p6|Z!gI~%Ce)^so?qWaqo!_&K%l8jVbq9j?Ib@l zDwK0xSk(;$r|mIbl|vQUYLWG0LQS)5&p7MV@yj8J z3D*MjSE~F-k<^41G(0Qvd;TTMK7Ls$HaH`m z$M0X6I_)q`sc2?4aXFdG z@`|1*s#_*(=*4U}-}t#A$5c%~(1;o^cx^TaOO5tm6WDz=vAfnols;(P+h?9@8qv31 z(S$o*EJ(c-WUYH6{+eWsB3UaGzJ{-&zmE&q( z6Jw)hN=t0u0JcWVkE%1`Or-=$Tcd;)6$6*cvxMG6+%*WA;qPlsGLs{3V8Fey2`1>* z)FE`P9=*Vc`Wa9V-lBZST*0xLIKRi`qxW@3g#ueFS|X1pR+5FG3AtLdq?c18UgUd) z_y!W(Ic^hA!xgW%pt&0{xL!elDSbp&t2(UHvMStTu4S@XG6&7q_JnZsO;iZr=}M~Y zu}=8-HB?38ZLedA63CIJ#6W($cvM-TX&XKA@w;w96?E@Gbv>VV)bLCiRqY%1n&Z@* z;OPrP3A{ra^?Dv+YAB&|aV7sF*EC^k{QwnJ*j>ypR4|TNiEEfXdlwT~j;O2KJ?uxn zMYSqi55F}|@}lLQw+;7#AvldL3)Kw4!gy$#t zyl=tikIpF($7VvCF>vxHL$+VNs9Mo>YS)ZLq*rD1-flQ?C%ZywmO_c1m(v7Cbg}NV zOW1Dc%?kDt^lfqiI?700=4DBf+FggoVytUARmodP_d(&;I_6XsdnvB zG6P#5PM`46iH^Fj%+N}(?6NHzNMZH$^hpuerz_Fr&C%LIOx)LFYy9 zUA!3&T22_dCNle1XT#BscGxyDwS6Q122u#V@(d&G`6=@3PU)y<^UNvxvF@+vHPtM$ zq2~u<8DS(odU;0ThUlK)UpnGWoDt{4+zP&A`IPMp#5bC!CASzXc@8E|Bk4&pTZ3kl z#8jT3)eFT+*Qpu#h~Y1#-Oi-R&0RgLgdty}pDd*sXwC$XZ-}T)g_E=3ZDTXCWFeRc#Q|Iq(zk=m09Rete1}n!!SF zIiA}#+ED2d-0TpWU<>&==54K7QzjYG80l+oa27{_`dy0OyBV)ry6GYqCnoI8y8+56 zx8N)TeT?O4B47fmm|L+?IY8Eo>wcJYnmSpLyR=yqfD-jaN`}R z*ofVXb2@niN9V9Dqw}ELydk~w9@ksGoI^t1?Q+f9OOkKQ@blm3Mr?Js=tA!AYGy)? z#Uam0Jmo4MEnWps^*t$*X6BmX?i5qgu-*|-L|s!#6+OPRE?=QxEG5p!aOw?s0LjW# zi0*m!{Lr7tmT-J%*K~Slm!N54Ln>8ESPwv$g$o1~00 zcD`CF6O#(Fa0uLJQ$=P{CU~&%9-IwX(%BY4@nE`yCIeL{M_g7a;lup)>7D$LLvX>Q zEw}p*hriFHoje>4ux*EZ5VM~YmJ_aU9XD|VopVUZ)m+zQMmT!jv7WYuG^AX8Jbouy zM^)TY7g!(jCi^;l6j#%ksQ*E!a-bf9s!>Llf2%ZgI_B`zcf^8*@--argMF0w*TNSe z_2U~+Ny5gE+?WmHRKP_nYQBhdGuT*zB+$sSHx{dtm|a{_sirsjHR199>*arEs9LoA zvY-iulk%a##!9L2(tJtY=4GxsPbE}ZgFMbW0|^%ol7$iM5Bdee5^@BF-+JvfjOZ%Q zF2v$kt{^+@*U+$J?#()U3qO4?GRzlmfwN1(3y z0k0*GuA*uOchF}oD_n0rkdn`ax9o5Bo2cD4SMiK9h&||CanSJySNfc&&+VlTr5W$H z;3-KvT>d4}xZtB(XXEzV1jE@=_P3)c6^`hdEbsGVZY6zCAGmzE${)fw*}-yueLGZ+Oi3Qn(pCM5fTZAm&^o}5hZ)MB-5 zSBq!s9NC6c?6T#yJnR|xd)oq%$Ogm9o}?Lmj(&ENW#%st3uL);;o&Z8(^kkW-4b`|ZEL7`5;mEDT$Xf4^!4~5EHroV27 zTg_8QB2^q`aEsOR9dI+2qb`cEe-73Z6UtSYnTM6$f+XMe9-meagYrT%VrNsmH1^Rr zIn>?_@Y$xUKlAoTAUim2eu<4*EhekqeI=i>QgZkF#GHE7>>Z;Pr``)T8&rUMERi`R zk^K$N@G@w^aameY^35c^3A09(mz zldA>n?Y`I_xF-!Y%P(|)8O0Y%T>IgW7MhQ*v}qq$x~Ir4;)>0DwsD3#fUN2D@LpyQUXzz;Z-ass5fWpy))gJ0x-5ENY(Rb!zx+YZ7!cboKSLF{EKWRgi4L-%jpV1>+_b zH_qbO(ggW2CTqXNipPFrd!SX`4e_ma=GXe;#l6fZBiNhQNB2VMzj{t_fNl^B=>ngR zkgU=lE$x-8_s>UQUvyg#0;fnp=uVtndwI?lC zq@8bwY?DN#lawOtHjOt=iV@n2%?T{+(YX^_mG&FNq~zm=?olSF^4gu%*Z>C(FGJ4XkYN`NjcNl0?#0j(~uCG@4E zt(kJH_#1sE>gRYHM<;j3jk5OHZ}Ai>k#hDwb`MRf28h zA)D8O;&Shn*=Fl(shCKGq((nrr)h6=3*1rRlWmi|NxS63-}8~M9p~F$*uwGLvxf3~ z4$K(FwpG&J{rsz4LV3qiaQ^_SdZ2C$mI1OUm#|f8;v6!sN;-OuZ__6C=_ceD=T&5K znMN6h=2qoST?%e19*pjnEuS0;mEnnh+jy6 z@ZbG5Ra+$YHhE^ZV+vAq(>1qK3^OTB8Ir&w{Jc*}o> zeg2=bqTTNrKh{@E#?oiY-a z`=-x5(_?o!BppK&=(2<{@7+3_4N;_WcYcYt}aajPLzgJt7B^bvm z5Uv=q5d&Y?K?lZ&Tr=*Vz7LQR(Ye6%=Rva{Uy_Au9fm~miL$SvNv9&|kX&J#c=QPa zfBmh-H7Y~*(g;;NAjq(6L>|F>QSs3TdzzB9<2ImDr|O0sC!COh0*n)U`G?}U4HUirLi>sG-X!HX{DP6p~a{X*c6FPna9O1Hkm#m zUBKA-XJkf;2`e=nPZ>qEWt?9>aXh6OTy8?xq*7vjMef_uYqz@nDMOEZY^dW>7!h0` z;q(>*k2@$y$U9|7hrv+jG6Sf_a|VIwV$r>2u2msQi=RzhdP;rz;saj0mPvmPtL!jm z1xEnS;_r)v9Z#r(H}dr1XkV~2icS?Hu@f;zK&aoEzHoPkE3N*Ocz92dr|3=ZJ0Cv; zY)OaP@_EF03HRI>3Nh>wet7$_WXUKzz2JRC_~`;6dihPbv58d>IBGjmBw-yS`=cDQ=U1M#utta%_<>nevC}X zg56{ORLJkc^0L3-8Y+UamRn2)4`~?k$Y3RQy*60N`y?IS2CuXzN4ihDIRs<83`p5t z+SyKRp(*?f*&-B}i_(zQfWw{yi5v{GDL9*Qt!B?q?dDaDHTX-_RU%th&ogg9k~)Z)Yfe zmh7o=pO%4d%lyq(Z@)nQtxETGY@PRw9Ulj!hQw}dkV(c(QupCs2nwR$VF%7hvqN7l zUij{hX3dZ+O6#929A9Gh5YFFp&SToAhW&}36kf?rT7P2$!oM=teRS-d`8dQ(>PnZX z`*A5L+Vs?G`y#57dnbv13X>_33Yt+znQF2^g2~C9SC2>R4YH3qH2G$^aPxdN`4)XG z!6(zuh&N+cqH<0DCxZ0>jBZ|kypQSdZ@R_z7?$`}lcat7by(sP@bOHqSJ-Z&I%EF5 zl8pHy_Fm#hC&T>ksH<>dx(D7{O|zkxeCG8&M%yG6;o*jkSqpxP+qTea443u`mbDM! zL06LEg1QSG9cd;wp32UVz2p0O2rZZ|rAqlWt2c>QKN8&BDqtWu{7+*oieF7Yw)(4G zwt)3d(4{_<+Zy`Nao$-@!BECj(~8xKlJdnmE^^G78~84-8tHcO+P8cdx1Bbde>jmh zqa%)^ITw6t9|f8>6QwKIWiKg6Og+>;AJD5qblx=Ye%X1(V)9E|Rzmr?2zyDk!zMuK zCrYY2;hm-!?46)bKOSjdxht(21rqFDX>ncp`y~fJcT*Y<`xl4DXFoQ`Q!GSz&*`*s zKBJoa2C$Qe<1a*vo7pktK6&1;K(27LZ$%PEkNwemVZ^ExMub6mYuAbuOn=qY?oTFf z2rKi>i+s>OKj#qhCR5=s$^9~pbZDx33 zuZqq{$tw32AVeM*rbkcx{n|K~l07^6oTxWTM|+CsYiIFhqlTAvbm~q>MXa>n0@8c( zxsFe7d7>;>dXwouSxtl!1V|0GKZ8Ekse^G-Lm`uYxjTNB&JsFFQz0!%>L~LsQi{re zjbWoW!s_A>XM;D~NY)MC2BwR+28Jc5FZy7EF_JH@@LkLF*cZH~JoqfAud(QCwriUw zLjvuI6{D@x*HkQcVDrI35b!pN>7@= z!6PNc(PIriv29@I@S{G!$otA!w~HB=r)65U=S+&AY+r^VrO3uJyqBZjJVDs8!yhxFmw4d8Gu-;kI!u0RMgXEo%^ zVk97Wm^Z1(uB+(h=R2Vl9$@-5Z>1MNTH3y9qtrYNg(9I#M~hsBvEP6TY=AXTJ>PWs za?2qaET)Q@uIew9`e6AgPpV(Arurxj@J`~lAW+tuphz3b+t%;A#kd;36Ijr&uKc4V zG=DzA(1q$_Y z82>w)Y-D;+mFK3s@;o@$A_>Yz;fcR|gszNZmx8K+iSaQFZfqMq_gG~^u5pYVbubrZ zG;tMC$~|BYBfc?^G+`7S`0*a#AsSl1Vz6_d%z-36C7%pXk*55B9zL#4SY^zy3SxDi z1x1Rzk$wMf+6L%~|FtwLSIhnHof)MWsL5jw6?t=6*6pO=XGx0)XjtN@i@MA!Bx`Z` zF+8fa;xI=w)z4#{Kc%KM^3T^Z9X?-l<%l8fSTH}x9_pBcV^aWu)>aXNGD2@Jq0X2c zbrsrpZ&k`Y)pub^;&RHX!n5^z$R+Ii+;eVWNhO)KtJnm^JJ8yiPBU*O<+1wUiwQ7$ z#s(NkgD0S*E9cAAZB(t|5pCP`E7i1!&Uc<9+mQKCN=qhSCD8MUOu2NCTVWKerd;QL zQ1Y?leG~32PzfF9zvOk!c%J4k=hB<}g7!P5>|bbN`^?5(^XCjs^_1jg(y)gc9&FGU z9+lvG=hZP_)g^5auxUxa@+pdGm&h-Ww6fduOso+H=7z_!F)E-Obt zC;0;6O^;@VLeW#jb=^H(*aR1)P=ytYK4L(m-hewjJ=y!D#$XyXK&Pyy9Ti|*k88x4 zs{MWG8C|?yg;^Oh6ZyQ!#MkawTu1b`L|5;A8J?okVLX0GnAA6$S@A(3zHVm|5;%R8 z?1f#dZ66To?vSkfLs=CQZkP9tX|gw|t?fm~g!ncKL1ZZsJ;v&WJ&LHbp7MD_$8yh} z$sz|wv|#W={YT?Fu_V>HE;Cb9(y^S}P$V$m=wbFD6g^(~SZO`^Scf`dE~6N(s`iq+ z(d0062KT*}+63_-9qSX;zCk^5^xi_~`fmXjTu@FZXGkUMc94&8sxy zm9JUOcYf&ub24Ri18zrrj(PVE7aug@;nc{1V1DR@NGOLRb!e~5ju3>mMq1}R!RAOt zI5U_MZen@DNXJbqH=&{bMg9KjKhDk&3`F%`eJk4i*&EM9KD+0;LyePh<{c`PkQAy z%I$4_(%}O(U6}Z5=aUe$Q!zw;;o|vw_u|({KHD1^|21U)Wm}&CpGQAQD(fEAMzObr zUr$_}?1%qjD1d?>KZSSVw9XW5L2=no0JaWms^xTvafGq1*Xz`qVKcoIhgVbnx;K6=L}3sA9-qCoNclb_nD4slPUWpUr5ag zj)Vj?>Zuu7^O}=y9*Ku&G}dblk{O(^_7GwZ1YRYsDq^JnbseD>id zq_%i^uJ>)U*OZPJaL7rf2Xmwy2i~#RhI$lS)ei!6np#ltCP!Ry5!TLZ30$@hx&Nl3 z>SI7G&<~R7tf3QV&kjX)w*`_Tiqs>kbAE|S{u%FHOVAHnpR7Lp5_=#rYk!x(0J;zT z>J$894ac!doZw7kv+HA3iEX?x1^k0J2|uSHRG>Z}qm7a2ZL(dW$#T>g8Xe4^w>(Y_ zPvexE4xhZzbps>LMcFYwR^Q950sd3%z-((*@|XJycRpn+YoGIYFE#D;EUgH9n<4u7 zCVaOu{KJesNa_P0vtMWgRE%9##e_x>nz= z;&5?M)j*;K++^c<0Flc3l`By0wQ1lr zMJItQVdlK_=qGFu>=uS%0U{hc+?0@HRat<(^P+VCRAFSqOoc#X&r}8VcMZ)`-sZpi zYzuF?-2Y;pV?*Vk?DT(^7?Z9s1L2!u+^o=r9j!SJy8w97FE>5Yhuz`5Iw%s$ z6LD`F%LXt}S|$C5fuie7K98$;AKK6v{MYcv`rV)m(aJPB47ZhOU(NTH0LZ!4J6V|+ zO|C@x$3tlVcqr77zLCQc0CE*3NbNnMHs>7}l-Id5jvcKA?K7?OnN9zM73<6ulkp2b z1?{U92fU;+$I->-l*K@GUNcJQu`|oKCWP(n???kn3ut*eBXleOU>gk4He-cLkgB7>6#pV4m{Y!RT^b%v#5v2U9!$K=W2@_o7}1#TSksn z;|*B;I;BJ5+^S{>$J0&PG>PtKCSBzwr-rldI*35xmP6e&B2KPlG`8m)da7c`5#O14 zf3->IUuYxBJp`?K_VU8S5+77g3J@F~IDY7qvo#NdX`+i+%;Mq%OrBjP!8nOjtbMW9 z^%ZsrY+&W|Y=`P+oSr_+O!Ie%Hv~OjeZ$7~1wQ_6Csq1Z5<;2ofHW9RJk1W@fKpSj zX;cnx>81@f1vFGGi{vP_|1*9dc6+s2$sBR4bY`GKeoIDYO{@IL9&y~VVwxPD_tUd# zLw`u@rf&Cw#JnkdtYpUV{2g8*r*1gQhz3;QeZHpkhhKySm(QNl&LtgYMfYa{S7z}= zCEKCF32$?aNZ3mFpf47xss;S9>jGMt;TKxdz&A$$>YH5U5eM(7iG+UX5~%d-#@5+R zHSUs;-sUR&;(IBV^LW1LB1iGrYF9(2_5-Rd$1v3IG{rbT!9IJI)WzUa%emP-Gn%qD z6>3zT>0w$>#ixoyBPC2_o(jLn1+t;}HwaO`a$1I`vfFh4-)ZnyIu_Lp*E~z??i-i2 zSevZwV(3Px&?Aj<@frPNw(n43e&c+<_6@Y=!a)9AJZ5V6luL6rt*$3wMf%-|K@YJ} zUzCclalXTse&sFG%@a&C_FN&XqUuZ<=u6URljk$om*L+(0xU-v5?}wMHF7~$I?$!; z?R8$~JAf!2_<8vt;|6=qYv$L7VF~9pq8Vt@HOT9J?T@kYja!G0hFH|*;Atppwns{| zHMV^c?>N|E65H5v@D71Hvu7H}G-cdW3&|L(57}9y*%y3xZldJj3yjp~_ip@136~?8 za)jRg*AwLb%5nr1?9Cea`(d{6E+(MS9hUG(J_>Wk+;LZzdk2syb47zR&z6prcB@rp z;FU*C2pZEcMUM|>JajLs>VQ_Ch8x#d3X1g2rh4I8{!iGPJ&>!AumjR znf9P#6YIt7KLkC5tyo8<&npNmvR)@u=Al2H7f~kE{MKHUG#;#;JxUJ~s^01bXze{o z1Wc)z0mY?vuiWg|`FAGW>%C<<aPdX^a*p#OmT zB7wQ-C#op4I??DXZ2Mhp)~MqPOsen3HQ(qO%R{;(+cZuK+vt}}$G0ejT4yHT*}B}H zA+}|enb&#kNma%Q+bauUzTR-Kwh)~#p4bwpWHpV(25Z%GjGr$z@ptp5~c;6XMwLzG21-b??w}X z&Z${kbcMoKRyw>$EA}Rbm)b03j$j%y(eb5-Jny|M^46c~FXkIDyP`LiajEF5e~vx8 zzujP`-0Xx2|J-6uwg!drh*TM?%E+1Ye!M^6`XD&U!sPR%897r%F>|=>*B|!VH4yd* zbED*W3E)|&v*VEMkk{|o{ybEGLR>W0M-k{b>cXA(h4=SqsX%>Y=Z`B?xk8UG z5%#FM)#N*}LTd_$`iGxJH5cRe_(fmMZ8b!JQecKYVdJ;7(Jk%af+FVq&OhS7YOhia zuL6duXw$A-LAP0nSKq$BKqN4i5fB0&ca)dGOcA4D6fT|%{oyYqk7+QqPjeJfeDBjy z>>_g1q2&xd$M70|4x9nXBJ^h$sncOqahq3DkTPCJS{7`$eY5cDBd2c4a)D0jQ_@#{vLey+y;nFm!c(-+xaujkk# z621IFX3BVCkNAk9+-1GJz0r1m^xkJWy{{}mYP2Rh9oUT_sfuHgG)M@t#X1IlmHN+s z-;=X7zvg+3+`W|LKfgg3!jMm0Ot}_4AiwjS2mhMz?1S5FxHk8Nz#zuG$wgebGwVip z?z5~gwtE(O0nxIlnaPPI(q5_UHB8H{efP9ejSWSfqh2WpnURNWSbfvI|Ew(}uX^xH z5AGROloI}EWoW3QdxqM}Y0?E^*b`JK{y>C%doqr!orfcrcHzrtd{=ayB_v9z6yE+6 z;2+46Blb%-pqR4~yAM@|s!!*Z`H5gQr%zg+?1(5!n7)hiZrEiV@_S}E{cvw1zWZ=b z-UG>~He7k}Vo&j$5j9XKl){1eG1(fIG2L8eYfJ3j(ac@$#B*-+?+cqB9v;3P`@|jn zI`I_maVo9@!{!}S$?My<&&DL(2i;8Jx;#`uRP(iE#@^zdk`Gy*+Aq0s^v5e@xS}mm zw!qZ5<_qQ^PX_9iPHD4d8eZm7XSpOEe26S=cX1%!u_B6cRrx9WNF(2K+ipRPLBDK% zL2EPpKK*uvea<}Tu6ed4D>vm9LdQdp0pIty(OZx}{B_HALhEHr+S+$t&B-X)mAdMD zFB$}AB*o`G@+6+>YG30~Zp~k8zkfL2T>@Ii@L;;9=5uPJ!LU_gnJdeKLmmR}+ejt` z5-k1KN9l0X+=o?5#SKCXDCMC1lbwsZJu3p0&dTg8ytO7~sc-H4C{aq>&<*4EH&XIa zEmcD|%}z$|2eJ; z2952czuyzzb;oIEeM(SNdbNxP=lX^HH*%lIj$!5{@l{Uogc%IJmSUp#o_&3&4Ahxo z{o>msMvPKqpnsH@*)x6gP=Qoa;dMs#ZM{d@rz8pm%^?)SZCjxkfG`+WLlF+`LWyP0 ztJuAwg__pI?iRE_fcRJn`67qI?B3lPqPj7OuCQum)mJ1;fbg~uPuQU{pY;==m&u!4 zLLTgsgpk|Ks+C8fw`(PUS2spwW;1U8iF@&0z{>0<;o#9rLX5Lr3YIgMgX41#B1NBu z>th|Khqpu9IxxVlB{N0fmt*^6C)drd+w43 zFAekpFcY&*Mh(nNdQ)=K@5gl&sMQDHzTAXd%@rBDkZibkWLKNPKkg~!E~|`%t}Vp7 zA>@3U)e^MnB{R`9CY&+P=$z<0!C*2zK3S|fK9A42qPnAM_-S$PAs@C5;_|HlW8!uAh;H5KGsc`~ihE;x&6t1&8>dIe`X&m5)Bx zf`XTOVqAA49sVko{RTM_i0V|ekg?T!mHFj}dS;;}<5O{eS^mZ%h0(}{j4sWUUXN5r zel59x{7YFOP0_nzHMORqKQFZH1#{z*ys4Nmy$a*1-iCnAAPzV=Ne#aAno=#a(CQR(* zhRHh#TrGwqgo7r5WHR<4-CZAMssqsjK}pH8Et65O3Ek#K(j)L{%yJ>LjM&deUXjSZ zJdUuvdt~}po=moj=?eYv{JnQ47d5Pj3}|iyY&BmFNm%AZQx^*@h$f5p!f(o#VysH$k*;+sMQf3h-TDa< zKljKXztZ$$x_B*9rFhAEPel2w^n21gWp2RG-)F0@oIyU6S@kbZuch}#_=#JvgpNN><6~VF_DZg*;y^E?=`1R9 z7uF*5b(Lo<)t9OCDI9(nJ5X|z(V$$K95=Y`EyI@-VZn}${}9{jxfWas_+81MVJ>Fa z2GPA`{!}~_^DFK*$P;d`^UVesIctn33ywklaISOM z1pbbkJ!bJv-;^F&8!}sx+$bIEU9WL)ZV;5q=kjQsHH>wR zLlK}%o6G3hLV}-8I7sG+#S;y>L-(+F(0Ebnu&eYqYi9ybN=PYGbP$_JHhbueA@VSZ zFweCeLAS4!ws(o!WZfTlO{lAx+>+2Ic?PN!z50c5IZ60tw=PG39kK-Wdbi{pxG|3z zcy#m9byLdltsj?iM2)yI)3~B}Kdw#@H78pJ9Zheem$9?p5yz$a8HJ9NwpFoEQnFJg z%EHkdgOT%6jBhp%lBHVwSpvDU=eL+!ifFSa3U;mZY!Zm2K5cLVI}PcsnLe=70#@c9 zA(Avn(HS$lpJsPM5{?UhnP};+OqA!F<77eof*?p+DE}*aVK_~Azb99c$p$1a2Ei0k zbFw9Sp9@0x0j|!E@r+O@^Su~+X&uE2|IL%wM$IXRxOq0uCWDoqD2zW6sW=ZbMA(6( z!QHjys^3Y2ST2NEe1~`2Ogw9dz;wcBha_C3vD_4rzoCMSR7N%8ORL_Lk*B``y^)|T z(mE#9w-uNGCm3%p9*^`I62_ZnuQ7R1MpaV>eZ4}zSI4|riBGeWDAH})H@G4uK0=?T z_~ZF;vp`DgDZ063zsNJg^+>mLdWRpiKd}Xd@hz&Y9QnD<+d|%rd|!gy@$ymRI`h8# zLXwj5wBys*z14w>YRU|%p9NN_`zYi6Wm8@(_y&h3s7-bXx(oh+A9WVuoslp))EcmU z)k%%icKG5MMmw^jGs5WxkhaSSR9!*H=og9RPKMb;DS}!ju@v6{KsQHa){{+=0;Y&P z-Y6F?+g(pUu$)Wl%|`m>%bH6}ap}>hTZ^@h2CtBC1_)mm(sVdlJ$g;$l}$xOh@zzD zia@=bUj2N?p*SM7ck~@!VBg)5X8hb}d>R9Djif?l$$AY&c7(tm2AXG{NBg(QU(J8s z&zl=!FPseM_Zw~JKTVqVda|APjOnv>Th~Y+lDDXrV8}*zZZ+qsKtSmaK2N2L`(M;R zQ_4nL)Q?!-VB?v&L9qitVfU5qUn<;DsYfa;$2A}2xb1?r%U7g8dMD|DqJ0**%`@ZcgSGDQrT@h9I-$;G=Yv1n_9iu0)htY*+d=Z~qQ4 z-ye&b=Ax;`nHZvmWVygPVxIlEp6{OqgY1R3-kUCPc_Kv*o8NMF+$QTs?{AwL0eTzI zjswEyr=(;pFCgi@Z1zWe#6#7a_4t;a(a9vl+~{$RD83eRa4XZ6;FxAKjMNLH7xPVM ztCFhE7izH`kLyA;J-^^s*!56FJv8_{*KuKFw2dsrPQ9+{ZroG&djF`x`UBo7qThQT z_gP3!VH0h?<9bZLO9?lE%qS|7aTMhuY~L0!_aeL*4btBdn9}a@tI`d42K9%swbJL@ zTnHeC%Kegjn~^n(TtG1oH%$GDr{lxVw2@j*yHVpKezJTs#x~#O3toteMlU?bz#u;n zxLdNZR=Z5#=Q#f3Y1n=Iy z+mc)DO?w~1?JM_Nr5^g~lz!n3#y@SPuN15iQ0%To48#^v%kCEhWqRbH8GV?%Y1do1 zIOt_|I9!##yqKTn)R5%<$5q?ya9{>)dnvw_qU%9DPYIZZs=;5ml~XY9yq{D5qfsyA zD6A2AiB5X5zZ3uCImKu(D5;W)9@s6oZYF|uyL;wnGjqiT7f%d9V-<)?5(MC{cQx>& z!kPSP=B3=fl>uFP)Ft+|&-JE(gyy0U74(i`FL!Lihj+IqJ>Mrl{&|EBVTPnX##Rr8xqqxLYvptmRDLPbA~vh}w|>E`0neVJ^^4UxZC zO5SuQJ&;HSb>aWloz#_s!m%O3tu;LK$~3ZcaP8@@ax;7(kjqYMDo4G^Mys$PeB~S1 zzfXKI+z9?}zW-60Z)J(zC3z;mg>P)XKkVA!sK`1Hr^1m$y%w z06yA@_6P9+}c_k+;&zP;yf8_9J4{(Y=~;6KYzN?He7jZwPH`+u<7Jb106&@!YGDT6gfO7hqLijIz!txE^dxU@M5MbLb8&gb^$r+-pccx z5wCZz^=|djzBhHX_5D80K1HlZ^`dnm!cNTjRC-1Z#8q!pS3;*V#n&T4pN=Z^cD|wIT&E4BQdl;r&Kso+M~^|o9YaSrdWKqdOgGE+l54+(6XRy~=Hi$yH%s-HX=p1~wYL&91MDYo*KX2--8o+mHY1m5 zs!rSg)zXJtfueS{fhf1T@~|h4Ge9uQqNUaLFjt-}DAV(b&H3Dp{}0m3NNmo(%B((s&nPsZ&byr}AN-H#>EPj{8o2lbC zL^>@4ZC2JQweme6MdTm|Jjk#oGh`nM2?vOlqF=3r~v_x1_nCqq94f_9Z3z$jx! z7cn!X7@7+16Q%=+xth7IoGBoYfZ;!Uru0;5Ik3C|Dvl4#3ADhcpM?(P$G3QC=13ga zhjOOE{q0=-&u@bl&Hv7V#UJr*4zvi77ZPrN|Nb5HJAymh*qYfq3q7zZ2Zl5Gw>_Q~ z5ox|~+%}(Wxzl_lpU3HZ*(Uu}XZ{0ezLqmQQ;m^Ec&o7@jzra^=43iEN-}&zrY(%m zke9zfutYz)eYp^JZ=@hUdGts8!Y=<46my>JCMfEg{P7@HKsweZcQ1LER^OeYqyIzb15z?{m3K3iXi_ zvgktfF2NliTZT7=iLt&wd}N-Tf(t!U5&+OF<9T#K4xou1_;gY29Gi2}nMIG-ja{iv z?iVnJ-RwbH^(`p)$>?!a1NNvG2{blT8R#zQdpOQ1LdfnxZXZ{yJ=+nlKEzFZ@0 zSc?O}%h%*<^6xUid45cQIwbRI3$51<>~t@ge&+d7lkAh}@@T2=OzBFIQGNf3A(wiAV>H}=uN;1CD%*9er# zc4l>R#N5e16g8z_)+Qb&ajp5hsf87zF0;^g*JaN&-R!P zKNMaQg8_kDxggB|K!*Vck$F2VZHt!&ShcUffCr(H+2LP>QbV9jez(orgSwdwPniBOIUi|ye+iNnZ)t!l%njOQaBBfa;uk`TdJX#(eJ+AK>E2ZO;P|G)uiHD|R6HHxB|IU)sreKQj27tFvo zlrFR}?m)8^}Z~CVbFrK+d1cUmscZ$6Jhr3m>Oye_&hqQmj~ir%;io&%9rnO|>95l@44sUV>CJw?JCGmOYrG_t#>{@4-SkNJA@Dp<)csZT$vDzgZI|I!5pDry%gMl6*_*KREgiOuL%u7)Etms}O zZ7>}qg;rOI+;W(WVGH}Wm96J(xaCG^eF4guQhn31nkZM5^VdRv^rZaJS7tn$MABC< z+nb+N=08!2AjOhoL?RFdJRiY5z1m`2*}1@YD|@o)b!@DOuaV-`t)rxLZUwErDKmc? z!@RMV$d@Vx#3^GSNwR_6iESXMH9^RedXT1dsl2K-UB+vIc)s&-cMpnd7;#?1CKP>y z?2cA*4CCY2a-MGk8S-V9%N3sY#Lu(`L3VjpG;JvruyC)G2Goiuu>2>I+Ec3 zrQvdXyH=$=|Vqx7>^=b51EvbTBx5OCX!M^OuZK!KP zi7rosF0mc%)%t(CWvfE+y%SeoL~=#fcwpbOCe5jDN`ypFa@9c8d|W-?t(*Rpmu{;; zoC;in6s;rl$7GDs9$Trdp`>_J(0M@uZmh}#ch&OOs>i%*5_;zBL+@n&geB_x4rH1` zmSvci#cSL>F%5Baja&q-SO>|GGBWfx&3ubrq`YZv0D(U=zbn01b4RNeS>SB{4HNJh&|zeko5BR`|L| zhly5AfmiVQNONzXzslJfBl2ux$hq}k4&k(*wS1&%*oQRK3xn3oerc518xKRZjB7h` z^KdZ**P#$pLva;Is96aD##+YL!D>0$LD)Y}Fp;SuD_L;!WC2ySZ60T;L}6~#7CUe1 zSE<;CxOstEM7rmFIyyfDPe*g_ma{hI1e7y;XSQhF+j;Sf?=*ab2nk=!pxG0xd&y7@ zUl&5jL6YrYXZ79L^J+T_DrPz2v85F}z$qYHwvJ%MO3!mhdBdrQP?aNL6Xz2YH_zN@ z9PC5=xd+TkONq>TXq;ye+?I1c;s@KLUTt6#g?czl?f3(?P|CqD7xq!J%A3|Wp$)!G zdb%x_6`~mk!_A_ZI2eG#xhm&eX-Hn2D7ebl>X^UC4xy&!vtO7+Yn3z~*EJ@Y(PZOR zfNhxe^tx^w;O^xN{-{p2AE9smba+HMC%OGWx2E;goAcZ+TWZt2uKOcvUYF>CrZt1! zzW?#RveyhQAl3J3zYq3Kk)9zDg=_}GuPrab?+acL5HQVPN%}T-*-T*{^P`Hl4Fv-g zk)5-QmIn$sS1!)CsvJ&s^o@|?yaytP*J*<-49~Dw&+He(g<$FRN+?ec3DPo~D5^I6~N48v)9X@(De znw2A{F3T60oFpGy&m5m%4R5s@V88+p#ctzQkr%;R4QD^adqtaVsEAJ&e8XOk34@MOlw15H5hhOm@&W>U!<3xR=FxgsimSvp@DQ>$&dU zdx0B|5Fgss0SeiR9IGLTC+WLzzX?MxvaVQ;Od=XP6TBcdtS_rP2BKQWTkGDdCTEom zUb_vo3&%&$k0CuC#NF}gYkl6R3Al_Scg!A6zlcI{1tQ-$n=|M3H1hGB-se zS9&rn%?T{FkOZ~FGJg_UIdR=SaQ1}t8!C@GBZw3kVoNGH?=$y1P-7W((c63Z+P7ZI zkqX^i!A?2%X!bq}ns9hK7XNwVpFluq1gXz?#+1TA1tFT6>)NUq(B_d-xsjQA`S*yY zr(xjN>Wv<|pF;KH$CCCu+Gudy?Wj`QtS_E#AX!?? zu;1?{z2;)dQc^*Ze;L2O3$)DM03!?#-m7i@s=d$D^GnP=S4eOe^Y~%=I^x*Q#+>)s zoJ%7-_T0Nd2+R2mm4s9i>d;m2x%Q@g$$VmxJU(wF-$^TO_tWVMJ#$OtvEA=kUco0n zU>BPgXfKzS?f0N!-v;<7w5buT?Q3n&Q{zcIuZl33Z>)q$VFzx37~v{pkHd2@z3V-} z*6ZbH6YIBhI7KAzkltP8YqE6iY?Bt}3%Sdcs*Y>fiN!aSgxly(2$@lc8d&=V7hheu zHD6v|`79Z3SU~bZZ5mqJkz87#S=R?HbzO2z*YJ$hz8&vy7a_d!4_OWNf}Ix)Z5W6> z`>kR$k&y{&gYzHGZW#)mN^UzFrDP%3z2AkO;<8qWZ!3NfXuVj}G?ouu{p!qmOu)3R z7xQ>MJXmO6a)CLh%*JF&l}iL+5QZXv`2MXlMda6`Jfqo+|YKiFQ#O$JJ%#steGh%fX~qX&p?MRF-4pB z@VGv~&FsO>vUvEu(CJCrb;-epJNj;$>93K_Pul7nw^^I|vo{;kcB3_mDK4Ue4!OoI z_m)WSjazk;Ffa=JtiXf~(4t^i4_WUlm?^s|9H|ui&Hn z77E4JnyvOxP=^PwdzlFhXM7*|gIB6!ku- zNHHO1yCdMbcZ<~f^f8!@X(Wje7OJqM;K|?Ww(~TR1E|J7KkP|4)$a~YQyS|nV`X`P zuLXVznQJRQv0|~$#Y7P&EIQXqh5yZz-)tniaDA>X=J%7^E+z2%^RrkX><6B4WR6T_ z{(|v^7B@%y`bnwoY8Vu%Vx!U!4%Z$z`u&X&5xx_6G+ki7TS~QGZ7V&6d~XDUiXx)i zKi-qwd~4)>R|cO8Kny$N(6W)=qh2;9cOm1GaN?gcZcbuP%C4?*;%5l#8W`CH@l^1j z2$wl~TlLWNd?kX76FkYuVYQQjjnr6)&9Bd`{$$kAx8v9|vu!^JUSB%SG&+d=ggm)= z3(RIIf^3>Y*!#iK{J*c?zTL=^SpAN7PPT-BFB4n!<)d!6;F&Y94DtBgJXjK^ zgxbGA`%i{&Os7B@-IM{H|NPS!1^5r zOFW~Ey6zy1`Z%f#2t&x9oXQyL~izv}vC#?|NI`@9-_0QjpEC`vq9f8Y$D-GzW4 zKX*8l^83K&wV?$(QN~?tBj=Tn^aT=uzu}vY?N}~Wg?LjyU^m+oczrd0t$ua`1AjjZ zt2gbRk^e&*6T{~FVGH*dej`-<&#Am4yj2k#jO+)}f3aXY&)CgLi*46eRY!viI5q(= z?N>FjZ*n2wP~&PX2Ap3XCh}=r?Z`C5e|}e1<=yP6OlcQ)Fc(KTQdNyOgl0JVIeDAi z*T%OtjP5Fh6S<&-_#H{he}iJk@cLZS^RS*{eDefz+?zuG-rCcUP)Vh@KBh*aC8{&nqT z83g$$`i_52WPpuOW#-l){d1(PZ#QP1*nh#8e;q`tx#4GnKC;(+#?_Ka2>0lI>X`o@ zC>SQR9z{A9m>4?mbqf53N0K{AZQ-Ztu|Z4_EKx*=WNsvm_}`m?!H|Qlr+YcSpnR7S zMS(}j=D2ws-z>n%L10NNa=VHek^w*6Gx41JiKr(4V*bQ2wTyc_CQh>d4U1u$_SVY-x!)&C z6QsP{8~6`1G3ev|-*tl&MowjrwZw zX|-1QVjLB>(BI72L?w@TfNVJ&7AgD(G6*)=nq;&KDPf!joI|{X=|3cs|A!N=F9)v| zK&wdnFMOqND#f&ayz?K_$ziv8?9JA|LNu z0Up#ilWwTW_5otl0aMJ(7~iIh8vX2*S6Q z{vOCS7L~t|Di!<}9J7(oNS^cW?=6bWd~ihKwx7egX1q}*i;I$4DYDOV-~3w{4u26| z)-{`Z%cuE|Dv6*Q;gKaC*=6Z-6`y8mjEpe&7jzrS3DC~J&%=Z4!Me|L3&4;Oafp;X zdlP|w5$?!6tHX0-lC_%O>}6HuYv7T;&>fCjU}#d^m4QeQzE>71mG!TK>8SuZQ0fqo zmG1vVr`3`lIU;Sb{{ve-D@;R&7f|D`U|ZSZYqs+zFw=J&e`*_~by$(gZo=lLfchhL z+pZ0$Z=`1kbK9j#Gu_W%Qg5XjosxxA$J#^n*=aI9F zPFwzvhvxP9WSi`J*B`m(ird2OkfVm{v#je_DpKxqz_0#omn+OO*KEJE?>Y=l7 znp0xYSxeL|_wBT?jh1w(I;AyX{+HrL7Qj=M@Lo-|7#be-JvexkkzG^s=;_m^(sFWQ z>gtI-Ymkmh{^CgI*RPkBpy9?WiU1u%+3(IrH4VK2=Py^eEp2R6OLSFwdV0XL^MJ#( z!H529H(F8Vy~mSWcSTs%v~O)z4*cMoMHY5^`LevJpZiut!_fG6U|Jf1!&Uxd!f%Z} zj6!2j)OfX3)s__GWulVTf8S|$PJsXHN4%w--D+bR{-a_NbxqCi{pmz{O|ZiRN?_-8 zGV8+IQ2&XgZ)h0nQSKu-nWrD^GI~ye?rfeA5E4S052&*|*{oL+wHf5&>M_ht)!oII zo4AJr{C2QwaX6={lmDV@GArnWM(vcgre<(V4Eo)B_ZFwBO?W%@m%A6&)`|{Xf{eTc zTU%T2Lm=X=t~>)7e8Jxx9Oy(s*%*>UO<5S3m?U4md{-@`7F6-QnIVeW~np&T{zYZ+Dize*N0+Bv$!T>Z;U_d)^b%$jB%c zuF4malo}O&F3*-ff-15s!I7;XEsY+I9d?+zasKD|d6;3#G?fL786@?aJ3C21J|>a~ zwjI|rUG+!P`tB^zT|zQ4GRFL!4xH>*G@BTala@1*+czF@l7fL4wh^!H+@ z$!v`ZE4epiW%0Wm#0n)o`+gLUZ<%e9Jy?95@9osQzLIPh)Yj%RxGy*O;J5vm=Rp&3 z78@IDaxjj3gX>&kZrVx(dHEGU{y4=37^e5#;b&_F>M>9PwqdV`XBM8!CMo+;y)|ZZ z`@h-!Wp;{YA(9#?teGvmKn7u3i#N_BZlB!!Jpz962FbqaGWNf?zo(+YvWA+KRhhw@ z4l%6BujVir$a3F$#;Qy{O-dT&AgBtCZ8}Yb)O1`%Mx=np@d5}n+)y(eBO^ASvGP@N zlJaes$e53c5U+vw#|I;@f~_Y4q`$p}h#gVPw}OI#CIb)-f=J;f?zW@ws|zF|dzQYz zz%JXPyGK>tFnTL&)9RpS94@)8WD!Ow`~LNkWE&?zn1Fd*=Dny})T_1fKm#*@_Kk@& z1%@RI>Wu!|@0*#X(jP%|zgmtW<|FxlUtC-)-Mp2#*;X5YTs)r-k3S#Y&z-;8ap8IOBcxGbE$!CbTiJEPsaI`hQoH?2ME3fQ*#Ewfn zOV91Ct?;5ucUhf#+U1W`pFhtAM)Le1p82J)B&pRiN{RD~*ypu<^zXIOp4c+?ybCfg zMVxN&IQjpeBn*dDRYgUdB0F0@Ihzf$mW=)iU^genWSlPXO@KeaKavhPhDH3SMo?zCWm8WH}dW;$P-AXbY5eZ-uUm@{M9>`4)T>%Vdea(EP*%uI z3!lJ#-B}5!0_2Ae)q9HHoXfkP+wz0cIgJ3#FY(F@BK=|Nk4l%7Tjpf)m|H#kvJ~ht zMu-HK&7WC>a0wcSg}jLr1<&$-(+(TR=}|d5*!mi!e9qa(G4dOcZsS&oNOCj7u^}5y zORJrwS%B3q#)3X%#|e)AzCSVw3iggy+8Y3J+~JiP_CCwP)iK({H{z*M0=_*2835qW zQ9S>^XFRoL%XuWV%FppIFdoAbZy+vUT1X6*-niE!L^S?I^;m6Npcp_riJb-y%8G7BgtGd|XoYtM@ z*|Et>oJO_BPg?9P{}o)ZTA9!=MI=8FFhp70H}-k-x!5l)lU2!Sqju+!nHGcd zpX(RY);f))_>xmlNRU!6=H1!!FTN&w1RrFDxIW`X2Ah@-4Xf8+kKv*c~1ony{nWXszt-Mwgc}WZWovNxm=Z6pNMH$n@Wao76r}INIj0yLmyQgn9Pc zw?wU3tbKSAqXf4_tgtpnS57rptNP5zCM**RiRHCBC<1}vtRs>CJNOc(VP`90PWEb7 zRZ0d$1PGbYPa5aHtTc)&u6QGDQXGQ#5zETUMa>F!pfM13-eTcEOQ{w$HR(~7=(orP zHXKu1s>*?JeU@)|y~R7Fq@sv$u6y7YSdSip^$kWO7=ZNR0M7t@2gd)e)pn8fTm=y& za!}j_l$XH^1&?|c)+g-d=EnQs#&Z1xsTItA1IC(_McLC!owfq~4Z84cYOS+b`z(nY zBJTZC69^SBOVR(<6FQxJ0|QBlf||v`;r!#AoW&(2BF@g(rF80}{B*^Ok3fu91H2RJ zmY_`O_t+}UJ7-IY@!HNAR@{Udog}NKK|4NNDw3&bMn|u#uFtsnl<@^9nHz{fr0ChB zZ`VQy#c9M+G=bn)cjrMNScHW1h_NGVblKkV=Y#tocStR_0fd}X8x0iWN@Pv@XydAr zJP&h|T`gKdpr*B;R5m%P%7>0V0p+0?jcUC4vpl^bDeV{KTOlb5GiY6sx*Mc!eKOr8 zY@f3&O@KcbeRnhY#|f{4SALdFc1`_liN}PC)_KtnFsY3W@?}zYYs+G>%iWLTsLIk|iu?j80<3thMt3pTd=ln{dgd16JwdIX z#d@hI9Nd)z@Dm28NybS(RSC@UuT39~gAsfq*Bw2?@(7l)Teeah} zsE~5MPKx@kXu~hPG&3_Z7sx|yTE#1h@YUh zEiNk~1e;+_PhDAmsoz;@AW?rSgUa|L&m$=_^}saEJn5VV2eN9DV_}X=fFuqG^wn}j zS9s2gn#)C>4UJ8V=W0*`d|O*@DJS3rEhD)+?~dB$!AtXNui3{Yr88z@vm44!K5|^8 z7T(TF!x8UQR;J|Ekr|D^upDWqA;Y-(E zWrf)5T5i;zv$U1*ih7x{X<*l=jkwGYRM!tW9ep!*l9HJA4$}NTIIbDbUH0GaQ6Y8= zX?>?4xIbobDLsTh1z2OWvMN7V#@fcrEi}-opqY-oOqt?b&vwVC=H8>&=?*ChQhN@mpZ9wQFj`E-wgNwzGn`JUQj zc4;9|!96Q!2ukPT)kU@AB-7974+uz#0SnI@@JnkO?a??n>>l2~|NKVA0}0vQ-hO)m zCt{!$lE06lwTy57tF*8r@ycn19-f8zsFnZv`-1(l zY1j8I{gDxQ#sZ?0kz5cvJK5LU*p8}Kwc1Zdreu5AQBlb=UU;Sl&(z}GCFY~*lK<$k z-M-U;gpQsb#KvfxapNWZ62#a7Ex7W>@;Ze=AhrJDm!z1ze37qfm|f+Et)cH&|gW9pZd*=(vv%vTrdc{nhXL}@f$ zzZSlFffq)+$^c^L_Q@z<&e3gj-^S&# zQ@ha)t3~wMd19uBxkZBcFj+znFQ|6eos$v-mZ*}b=$*~2tsARu2jjObDu2Yu!BPAc z&3SA>Bd__%1M2!0P>tb}k{a8K*FV2EpixI5Jr)!ol`y3%u@@G*qF7(HE-|#ig1xIQ z!GIY*R?%-FP*eXwe);^mdp!Ay=iTmI(PcIzHI=B+=_V;hPcd;ZN5Y!RNu#C}=Dl9S z%R5PBmoEz&5n*{x1b>q1JhzY0H_Xj3xPlSK=%}rAFQAXl+J3KIy9ITOw8VSXtfV=w z8`On32CDfazw_40?>VkW?Be!8|*ud!Mkkd1yApG zlf8e3l~|XcaPCF66jmTy1Sm}ORoMjPSh5qfpf#9D8;$~%Z%#tGR z910|dL5inVYFrZ~{dtwG7$1~_t1Y^5Ud%SB+td%c6~cYKqy+! zTXFI5zWg#6R_|M)O2E8T!kf+UuA5c80P^{Eg{7EfrJlUds45S74#-&q9n z4@`L@Z>ZNeBv_xfj!kgqve6TCqiv&C_au2e=FwMq%pO!GBjjBy=4J9!q5JN_r&UH{ zZ|jSR{;B2Ukjn$>F%&NEGFwo;!(@CF{k$9Jj9>`-8*bd=t_FcEh_M1Kwzn>A=3Lk; zQVT}%{2u1Sj5XMd8~jBKi&}0iXRI3S;y6S2j@T3%3Eama4p-(xiZtR9&bxZP|80wd(F=^OBTvMcVV0KV;s!mp`%ZnL7puET z{b7CT_j_!m&Pi-JvNmdA?gP6wA^+ zv7fQ9m({GnfsyoFuH8A=H4ji=)+5mf`Mx1+O)S+t?zOn!>Zs$`4NwJc8LS?9}0HP3-#G>rwSs6ATsSc|9|6 z6 z!A=>L+M`cVQ?eYKB{JVhITLB(lXDU>R};J|BkG-B)rzhLmA+Djk39OOfXKPIDI-Rq zgS_<40?II@mDVv}E7*)_`ASD#hq$OU2NW3gE?S%2@#kQIw+`6BcCYcK@n|jSu~|N=>U-$VZMW&#F2vQ2knl&KCD9aiZLFH`oFZ?A}CK2LQ#!~+@A_~kRsi7()r|>=M2__DnE-)zXs7m1xaNy3Ic-h z=R)a!kdhnA+D`6i=)*EV) zYsyv9UP*0d|7rK{%wv-{MGNaeZYjoAh=dl*L=kl2W|4*S791icyQO}2iFCGaVP44$ zakKfuusoh+2@)Hy7j<~SWPV_22*hcncG-%Eh};Bg$}FSaEP0JzXbTXe)5T)9A&qu6 zZ!@Chkw)sbH+de*9`zW$J2)^;`3JXLq)c8RZ=|Sta(V!_l#brgE|u zd$sMI)RrigY*YgKa#Q6GCA{&W>jzG-x?%=cH<-7Y5}JX8T2PmBMqU0vT{GYK@4DRmk?yd>$gS)$17~Gu@+!-vmL+}70xwH3I_g0ApRv<@|DmBzm{O!FeiAuUHp2Y5c!g`9*4amtzH>ATl_}LBdLl51AWiQQBJT zpSBpXXe)i>S5m1UL{?6O%*uk*YQt~r%xZ4qg68DJv!S{bxllmN%cn3OVrGy1jeZk> z#ZCn)eTv>SA1l)+`fVa?P~h1?+RKP34d^z4YNeHvM{M!@Oj*G0hL_%0cm7`5O5|eW zAULoKgG0X*`3|eRbrzF|j6AYhL2+o0{2QXOKyJA1%1R$4lDzzUx$%;uxABS!+W#3V zp7JwXA@o=%#LW@!a$;m?s67`WOib*gp?j8WQ+2}G`7OLXiN;v$#Om*#?gKm~uhZGl zKj+I_tK)FbzH6MTkduhoxXJjz;4Ffpreki!9e(#pL`(5ct6da6U)@SwzwYMiy# zPS+Jfvu|J-imQBXY3Yd0tb}82yy9|Fp+q(^Y8MFwlQFO$=JPj<=J7eTSd@RBTiB~O zmo{7#CIXukl4dTv*2yXq2Nu^p`2ZkNSLWnutZFc z;WX5}UT~r%*)u0iR@LG|qA0`L8(~hG#lnYb7-<>ksJMjC>^y%gO3Naj`{l|3tek?( zcotH5MsOKb)+x|tuWRCWqBdRmIRykFHbFb`O!fF5^2Oe0X!cD^lJR@G%3iz?e{ljK z)v#e7p!1woXQC_)gYa=Jv7feXa8@uOXDYm5mwmH{Q;c$r6~)f>&eB>TLUWO~liwWp zIk|EmNxDiKB0-iXGjKVdA?;8-LPHag_gvA}j{$5`FXVU0b;l zw@4v``CB8QZ>auF`arw{e2df=5#uX1^M+=O4~42<`h75^c6I76IB#ym1nX4@79ZGt z`sXSW(ji_O!aYMSe}h>|-b=CxAg$g*;1SsLCsd2`F2o?1WsS9^MBB{QMA>en8f-c` zgIJX3;kniN)~RKpwFw-q%^`I{;cSXnIUl1o@Vldx7t1}2pXG8Mp1#6Mt4^Lln=cxE zz}(k}D7cx#BrYE(%DD`_5(9*Z8&LVEJs(UGAB6ayk%&ifIB!2w!Z8%_^6)%%7R=T7 zXAzDU8c;)vQ>+MvvUilRGo{Q509$@55~o+oiG!}aMv>XSB6*Dbu7YB7!9(2ni6lp&vy z!@;onp2!6G*GI)o@{i%^{sIHU*#2fw-<%06^YG&x20J;i^GPH`Xy z&81%Y;%@dAh%0?IY_Tsa7pclM8K6No{wE&9Ta6pPil5ce0{S`%>i#9R^EVv=0PDG1 zztA6B!Lc}r2zD=B#;(q)s7!(!*xfx#4$~n&=W?md^fOID=dHmZD$3Q{NttOR? z^z)2{rYF$b^vEK=NJkRx(bH8^^Yo&s(CE+G4|wwaM`^g=UC4>vo8an$aF~qzS-S$( zyrqpi7=N8khc{)CXr3AXOOH4;lDdUAXf(| zC#TCALQv{=1|cB>RO(6_ZHW*1r{cp|ROzjm&i3|N$Oy{Iw_aYx`RRx9jJd4^d`vA$ zivTUbW;IVrYHC>GjI#9-tYk1WHMD)IEB^L%qIrz~k?ovU%lKzBog_AU-Sn4~7j` zbb1MFto=cR`qQ4#xe?M6S@;(B{N(bu)w<&aC_`|MI_3amzxt?y4)+x=$=(%$74u5S z1gD|1;>c?0RXM{B6Qu-~WkgKts&v#)E+7GlJ%w})jjZMl7I0o(N*lc#VP+O0!XTl) zC-@**=cw*DD0;QEh}LuXv3B12HLg@&Odn6}rO1;v#iKX&F)KCYp-6_A{niBnA;dFrpPe>Xwq3nLK6AR#cq@oB;f@2s%3dtfG;2efVB##@wU)wu90eJ)GeJLe6BCoO ziwkLyTx{ItXe)A|Sr;;S5kZ_ePo4$R+Xl&1CE`h)@#_gvziuM?H`5B8mH5z#`CHzV zGRD83@I?$Pzm*cP7WTCSPclW)RZ;UXAP5;4Oerc#djBMEpyEl$gi}VH`6=x7In9YS zfN#B&_knh;3p^yhM`()=%o*JO1Z~BWl9d-Xv;18ii1hU(Z__>B5}k)8L^*c=>93Ga z+TScxex0eL?Wk^F{2XU%c$Vw{{Wg(L``S(0q(?TrS3Hro@o%ElKc-q@bzW&$fzATDP5PhI2%HE-{|eF*<&93S=@v-wD$`b1}s2uc|l$ahIjhIdy8W$7WgB%&xfY#&FTr%Qf^@aBcV z=kcx|w=zTIF`e!@FAq(_h22bq!A!58){f{rJU9)anaOoMp$9Rc4zfm8RDhSJF;+hY z!CAxP6ds!zrb zl0!S;(4{;hlP3pKN*>Hv6(QB}x>1EGa&uQ^ZgY;cxreX)!LHNWHtgsVI`MtN<>mjG zok6Y(5ASS`KS%ME+d8qV=<11#Fw>1<_$%*RVG8sm10 zi&X_T)lw0#OiR2W8mcTG%4mkCTLO)1uPm8zJ9~i-B_s-HWgX1dsMZchN+e1-o2ug{ zKSg=lQ{Rp91{1TM*;lG5Zg+^LLRc)3>le0}`)w~$$n|jYw3i+uzz@|ozpd^6_P))p zo;Io&7%eX9#AT%CbB)S(@GKM*Qccy~NE_AjHo}7q7xcBF|0L4V+8iY`-WzFO1pGTD zHtv3x%!}cF@+C}xrzXv4foI`{iDTGaUExpxb#^8w+{uakhs<+ytmEvL|M4G%_Tuwo zK2dB(?(Nx`x)8;My@QLijt6Ge@<>z^k?~2Rtn|3sT|J7O$8$r0*Qxu>A2{!JzS25c zynO_Gi|BH?85J~(ePNQAGL}Wo*r1+A@w(Z5?FuIcIX?sH=jI}G9K#mAqL|UR0Lw7- z?c&VA<)y^83sllyps(FgaL!^U2NI4bp$|vhyN4HHz*`0H_Vp544(K4e?%}M&PgN7{ z#^i}*^a^H89oy5B<%WL;YRCsqZXaWKDGYMK!t-T}mJIkkxkD>sus24d>iVm!_?D}o zdUR$?5eh7F$Ak!7bMf;eml{MJt@FE1bBZKo&dq&YtvesJKJ0`CJ z+PKEIN%2*+TEWejCOaQ{9D%^rN<0^)p#3XOxV1I@7C$i1I{uW_CZ&>0y6_4k^Pz%1 z3HieW!@(4$fale%omab@(;QEEmd)JaCb7HzBXw_w>ho#HgR85XjEe?8_p)Y6o(@pf zWvhRhOYw);rw_6O-p}W!5wclwN=69W=*lrWxPgw#@Xz^&i@U4zuVUwZ@6>GT=F?Tr z0vQjQ3e{9@F3V3ZTaFrH;PE3mWQtaFVf%^^V zlhUb2X+4>W4bd+nKD-WoT~3~kuUm8a+6j7&-+%0aBOwjg@kP9*&Z<+VVsO-1#v9U+~ zE74oS*VWXh_jW9-nZ}Ipy0Ml77;($6?nQN_xq91^TslbvQY>um!f)=j&*|-iD5wxK zIX>uR%b|cp>M>hA=nHT~U+P*adA9)%LB|nM?Ts5mvC4-NYdju;t@?fZ`cFZ13r=5I(t%b_ zs3>=JtRy5(`dj*=)7!9Ge&Z#eXvl?sZj`64herhw_jytdA72GiccN4?+Yx7ZSBi#9 z4Is<&ja#f&s;(o=2s!x|uFWzAlDN6Qa!#ZE%o3w_-AVfy7dFi9@n9-zutdi5==UbZpR~9Gp;M`PENgz#F+ASZAzzAi3yVLjEB(9f|zEPgl zGzblm&0f~7`faaVrgI_@NB`J{!SB{h%2@-=5P{h&;mS`$xYt8But)^&m^hZ-+nZ@l z=F8;zVmbz!YwTdcBiX%^yiX0zj@Y`StezeM7;OFeJ+vN_VQa|=9*h^g4R-ZL_3mE4LJ^x>P%@I!bK4Z&3L03i z@=V-0yKpAh(=&7B1GSg8)%U)T#y@OC&H0YsCZZ1R_UO=ILOgOyvXBOX2mqgtG7cz( z+i*t0!NI}uUi-gBkF3nhB$5AmnG({L9!h~7&$cTW(dJPjsJwP?;V5&Z^(JdySfoNc z*1KK3c{8Y<7f8D{4HFn@NNumA9NQW=QcjwcQF2+q;Kj zk4Oz^-6Xxp0DD2uHx#tHAg&!(g?4SmXwKhwJkuyO@;WwMcLeO|I>Wtq{hPH7jkbj8KL>laSZdM z;cXF(u?W{>eus}y4}u&(Nqo0<$GY@EO)Z91$x{WDV5N0+x<>4fW|Mq-Bkw=BFweDm z+NSO6Sr@Tafkn5DGCCW?Ik90No|8s2nep4Rf?^ySjV2zp#%sZs6}Mko)x{Ibg8|4p z9|b?iYmD!I*WPkzlF}#e=?u^?*#)f|6wpYO3BL4Cf|nMoljQ*uCN$O5=h(0+L z2=U?pJ9xv~U}AR8yCkj+LqEn}bzcWiJ(A1ctv|mAzLA~#y=6!y_MiQwPT1dK3^59F z>*E;Q)piTCdTaW7+t23x_iS6MNc-je^XJDckO*|HV9$Byct955SWkdFCFxz-x%oNh z+im4N-i!o{N>*}IoY#L$M|aVQ=-$N~fFZ|jqNy3!`iOXq z^#&o^G=R(C`**X$PPNYw)n8sR!0(p%)V3OrN+AcIDhVaWpz8K0QvROo!kH^YFE7oB{(zyA8t0!OpQEk(Q(DKL zbJRCF)9&YR@c43>G|sn^4qn(aiI|#&yo1qu5}Y*mx_*M!6y1 z@b4tqzFuUYC>Xx2Lt=2xZui5kfvun(kTx%?EWjU~ zd+a@PcmFj#0gW|Lg)>Brx{kPZFEvz65uqnv2`j{2s9-ipTAatWYiM&58NZ4?r$Yq) z*hldd)Lv3bkNuy@G@^(7Y^6gyhk-nYy7>tU4#?lvhu(K$4OLJ;MntSQq>dy>0fB3I z5q(Wpx1}XFW=ZE`kV#^AvrF)GLrHw#W7RG7ln24}8NcNlC~~%AOx#NzCdn-(VUfFe z#;39gIDHNebIfGa=ipK&rQ-pG9N4isC!?XkfJd2Yoh|q1Dg=-g7Uy z>%b87BqTRIxe8Rxt2lR+3Grb^SI~~BHhvnVfuq?3V(WzGVeX5x$hgqMqoRJn> zd_3-CPD1z}YHeMCvIuumYyCFn-l8H zTXeC(%?=gFTYI6yGV1Yp*s>P#h)uDvHJX7XBF6v6(;FDrLny8I-+_^P87RktsWK)% zA0rGFskucTo+gK!ok(1hp-rfRSYVr4S-lWkP}No(oQx6*f^$OlQ^2+35{kX&S+Icc z7O}dHY3v{9%Jl$Tg7UHf1492^x)7wDLnC4*b=eO?Ni0RJh!&r?nAia2r2ZBfMsmA& zZP zQ=F?8I}+}B*OSN&t_-Jt$9?VWEUA*lH##~!rX2Nyn}MXiKj#SBaabxPXi(1i5jhf{ z6TUIo*&z_8`wopy;Op|Q<@AW)xXflEnp(g~;h<~;uuklQgbOZp|8(4Qb(9U1=|AV? zh68Y;UBk)?XHczATcD_-zWEClAAEnq2k|=?Fw|#URB+!b+}*4NYG*4RZZWTPg1#SY zg&s^z)Rh%E=^3xm@@#o`mb5k*k9oxVSGTmlu0FAy(bHsknqJS=uvg`9Qosq>rFdI& z&5X@S0CU=pII*XXPWvjLeW+Dkv&StPpBOoqB}~~@qX4fJ^>5tBT8TLuF0w86eV>HIW9ecIbZ&gjvmxkI5@FMLt|qgxXQOc zWUDf8fKS51v#k&9lmMT`xnAVTC*2xSp&puC?xn3^N|`Ko zoRLvdS496;=hg#43f4Bu)&?hPs_RCa0Nc)}N2@eL$5$czevL#Zuwj+}SEr;@XGf@l z%LL`(hmv|uenNq9kJ(~H%i~KeT1Q!BE3u4sk=MP3v2H{a`l>HLwzBC4op&v(ykXv- zlA^M5)!z9M`!Ij^U3u@YTyKeK!W{0H&%X4~$qAs6`{(!l1H?&$mYX{ zZc62w#>M9`)W-+KSi0l8Ldvq}aU*))cRI4E`W%Q@ucFZ(;KN@|eF=-7x8O8TZN8Rg zCAi}Wb-J?iz9r(`RF7S)+dVQjfxMQrrca^6ATmitcIMawSc{|M<#~A?cYHi$l5gk>3?%WHbnKiwywTn3=+eilhgSwFhr^375KVp@m3QsGKH~r2 z5LE6A0vh%2s37_p+$XlBpjvoov#qbwOVkEt`C74x@gi@aHeD=)1<{lSGsmn+s9F1z zF({d+${GjzOrmdmAfHVD`)vPe?OcDRV_uNfbPRsVe4&()l2~2lL+T_Q-xHgqcq-}v z^J&$*s2=-0a)LTt@u{PRz&xM2fwbx40?fFtj360BB%3KcCl60V+@$6elA?@)M(LF( zfnMBtBzn#sC*bh%*DrU7u{Bx;GPyT3>1xe0O*jD*s??H+?7{J3@(o>-Svd~}zC7@IE=jduM?Dr09O!Jk^KpC8I6LBnNTVf*5-n`z`v9KfKghI!7*4CjO zyL1r<>R=%k)%+JS9->JqXTeoRED|YM*WE!XbAQOfST9Q>W-TGY#jjghBhlKS9z`FI zG4B&A1S~6Rk1~D*10rTBz;-UBFl;`VX|2*oiY%9l;9D?&Va~KhHl8ClEYaMV8{<9M zb+RmgyoH-5>h-Zvo>dFDvno~9w059>Wp?$+%F0TAej~atpTO!_8zl#YM_rVeB;!OG za^q?fMEUTlX61*e9v2c|P+&BfEU6 zJ%~L)58|pKI<9;{BjP<_mYIr%>$fe@imItZ$T)WR-_Qv{L`-4H6?$FFkC&b2b8s>lG<>(R9+p8fL+WIX|auI^ron{B)eFIH5I0J<8t>gsts z8)^hU8c(TptGqUnZR2XozZVhmO(F~^X%&O3i&0*Mo4-uODVBG(_irxZq=<04o~lY4 z2hRI=I_KZqs!?AS{jb71&_^6Toz>huUgt3~p3z6z@1^u*jf>Lr) z(g&-Ip8VAGHaXK3A`gTfeC5pz!)->|0{%`DZsnv?QG$Z3j|zYt<5LSs!=*&BMxB5p z+h?S`J;SpUWt#ZEx@F-|UAG@$|L0 zK43-CpTx^_wh{C!E@~%#<3U0c?L(@!RUPG>9vy`tCr@axUZ#Sjr!4~wXwA0YjBP%& z%fH+@!B01TW|-lDJolgNCyP5gX^XeRbAhVvo&|A@)tVZMzW65hm#UH>I*kGr}GU+Zlf(0BYvpX9{^RT(#G?ces|l$oa2-NoqS%T{r{@wPIGD_>BW}5%Q){ zQi(#a-`CiOf_qe5SLkicC$`ZxZVr<(wyAkjH5&GUqtP8O%OJt+`e4Xv#3LvU6zVW) z51Cqn+L(ZouOgUG3OtrH;K!EgYjwDGR?cU?EZ%{fnpS%(ezJSHU;tJfv9QNqRu$dv16!Lr z!07^a=X?hYIRb&MUg6>0Ie6VQ(2yfakR1V5d6%$=<$mk3y}fI8W+@~+UAV>>UeR$V zfIGYfvQABQCaRZ|a?GAX?W@`{exOBnP|FWqtgdCI2)!iTKRZI5*9pQb+KHUzxfARN zVLAzfoQ6#-ub0UZP*Z>Pmo&0TA@4PvlAqqFz+0o_*N0jTkL;)SqN1Em2NEDc6=i3O z2ERugd#L)tW)g71^GHT!$Zj-v#^@*28^6H%c@CGKnD1{{%JMF_@zmZ{XNbm@xIZ9< zP3LnhRz+XN=*}m47Q|y~_5HiE#hcD6Ti$}`UN}#6ht-|89hL~2U;!=a)KXWyl=8QZ zQtr<%SQfh*-InMlkB_+^_dfWo=Xa{V-c%QcA`6?6I8C@ebg7&_l>Hf(gcn(y8^&3B zoT~+g#&^(~RDn4?!3ZOsMmDzOKF3$Aaz8!&T5Y;AGf|!kTnS`%4}Lwh@CbzECXr;_ zs7Q|r1zKtcCaTB_bz!&d6#fLPk8P6kuknBL+}Ie-y~s)6+3v?bR5uhOzNg8&sPOvd zvm{rm%3N68bmNP)M-N9a0kmo3$`4IHFw*1jXWNa22`d%U(vRwxQ}vd=1Q@a0tFjT? z71W=KHlQ(&X_;g)3X>u3g*b6~Rb`9rdW;rqDzzS}Q>L{|kF#(GHX9H@>jg|%sp)kI zs-}u{JGvw>`2yX2{9@|UVXB(NW4owdg>KI#x}$kX^~Eu?#$WJ9IkHj3>ei*keiatl z*Ro2>UNK#5_961vbw7gwI=j7E96JR@%01N$SHrv%T?S%+?4b2bYTJLtg2F<>9)D#? z)f&gT*DQ}{1FT+7@N)QF;J3G}pT6vew{LvmCbmO%S`?r1Z@9clXGGnJg)8#@*x7e> zgCTi+ZE5vND`p@NR4Gna4GZX!h^9F{q_fg9;&D`(qh$Yuxz?N;RlB@x!`0b~AezNG zyhalyugol?Vx?CcWb8)7Uc$VCVT3O{@_iA?>t(#Gqypcd6>6g`kiJT-vTF+lLU1e? zV@LC8IVc~62`Twf*4~73dOv<|i%3k8LEIcb<}oSQzv@CCmf;&2+I#7Q9E-KswBI2) z*UE?SLi)V|_?ZM5TM7hoK0YO-J+SBJ8ozcr@mZ?XqIzpOc|&vJkAHNIim(zn9DDz* zbW!#V&Q#dd|2&VR3OPSt1c~S>SEmy1N9Qt&a%3k3e2CX#vYvLU&&a7RG_bhU$R>CB z+Bls${)C<$L~Z`)IFxtCzGnufUSGrM1%fjn8LkiSiA;Qwhh5P}A2|X#7VyGfd3Wr%~9U=}8`{y|g`8h<1-p0B~#dB2Cl<>%ij46-xI2vZ+CW)^034UGg@3)IZ2 z^vY}Ty|AT~bbSjuR2(ZD!Divaqi?<%*@5>?Zn7aJFnP7W!t85V`kR0x9rkZ5`Jk7vEou~9>N&GBneng z%}O|B%u3kU+h?+|*H_)f09^~*#9{W5lsI?B;n?R;+*$jvRi;B+oN&IWvEzi){g{#2 zAT*3*5MfczIZw6<(j8s)(p4kn^n2#M62(qDDnJIs{#pG~wX|(yVu61EgJF4zy}6x# z_HWzP>UoJwkQA4`IOy_X`C$o@+?#MKJJ2Ev)2*nw&iIlWuzTm zX*B-md1<@&lAc)hC09=>gPR~V%(7`d@u&L;T($?SNJShnqQ!MELITzU`=l#7>&kJs6 z_;!Lqz29*37rUf<>V%RD1+H8e?7Tr@#K~%9c*yCe=X3OA)t^ORq>XD+kQqTd;yXxPQB9p_eyMgy`@o5aYZMr00H#M@IH_&ObU9I5~=aaQ1k!=Y2eYHH89TQe# zg22hHb{O!%S0PhFiKS%(psLjRrAam7PlZUJP@-N{()ETU&Q9DSrz-DOS(r$RPgb;t zC}poNNXzBG*!vDW8>Ntf_{wmdj=M7{IIJ2o`MYl;VntubUz8j+4$c4Rnf$&iGbI)l zFL&>Igl%&N`Vj@%3hXTht z*%)U_rc_sCxumJZ{rl<>Q3xcMwIs0gB0GlR^qV5s6@a!UkmZG{}CjsB42PVo0){SWpo=%89&(*g*_Yim3tu^G`4qYk+ zdSuVqPkp~IW5=wAHJq`9O=hBz0dM_Cc=ue0k?w#5_Xp+3Rc^1eGydh0Qz(Uu~j zb;oyY-Ij6{_Y-(2QslI9|G|DDN>#5t!o2rK#K@FTtfq_-{k4%-`<0oG!jD`9&?^<9 zNZAM%>!me z>0|FRy}@UVcw&SVEsjOQH1teOWUMEgWj7nw4_(AKK-kL2mOpo+g=!~HNdMtb^_=%F z|KP|U<*A{a6m1-iDcO|K@~p7Q1*c{1kx^!F$qeDfO_9u%$Q?Ho0q{B}F zafx6S5eJS-uU1U1aTWi;3&F53;4DpzTJrt{X;CCFK-lHtLHvxUNzC3ZX3<3Bt92$c zwT0mCK6iLWdAi`p=$LShEXm4g34_B=^y{;s=z?#%zDZ7mH3<_sOCOf^?)9A*x8%4q zQK(&(SidWh<3}*DiJ`;Pr+>C?^z{-h3;j7MOR)ZT#-dxCNktQ#_=r5EUF63RV@!Aa zCeI#0%M2%zoOSnl=DkDpv3%TOb$pGpJ*^^@2!Mv-z=eh?{!kq|5hR!9#Ym!$K0M=qk=NFtNWBEbO!`Wd0Q7>zJf*-|~ z;k9?$qe){Kd3naN>E6X1{Rj{t-I%DO5hDOFnzGt^Lw*Z8+hxA4d}v`srHPBo?>ho$ zrY?BXqLmdigDsVg1}&*RYSU5iYfcQ5Q((ZP3PR?n`f=!labyvhlwUqI{@;gB#gunu zrCFgwHSwBWT)O_9jO;eSkx0$TV=nA6l8>W7I)No?1^?WRUUYte9o=$-uU-68Q5S*{ z)9ZS5^sU-dcT3>1{ma#uz7;lsjXZhMv_PKMTHZsXW`DQV3t1ug^soMC86B||%L@ZM zO}OC5xI^_tKFRec+f=#>XEmD3|u~B8UHSkL7V%6mUB!i-YU> zd&~TV4^`TI{tGf9q=eIgvUp133q1UrbH1#R2KJZ*@5^MFjrM4nWx@8O#rRbOE*#*i1xmpM!Jh$VoGHuva?cp`n+%#`KA*FOiy`Cr4 zILtrcsxvO|7A5AdFYJ4fVYXjjQKi*OP*e|zcm4JNuJPEOMAsSttla3ca0l-9Y>1m- z2~!Q4kA0?5$FRE9ul5K%$ES=~QIv^mmt&URsyWc)$ z2d*026{a129{Xa6Eb-F2yY6ek&f9&#uWrNqUik3D z7_cylyFqRg(c=oDjZBh`u$#}L!-r8GL2mbKkuhHx(j;n?$J@bngr?k{{AjnQtpT1W zYKj0%z;(h8!10cq{j^%7w0uxNQ^=CjPR^n`Lel4XKWZw)i8rGhZr2xv zeQ2y_q(KTo$57!$R(pxkBS`}^jmuJU^?(;u=(c|@YMNwaLXmMjxjADg=LRG8ZD`&bPsw3n`fIZ|D`QY8 zk)ucE|8@q3G?ks zyVYWOGXNHgX}uxj$UK#jS?IDNL;!4NC8y$EeBZn$ zab755@6eYtaY;AzLpfZL`Kvr*Bx5`3YP*038x?A>4Qkp0vc{K7{5&JW1!Gg-w-MZ$ zF-!R|YI&N&MR|+Mc|!G2R#e^<10bV#)Nrz_%sgU8Fg1*WnI3gC(e$x;g&nmPRW!9Q zQSP`+4J6$_>4A-R1&@Y-4;{JotB@nLr_r6qYPgsga_B9l?vn`UNLqp&Pm)rdlCK)} zc-5urr3GQ<>8XZo8eq%|)Zk0T;OJdkeuq`N-IG)>gdWr6qAlaa<=>TC zKb$-Z*@zI(y1aSRA@#niHZSRiAr+ZaFUCZJQ^LWIsPhx%AhgHY84r%7^<>OYWVC_p zox;j4iU@uI8$H{T5~C2Mo*BGva5el~oHedF(u|D)FdlsK=K=i}4PIHx8h zT3MB<{vmXzPf2M1}!6YEEKiOBYuIa+@`aVn1ItxDaDf|&V`yCzk z$59F*9HtU=YN}p*z5vBfNjnB%BS&?ElnBAfL@0S`4gS2{;P5b_L^OTEz;LCq*6%v@fp3J#fq}oAf5Pa3I=%ugA?5;S@!l}~cwef@6;~tWY+9J=A64W2 zLHtU$;jHO}JQ!?A$5b2gDcTVgR>WDbdii zm|Qd(=V&#%32ToFx(V8#FYNZG#U;WD(}~5F4z$Z|^~4k<<@GpZd@E|T>RbqUhzHwS z{zThLlWrpR<ey=d4;-n6cnSrFyPi3I z4waf3tEwPbXBuRMBfB8zxM8=sm_74j*73cJqKV)S6B2?JBbX=H6hN}FDu zl6EU!l-M8uuV!F53^<4Gt*uwMu^S1^MGMD9XmHoj-+g+>l}2;`4GSI{Za5u9wqzTw%hMt-%zF_f|J~-vsq$ zH+$ZLF}t9PZA0I=?0os;S>Il4pLbvr?m;AP#%^~^buvZii-kbX+i9!I_FeMSM*=M( z4BU@q#z9#4vi@w7h}I^}v91u5x()3koz;@cHWalY(4nabp6e04By(j1D)toK`eJ{> z=koQdyj6c6EhO>bqf1U{KAD+(svlE zUsWiNPCA40_NM;M&AJ^2d{6mt_9=;QX{49XsM5Qk=-HifVMyhBu8_in^KCAJ*V5B_ z^DhNCFU?PcHGjqG(_MT%;URg((rU))+%aAdht2xv{_AB$mHt&3;6us=r!dn=iTxe3 zFfBR$g4x^GyR5t%X;nvNN)2ccirK1@T?Xn+`yVY4zamKY>KvJ0I zW4YJ0_=?Wry>b2%ak4ltB@1)>Q}A3TCX8IDoV3P@eE8dgLJR~Lsg>OMzBu8l|BTf4HBUq2@eshM-~I| zA1a;3{?9v`So!YgpxYA*Z+C{LvrX@wy(*&7!h7D6 z1wTCm5lE4KJbA=N|kxWKaFr(!?JV5mY?J zlhrB9HWrkSH7Y;8?^nF_^7L}freyF_T~btwJ0E?1<}slsXv8SkB`J%t%CV8tcKZB+ zs%k2)LlP_EamI0vOCp3h`eP>w@6IiYU7Ih_VN>Sp(cG1YFVO@gB3$(89$={Jt4|jw zWmJ%$L;475+DLACA{&!tPCL->xljKOPT+GHE<0bec34iCoIcIzu|oboVM2LX)$Xxi zZ&YQ=*)w==#>Bei=%iND3_F=`DyfvBSZEw}<{J;H4nIc*43c)^qt1P3 z{5&(0hC{9l1@3x!p*@(nBB9B($1yku90oTee{sY%V;`_d5;E!!R6|hKm%aN-bIK)2 z@%>Ou+K8}>)J1#Z%{$M!?iRl;goh;~UK^&|^qi8VvOwfZx2WZ|A|5XD+Ges>Bx~z* ziL%&<_BjWAV^dD9Whvk<9f7dtc-~*%1@ViL=f0zDpkw6R{Y`8;4zhrp{3_CCRp&u_ z31?Ut(KJbNt*{s2c=cf-dNIxvb!xkk8fl`WKY6}=`DWu;rjw2Rz9I=%qif$`SgMR3 zN-BH|>BFdxi38!+&39i6gI@g1O^YS`Z1|~c2z6HHad$BdjRmQ zkrGk*xU0{Wq064os9I!x$}GP(OyXu3JwJVi%oLz2 zes@Ql6rOLlW1U(|5h+d`QE~?Dk&RH_%;P8eIh6?KKx3Fbb(A(~)MMGC{zqUW%)I_x zJ%ydO`2WM!TZTpTNB^S0z|dVo3=fJn#CNSA`r4T5y{&>fOWhXM+c(%q;u(jC&> zedqT-=RD87_c<@;&BWfb*IJ)iUstZ8TBh9Hl8HfWy!R60-iEVYZ>wL9M{{9A(4@)E z?DH%LH~A02UC@rrehLu6@-;$tY);Ph&byNa*Oz1knmGsc9gtq z?Molt_pvdh=QmWQPx_Yhb~&WM?Oi6gouIPYL14QOo?I1!y!FTauLXC3kgtOERTY!c-~Y-q>|2iO5C0}i zW+swlN|g3IQ!u`@(kH;v7#RYE`HOfN~^dA@RLFmo#;rxHi|#kQvj+!I@t{@|XiV6-3> z{Q5;i3I1(V3g_&#^8UC+D1wLCPgvA_b+X8{SMe7EIcM^_YYOV^-2%$@3)7aYHO%N5 zL&@)aA*5d|@}jIt)O4B16)L;iZx`9!z2! z_E=5sKZpG|Sxyp%+ot(CVi}vgBY%(Ww?zlm_Bk&O9z$jIXIPzi^9F)hjo&lVsqN_( zchuV7n=t-Y(HD#k7uHRyyhkVt0iZ$O1?6;^!*ML4|ynq1Z#XKW$93?=qb0-Pyv;Zh3pzkXG93!P>8jn8qfZF50`s-s_5@{cAR98hJ^{hrPyAvo%CYU2h(-jjk z#pBO_nd)?ZMHpb3FE>T_!6NT+w=4;8aMKymCp4^-l7y;C& zi+)rtnpWt$1911n6$@^Gh6O8hpTL?cQuu@_)Gsox+9_MX$%8|MreBAw(9Z#U74U4U z|NQNLgt5Rt$^THm7ZoDl-8FXy3DEWM!-kIwnV7`as3--Wj(%42enADl-^7$W!AkUs zsy;$r>%lh%#*rceC_+NV5L(kyD`Ju_+Jj8ma=qjmbItoxtBU7GPsC4SYPlhlt^V%B z-qP3S9&vZw@S&zgp{CFve-ZP54Nk=tp4kfn{ihQ$p`yaZQfv}FZO8cT9nbV!1frK; zshFequXt}GJTG-2CyyBWyo3fR;8O{KE$L zh(cb5hfH=*C;#Ge9KmdTaJjdMpT6%Bja z_GlQwaKxxIkDNcUO`k4$lV(ZZb4gd{AzH6eIMGRW(z7$Fg4kvFf^E~j zdk>sMN!^qljdbxEIa5&5E%W)(yyC-A`a&X?)*<^TL*^!M(Cn*+DnqKJB zmc+9_eXh*`WYzg@85@s`YAJ3P;7cg!aWlfVTj=hSc+No2S|Cnhx7DV<^0+snQM__g zW>UbzO2Ch%TKN6MvF8tps3$KkD9SV~Do)v5eedM*W=wIV;K{hB_Oqefq}?k9LZnV) z>gpZLqJF$-nvVkk?bMBG+;1vH7JK|+>JCI2@*PFUby1= zXrV`pU8jfDzx{2OcPj_{_zOy8;0uAwINZqWM*F+BSlp9+<~P@u<5lL@4Cm@PEkewFCsQ^xHiE-y8zmDoE$c%Jjd>MP5A2PuZV$5L18To3)bhv z|ER~mrd$0LEfq%GW&<7Nb09GX1N%A)OHRr8$N5U(w+qFgr5R*diT(&b~CAYlolsoh=P| zJsc#3C%2LXSYQORnI@Hj*2Add_I9%_+pg}dY%YoUccz3$MRU~EdHCeCG1kw%oxWEO z4dork2%090L>SJlc)?$slg*?W&YvE9j?8TMSkZGm-_h|vj+_BFkTbwBK43h%1a9g1 zl5S+@lsTWSDn>EwQxp`lQYrXjiG7v#1BcAQm zRGmF_#mVc0qt~gV{TaCooD3gD>PH#GLOb>gnY{tLpq?KME8$epR zoFl!VhSlum!~;2Uu}01W`~JAU0LqC6>;94>hm+X{iqPykO3mW{kA7^y!~az8ul&vO$pjMiTsnW8PKgw{zwL-q&^_+fvAG^Wy6Gu*`$}d5`tMO*BwJ9)-u_1e*#c1eyqrzyA8FNO9eGe@8iWb#;|1dTJZ!Jj4gQVu)6f84sm82$B#VuyvNi>lGxnbYJm~ z{cMTxJwuZ`vhW^aIxk_Z-u$~0J_06HpeCg+4Aki@Jeej^4Wi8- zQc+}3A$&p*i-a>s?X~e>E+g(P_e#{GHN zyK|(EEi6gSSaxQNC3{=+d?)a`cPh9i&bnq$I91ZvYCK&12Lo76B^jaK6+1HIq5sO{ zv@H*b%Yccq$SJI&|776W8>x1@^?9a)+}(}4m7wh^Q}vP+zJn_0zM(R%ILIkyL}JHe zP)H`4K*!sGH4*#c<#tkBaFeDM8L$a0TAsIng$>&I)1mF52l1uddh;7wHv!?d{kF`> zMjZjI=UJ)z80}nbrms7?i#atLvqF7e3h`(7=IUcWG--%bE(oUzB`;?-m_ z>9$U9e9tJDHn>WB3bwl7`B9<>FVmziTLZgWqg6I_v}6;2fTP`TP7{LB@ig^)p~cJu z&=-)l3pxj8LzN#1UG5}D&cS?J8)>1%C@!WXadHbq!Va|S+E^_Orrw>knH$=zPEE& z>7+OHj2Hj4(PUN5_x+=_lkap+PSyz z(q-LV#*lxE;Uw}JLhXb=S6pfP7dEE(CkC{kiRx2qkrbhNxc(luDd^WeDhtm+^Vt)XpW!50cy8?$TJZ)F4> zgT^*u?~VNwBVlou`2w%wG_`nqBbzlp=ixOaCC<)!jpfR@t!kIXEk9U7^+VMm&@OMI?pA59FV4Z*WoTq z^8*cB_1`N|+(wz^cbBZ70<_w$Oy3>Ls-dAOlYQaPX0O`xi3$&rFwfK&Gp+z(L$t?5DtMeV909?j~SR#!`qc2dVMd}d5t*E~#eKmS_sj9BT zTj7p?c34s&>c&)L`%HD)uw=bWZV^W^fhbm&N_qcpszZ@>jE`^Y`n4?SmHj@&9Nliz zwuh5~qWsa6Q?C<( zm`qG!B@G_UpfA|u>35G@9MYR%kQB3V?qC9Qh>Db|N^cjVUXgA9+4~8T*_kk2C4`l~ z!!vm5$~Fr0D9}Nk*NuJHLT%ZrW2?nro~2;QvXvJ;j07sFi9(_n%o-V|my4RKNV(qd zxiQF-nosRthXF(_<^H#oXgji1o&BrzkLblH z6%Q=cVa<*~Zx!$`J%fEPqQIXgw5eh>R!!xD@r=R&Ye`mH;Y zn9+~8!EU#gHksGpAO}3UN7VfXw<${%`qlG%wty&~_>J|4%ODy5G$JSDW$rp_V@7&U z?$bLk@ix1=3IH^ugf40$q3}UP|1E&3ErGLt*P@0NQ5`@?cu<-mJar;t%v)@~9ngyW z2~1obab(aykW5cE5rtAWyBY@-ZeNldGeb)u9bw5In?hhV0x`(;$shkZLSVbRZ3>fS zXE+w|3p`Y^*M<>EDPnRhwc#8<)DUA%PjAyygZ8NM?jWf)Tm70_LasRmpj&8e&u4xe zW4q0`>8;n{B(1eR8j0|XtJLIR=*htwkG}P8y4`y4JOZ!kOGOd_62~VeLr1nJVR^;g zkEBQvEbWSTs-#brfkO84Vz5g^O03)_1-~mDZIh|KPw1W!KZ{WquFXHkrXzt zrcz3wJ_@QlM4a&^I`M1JjiTYjgwrL&ks(rhI~qNz{Ss9HO}I4A-!#A60Afmjl|R-} zYgYIus=Wx)oNqhL0|J_`fa3%Ok{{!S&?k)sCcix5dKk?-@p<+;F-m^kx0n9~;QGjr z{d5EXU^BWp@87Z>WKT}F0VIIGG|cipZf6DIcsA1F$+2=B>ic>@F!&`^$g^~>$<9-$ zZx^qc_1#NRF12P~g0?25||P@Ydy$&OBwVSfasib`KD#i_o8^L4qsZ=u2I(~B155i$*X-<`QR zDQb|`LGtm&Q&SX97FL4CWiOQ5E@a2q;^%6U_b>xNTO z^*-ywZ3Ov&ls*MN{4>)ZeM>`_PyzbDsZjVa6v0oY+I0|tU~3NZDWjpX^h+5`PSr^q zk9BYwc<#M2C7;+3UNA5`2bI(LQDwy@svkt>C9+L?dX~_ygWi)JQG9JQ%NKC}0|NBG zEw^Ka<*c4DJb$Pt(}Q-j{u@`HtIj6P=s%LKbW{4@p3Nao9}wU#y*50VyC8l&RGgP| zlp9-Ozf5d|RN`x;C6?n2*$-{LC{^rq7Ow?Q>evxBzKCwQW!L<8RHiIH8h2~S<8ml! zYPciQ3`ZUaoKXB=0BUg#yvM9x(WPfr=-2k{Ylv%kf#KZ+|8td_aF#(i{0dp(bdMzt zL_X+Kmr56g05qW_Dq|Aw!_K1$nS*G6uKFe8v?_jl*wBwP<9sb%)#`317(YJU4=Tl( zYPmC?a%Z~S1CWk%D0Y~ONNGTuOW-~7mG&U{BV2J_Y^fw8;2pu zhd}|A%q~X`t6J^Zx2jOmi4$kU6*#v8*(5d%uw)(gG0Q(CIy>KB%qMyb#=WU>@0xfR z7|Sj@mo)5u+1ivAtcf6+i1M$y;O%+3B~r`-XC6KEB_kH=+4-a%3>S2Vjyvvm-?htN zMg-)!1TLZzj+Hwkt47IJ~gp%)_of`+by^mBq3fR=J z3p7#w{-o5e71nQbzcSb+P<3*Y=?qT8>MhYNvUf%a;5AvdqTdY)FBzUmBC0U$$Jerw zQWbRNzz9awa^RXbq!P{Vc@xJIVs5MOQbDk_&%c1UB~j@A<;D5p0ZDRR!YAI#5w% zGgavHCMY2#$&xhj{5k%Ad|)3*VG4aBUv-&nX|P~dt^KR92|%*qQhJgz;&d-Uo&9YQ zr>AUxPI(X&YbcVJ@Ynd?=CyA4J?G9y+~!1>yh0H&&(736-_?$d8|sxrt?t%=?h`*+ zWqKPIiu*6nRe6@hfSCJBls5;Hb-nrg76!9FruP~qU?@$qY(1F>lIRVIj*XRl{~qV; zyt3%)o#2tUf4trAnQ_&(fO_a7yXVu>E7^=L#-BYm@Eqy8_byG9?k4G|_}&gVIB>i`CL>)mj>lmNyzH z@P)Q*!}n{}o$B5E8hGD=exrd zN}y|J)fvc#EXD*XuN?@M+-pycusKFiwAV-`SJ{o+(q!j68Cl1cefEk#TseE^_iWaR zGKJ$Z5GiRc3P!yQG2yZ01#`jhR@Subq8XEor@}Sv4NZ`COpL?KWpR@{_d^l0=>WML zZIpI;l20>!RdBneD?*q2Jh+V&P8b7pbJ;IZZ9vM5=9}XCY+7?OviHkQ)fXt}&TpqQ zv<4}Ws-mp=qoRI0aL0t=HZq&pbLa1VCJkLbCUr@jl>0HUiFi=V^0aoZ7ETqCIOto_ z7sUj+&JFADahtMt_kQ#y6Ksuq!o^g(y#frG-~z_qd1Ga~jo0&XC1NCugp(m|THoL~ zi8Qjo13MVz1`1N#T*-?E=w-Mfk{}U8F^ehLA*J5zH~%G9^q^&pIE0RA^dGg=Co)@K z^iNRSt(-fwvB~%RaKRtl*CCc3`z%Z{BoEFUb33Vcc6usqBmluY57H}5o;znYj2zUi z3MGZ8a8Mx(zbvODvLgC+b8rr6Z3Z@n|cnyf@qfQTIvFdPT~?zMD|c)a>=0wdw- zz(VdrN>tI2r|%O?XD!KeGV4@v@_$-qzVmZp5Q`8C8#2wmtliDIy1F);K&U3P}uwc4U6Ob8g#+3_nfv-WV~<_7lw*#vl44;BL0 zh~0>-%_o0sTu%21Qp{Hdz9GyPW{yD6^AEkeuosb1N50$NyW2xtbb!yTT5EW)`$mb7 z?VvldxKp40lL6V{6S_&jH1Cj~Mt61`(Am6&b^k-cf7z(}FyS>XhCnpj1<`(2oUOQvGTP~$p&?Tlq5;dFp}naMYC ztXs1Zv)xX1Lsjb!By+{u1Ft1aT8%(#@)LIhkiD8j!Kd6RDH;Q7v?Xyj{yax0NP$P0 zSRr5Yf5 z4VlspczFYUE`2nDa89EHW^K=o~*i^A(uQt`ltS;gE{pJTvcdonOt3B7=c>a~Q$7^W-MGZ$z_7C@=~ zIB{2jV*Ygnhk|e%rFF|VyU>H`+8`+;zVC?+duZ$7yn$2U;jr7#BqAR z`(_?-*w`xCy1dx4J&HLNFCpVogQC|rXZu`*w%=v*U`D0j@s4ZL9M#-0vAezJu%&)^ zn*JSAga8YM(2Bzb;fk&ykMi5CU^^>nr~78-mT$CFoTr z&p3EUx7-W!R&FYp80rF5Ypq0D)DzH0FVsC#2lJl%nSA>}5 z2EN$c&@yUl>Gxc3#8u7yphlD&)1ky{J1@@o(2VBmN1|gvT1_clIft;gOQ&C%@gf7bcd<7db<*v;zdLx9* zyg>6M{7ci>KN|f5n?tLXH~5kPJF<%R9}Gta<+$VlB`dxalniG6SW8!O$j959Bn%O< zrI%^~0BBv9PH4^WU0zWgtyCy#;cOcadHdr!wk4l)6M+~Au>ehmn6nMvPOFPh_z@`n zuW-K6C3IH>#clA4fm-xJEWe;Y5jb0qoK>P&GOo|lc9BPF_`T-W4=hg!w`v?A#l9yc zNVGBvKANw7IUirHpXkIu$qo2j1$hLB)e<6+r>#cQSiC!x4ezT#hu+f56Po)90PR~v z&7hBK3;>B(K`+R&SSOsPkswFk&mAz1bRI+}d)t$}p|}kqh9Cw7$v*{mH`{oFI|YB{q-~Wr-Ttzg4E6cXV5(AO_~Rzf zep#k0C3NfmZc6N3rTR`_k@zdAfM~selVO!JHR|_oS@{p#jxPoB+ncDkKU|M5>&BO7 zDp}VduT{0Knb>Tu4}yko4GWhP+VmjVUpH`bT*NH?ZEy>3`y0)Ufk5Y)RgxuaY=Ig` zGqkBJr%+N5Twg5)wp#hQ)z%Y5&bu-EJaQ6_Mkeza0yz(P+n_CeM z<8>8^mEr!lV+h$O@Iw34d`Ibu9VF7|rT;}xm>^9fq;3VcWy#8e?1hT%e8Bo^ksj3({ z#D8;GsdcUWdn3$9RoOA1s| zFB2(D1yYu)t>6Q8F|`8G09zZk*7+}o?u9UgJ-Jr3Ja|;TaBQ;X?>_D!J#?36ckFn& z^Mp%Xw*+5JZ~*v1{EVEK01wkHv7ZYfGAmydM;`94JK>mlYwvtwChBaJn6<5i$-<1g zjys+M6Ab0qGzKroqcvk1_pkX{n?e~)Zq``4;Uqu+I}ue{LU;G?{f4>W9$2;e`ye-v zZ8Ny6j)9UBskFwtTG)2lX}%YpYLni@#?y_W+PZ?i|S?gy5)+sWlt?vH%VM5Is!((9#>F&uYONww&H*|5};&)tx ze+ZDErluYnn~tcjf3sM|Ax#_q8-fB*Th<5-o+u;6o03HJ>%2btv0;J^K#rZxJ7Z?+ zE-%+#`-h#&NBwC{u&K2^fEYs22(aV76!8jAIhhBpV&}SF=9iSvEE$cSJ}gnIPLJLU zK>R@8G~i4kX`48@9YQ83RO1cK!Hf$O5uz%@+?2trWDG1NXaE%gT3UaA=3P2W*kOfx zyHc8o3wrWTSoFsbG7n8mZGvDWH8_y#IR0sqcUojhPwAOs&DB<>VAzNSwd0C$B$V;H z5jpXHl9CG)#3MsTjHCwzfuh`bJ{#ypC5+tgEEqU0pgXX_l5krHURWEyVdT}0TL~mY zHDOUyRToqq%GEa+J%}E@J}PQ!qG-ZSrB*@~t;PRSxKIJahBX!p;9KCtCfb81GZ|Vq z$?5HFwiDuCNIJzogT@uz@VHW?!7cPOBs48Iz9%SUrQ7i3g81j6mkG7NzDV^;5mJr? zNZ=Bqp`s^SdAGgqPS6uj00P>6L?|8Pu9&x51pfXlZM=YfAca(33rar{a4U$NWxmiE zBxPi|76b|zI{W*EO~joeN2m=*9TL)?xwRg0bfLFES`&Szeex2}(Yk)sr zN8if^u*s*{V67LoUk$7F|BdU=%V%n0S^e_3P|b94mh;H>Fo+kr?CYQaLKk$+Ddw#W zZ#tgt%RlTr*VWzZ8nmCa(G&fj$?WuTxs_QzP!0Jg|Nk6K|MBYwB8RioM{*H+UST+C zO3ED(Ep4Nxj-C|n+vu!Qu={hs!i0&Zed34+(n4_}yk7SDmIy0V_kRo$65x_%o7 zGD)_&6CA0HFtCYYFjk`(lf1U^{|)38a})Xz;9nAi2EbbyV9E$vgxb;4z%HLDYnfC-98& zil2Y6EGa1vl9w>5vE37f_tTF|bQx0BB#tL%f6GGw(StI7c4rbHV`o=V8#VQ{F&%;{ zDDabzq$a6^W2I~tUJjFZCJ>BXZD@ob93cCj2-s2-p-okQl-TnpUpacSy9(3E(&PVAmSaVn@dxCoe=+&0nO@oH9%DG#!i$*k5ZHP7s8_A-$ zfa%uGrJkaLp00yxJsvuq@AGc9XBZD4bn9c1(JgPQHHV6mIL^jP2J%FrZGd#5Yc8eM*2v%a*G~HlKMwSCePKW1{Mp7O&z=#6OTtg#gT1Ta)*jhoa+h6G>~$Y89e;JTFTV*^V{s(wu7 z%#wN$B%J;gjI0k4V6$t-?}n@Xk3l8o3|VL>uW|&-gBlYBwvMq#Ig&s604frVS9=y| z>&4qKN=RFf8!{YdWk0ymlHqY+qXUw;@3HP98rCqCaHGmCcJ z#S{CWn`FCg&xc1ACWL8z*ME$Dm6t=esSmBLrC;(ukv}}X2*6suO$Rj@iPS8XQGd;w zZWy$h^cWF}FrK0G>4Gy#kpuUTNi6J7jV~-k-Pm>w>VrN?0nn0>93fi#kO_MbWs7SS zdTfd2CS=8Fgz=d@lKzPR83!Kwj;s{2vHx%6Dg{P53A^mHEvnvxbBgn&x_q+rt&VCd zk~$ehLN9OU&YUO44X&-vKNIqg3}8eb&^$e6>5gUOeK}-qT}lH5!lLym*rG@ovMCWU zlx{F#|7^|?u>RL01_T~W4OVfQMEBkJGv)Jx>mR}X?V%Z|r_LGPywgwPON-~vcz74c zsg20UZ^u>U#Vt{IG&BlB97!XNX_|x<{$W$ZU904k0HDRM}r`n+}*IrUGW*dt|tYWb(Em<*Ko8C=>Pq}{QncKkTF>s zB!5n}AST?~?n1%c#F3t1#PJ&1;6~tC+2nL=i@CSU%J%wTMQr<#1i(g`(`@6v0ZT7PzWIEO8L zL7m!EkfkSq6KTZ@af%+mK%L=_KL-Zm@?WYWN-q{b4Ya0tms!tkR|)Mh=hbN<%cG%) zH*{M@Q2sF<0?A|39}PI4kn2mtz7jT;(f1V6j=%Z>nw1otq2FCx>dEbk`ne_lXLQue z(-^;HW9nod6F@GBg9kVSuMY}G*V7DXD!-!qCD+qw;7<$g8yp-I;g6|~CiI^3Af-x{ z!Rp>*vkeG+9XV?E3&}oq@csKSu+?8oT0iH>l>EYk0#R(gq%ZMsjC;AT@#3yx26R6M z)9lL&LtJ^9jVG>o#$m}+VW(az83W~ zy|loL*+jz%UsEdG>8gu>D?R(=MoorR~5ZW0-;cm>Z(o${lJHnd?c^~QvDq;(7T6iKm}(~teY{tjtzmS9I~g-I5E zj2XN2Bp-QX1fm6uo{P8~ledOLM z!s8;&!F0wc@FJNM#P$~#_&JPwwIQgTD8Cfk7QX{y=sm--{43sdcQk{GXgf}aKW?5k z1V6fsAGb3g=Q*4ZI3gBF#AJ?K8)c_;CL{hxp#b00%yjyq;-z}oV5JwqR}_&6zW^MT z2|(t^Pr#%l&HAT9Zj#+JjVhOij?+YyeHP;Yb0PC`c4`KH_EfE*QXk>ioKW7I<4O~a z!)U*3iPKmM94}m*f@0mILILqizRo~mybIUMa=HsvCKz6^gx>YmW{A^GZ^Xv)amjhp zaYr7~?)E=XgFw*EIr4UI8o+)Ob-Ww8r66&2X$Zo&9|&5_Gm*T%rAQctXr-xYbt`#| z_NDxUUzJJF9`5oX-`d4LWFK|gUEMVw8~a%a;prlsPg@&vXQxf|HL;tjmib|- z0{5Np_Pg}Thn4E;!0Xm>*#G+APQG1c%H3~L@PWsc*|bn*eetcrwy~i3r|n8EI?#^4 ziN@+6ydT*ic>h1^A_mJ8OO9>%mI@{XSl4@h{En#zuSEtdUQ(&FRkrA7x2R_}*f?jEanl9wuH=ipyg~==8P(h^rR@C^bvCpA3Q4P8!=fBa z(((oi;npUz;bTfQ)ADZrtkOY;9IuFYnf{G0;FHEH(2t3_5f{%C*a(h`7y%29i1#4C ztQpFk)`DICiW8T2+w({XzGFLmdo{jpn!i4v@&|jzMDOUXO_o#=ImuFF>D{;W?urK6 z+t3u{Fe@cRF_8TtCI^QP?qCgmfnt%+@YR`hj{KMrtEJ(a|AY zJ@(r=3UV*)yZ;L9<+Mt|30Jaxz{>gC^>cJLY^nH?|D$F8df(pdA!8;oK3X&TP`NN? zmh2K2dpz0+=e2RiJ2`DjPw@~x3qgqf8x+~SrO9#s8lNSEcR{V9!2jlUc^l28*?D!g zb6})K2UO|@dXFi@vqN2C5BZFUK|aKVJM4n5urJ#v+x7vN=_||UAqYAP^^H!tHeTxO z5vz+$#ajF@LyGusGw+43HS#>_9o7Faxz2KNl?cQtd;Ak+-PU)f=^UAkOtqQ!KHdoa z@>@)dB0)jnkqk+D&{CVvj9UjfI`aHarU17KE@-eSvT7h>aeG{4#|X0Z_5F;yM4O+> zROdKxcZ(cRF0Zd^uQwMUh~6gqNa+9y;GT(?VxqI_`QINFG#fd6$OINWX>37U8GI%E z@egfNO9STq9o7-$_fFC%`gk_@!YcnkYRqq&ALoBkTBfP$Y{nNk!?~cSTfEcy^FLC2 z+`J$I3l@!DXvSoP_OAHp0V_>^D>Ojc&MU~pD12+qqt|m8l%FfPKWiAn#$MJ1Xy(;k zX=iy>-Bb}JyT$;E;4*8nT%toLdw;*C_#JPY42H~=;19+n$;g$^{)U(ZF7#*PIJ8v_ z7fGZPPHTyNL0|N(43E`rzxh#8XH5q-3#h`rJ(Y<y0prl%C1mf8%%lcAle* zpW3cLveVPu`QLG_FF2sythEKL3_TSF+5yz7E?-ZqS|x`_8`hNSr%EpLGKhS6m{w)I zN+2-H3JCxDi)3@#8rrZ6c4xUCYTk99HTBlPAIHYK0DU*PT01*j=waMh5n>cG{#7Ut z!{1V6zA%s>)n1kAgdCNEo=ATT6W9IwPQMx{o9=~qoX&-(DG9dWIZ&%-xI+iN?-GzU zwAov2x;nF%u+D_LeI$co@2b4DrFr4P?uWM5K z3mn~WH^A+0HOsH(+u;@v$~CUF7c$$h61ic8X+i|CyU*M5+z2DT5(^-UlF>QKtXHE` zOt!J-B#ZIi{+cKrZ7cp?$BEs}WcoAAk=tueqQ>bZj32*zN@3V|5-88DuG;TT2*j^+ zVs19ecMuTI3w_M4Z@#MAxsE{OR@}#L|Hh(fDmiA&LmGEVcY~Jb{-}gD=6pv%?llnD zqi$9dTBX0~KWrD9+*dnj8orTh{7P66*Xd;8veE5lZaw z1G?!7rzd|FRX1!XV*y@CkiO91QCaJIVRtd~m>h7eYR1wEBCrr+XcOHArit4>969TD zje#m;01<|d0|49(Zy(%EjsgilGb@EG3RnoGkkUGB@mtG*wYBzVG2|k-E=iV!7ewDb z!j)l~#~(+9JgSAS3C!r)B3qdV@!Sb(ft-`4k3XzVlE`UvG&Ixl;ofY7A)C;;*39sp zb{gmMppG%`C1wJc?<)@bQO*pE4-7Bw$uj zR;%1^p*;HM!gjo#Q7|2`&NiLuq?V}Rugb{^^nR@FrRlnJ8LXZytT=nji;0K$?1CL_ zV(H1wx-=^K&vy?~_>S)vNecpn?D<@U@MVx~Wu6b!zSBvnO4x3=u6$koE$xjtw&nO+ zTmSEU<7*;rw|PW@6zX?F4MDr$W}Rc@*ARSwk>$gahESNF zX@2Y#O?IR6i4T|;>4J>O$bwa;#dA#lh)@!72X>o}ex%w0Y9R(>d`tvD@ep*jOE%U;|IrfuVH>qW4JF=~Jn9 zwQItU03QZ}WSVvq*e+odW{+pRH_G=S(cck4vbqFF zN%c$zSvAR)c$1vEe%kc?`@@chjCzC+dr}-%`ZwVGERm*B*0y&8G-8*fenN zUOQ=?J~R-XBFKkE(|ZsAeOupOx3D%GFNAvTiF{i}M<)FTx!8M#TUC2ie2DrNd%rV- z5Aaw{zu^u$eKr&5on*jZho;1gy|l0rEz0BB73vs&nzvW4*%>!^byqIQYwd;@QZN}J%sS|UJc)h? zvZ8!|R9T-bdx=zcsMu?fE%YeLSr)GSRhg0xNARXI7UUB{@FRsqcU9gs{PljloyJmb zg|%Q^*;X@&`u&vKQ5#(Xc^J{!F3F`w@LiDXLAt!VWs_O{4c_9dt$4K>@y^C~x7(P3 zr)L|>>($ITLgZJ%?PtQ?Z7lIaU2*9<)OmUh&Y<7#2>qWpqz=zK;Unu2;_2Hu0%4nQ zk|QMSvUYOLa`zn<$?i|sd_Nc~QY1XF;t$zc=*zf8fYXWZfdBsP?(d}l!1B^B$#BHCl}^iQjK=hq<)#kY`+l|~ za%TRhm^i-TI#<{jx4-^mSK!wGf*salMk9s@O{Q>;dK7R9^(@#@X>nsYGM0>o(yi0h z#@CsforJ#t1^O1}oq|BsoY$)9<7V*R{I1xu$wBzmXvQb#u8E%5L0$Fj)IJ7MuW1`v z56W2@iQBu=aQ~Kx$1HO)VSnQPK8t;=ea5vWCemjUmcj?x>59W`5zFm3MDDZc+e`r# zpN0`(btn6I5!t7Y9u(Ied*d;wH)r%jODs??OR>qP!m${naJGyf)n;y8uPh?^vnDY- zOm2~q$_;3r0CUE~rWbWPmps;KUQ-E{gtX89GX7Kg;h;{UCBq(5C9#t64Nrs|i@d`K zxYgukm-6G_GP&Yse%Z%dQi3Jwh}?_Sq3_rFWADsAY|2j1G}2ATgeIx_O#_Pr)|2GF26wm7E0itWKO zSnvZnKuuk}=JeP=F>$Q<)&6!66i@EF#3=}!WFUdOIbJf=*ED~TzjgUyoH0`pUUFDo z1rmZS8YE`3g`kEdeKXfus>Pek9i^6bF17;5=FAW?kOV_mW&S;N}!mrI(NWx6R=_7z4jZ$p+Z0`QvZC;GM%VFKmUa}i7^x{Fresu?v z^=D&}@o~44=<}y4Xo%rITV94nUN%K(MllfMBzf<@oObs-@V+d$GQy7)kWS87=XY8w znm+gQJ$3{qslWb_xTylghP8oTyF)3~QlPjMcZXoXT?!O;cXxLv?oNsZmqJMK$@_gXbIn{c zzaS^iNzPe&@3rscaQQXj>7y?=%iF6AmAg>m`F`_T@izHT>x26=eZ4lclnh_QpDA9- z)>QCUJWCG+N}!wXKK4D1J4yG~SnrB=X1%EE;)`=xlxsKsD7`WwX+ETNmqaoQxYP!} zg2TbdbHWB?7c@+>7UX8MLL|(zm2l==2iGovIBInbGouH`8QeZ^TiiCo-ujFOv^4b5 zs1_*@lP~7x72Mqr{RIxCzLpBf-EU+5-6R&*J_48Is$)pSU{j)es5Ek z)8KfIos#wf5+D~)yrN~3LF-5l8lZ5W z&QRZ>@ptmPKz}?FPXab-Lfw&nxovbdb%mvsz_v+ml9ze+Um9h6uMc{@k}&_h=kk2M zH!%o2&(OI(b%aZif%{kRrFdC#FlEE{C;C-!4142kGY6jSUb_lo&3#{Y)H7Gp0z(M>IFD;H_)`SH}x<@|DpYb)A7 z5R?9s-B9z))CH_OrW5yHjh-F>m$Su&(aXL2u7_red{h7ue^v-Z`v_nAwsyEN+3P=> zyV-(KH4+^+0zIzz93jIBl%iL3U#GLzeTJIf)tq1I)r)M#WUXuBMxfKovJcg?P1b9? zZ^0;3d038#^I-ezOX==Q$1dPH<_ZLhcH zx-W-2L8toX2e;>RZ{j-X!*euCSQ+9%F3V9tfm!@{AP&Wx3+4G9A!f)i%=6zzd<#S- z=A;x*xXbNF17Po=B%OVNO;{Kt8BO~tpzx;mENI<2$fzu2#!PY>G@oAhCUBa+Om5Ps zQQ2};MR0N@iI3bleoe_2oCh;JUkD1(pj?yXromXM1@BVpT20)3%1(hscR3I2-RfFO zmACqi!hwj&?V>UoNvmGA+-8qJShs}2HA|u&_xdT{XDf2ID-1S0h`a+3aHXG+y9C;4 zQ-)_QQWK(4Ys=A;opM_?H#1%!)%UEMuBR&+KZ)PPO`9*8gv7bOi=ZUhr?@OZr=eft z7r^3wshSEHbKfEW?0q~J;LzUT&u3Pixbu$^BD_e})9rRgt1>0$*YiZL zhS<;Y17BZszqvt<>S;jCzJ4|d4MQ@7P)sF;;r=8rqm2Wry z!3)IYJ_(upKylel0v0Z6Ji4C*3rD%riX)%Fj0`>(qNPm#j3?1DHm@v_k3C*vyGHPRBvt6nvz&ivpcAKjTpMz)20Nz~vXS#<@NkQ_p83cyV@gT2`j zz;?#HLtxJh?U)k+s&+aAmc51fJLy^buYr7g|4s~gWmf;)uR_P;&v4xg*DskXK_tn| zuj?FTA|GE~DJXS5NR73Sek|Kx$OYb2^=;Hrun+99^n!BQG|(GN9{9sds58m`S5+2t z04V*xAw@Dsh9@Tc?e$3a!dAOC=>4ov@7MK@M-n#hMR^Yz@L!k78jOp>@gjGYl2t z{t(?vdl!FZ((3$T!F@UxxTt!i$XHPBMfD`-f#S`$>@ij{wi(Wj9Hhixrc`M(VA6=# zzCwrHp3OMb*E>`)gPUB)84^oDjRx`VYs9MB2IrjU+?;A-r{L2d;-l@kRoz@K^J(+j zTDP?otlBn|d>=Z`8Y`bv5B>aGF8`FURleF~8bZy>{c}@Ssq%Wohwqq}DwG;d zvhK8ir0C_5_v8z!ZxOU>)P(G}nwU#Ak0yEK6vx*~<-T2`+6j=8t6QC^#0z3b^f<~( zBjzo$l)Hrb|9e1ae{|9da46u)H0;mvc8b~>$U%_nl)tY_WIr~ASe zoaR8N#Dt*l&XWmydvMQJs9uV%PgRjzIu6$8w>fuJ;Y=zx#%@+?333oI=IyC{!?M*A zy|%Rs6VK}Ph%nqQ^YflueSAz7n77QOE&d-0tUXMFTA=@@K_eSqyZWUF2ng(-pH}AS z?l|WEx8DbWaF>PsCqlEfV=o}Y?KHU3^mw!)J%~L7M!um62-8Q50)oHl`5^F=SS3n2 zsQutD+2UAvX}1sB(_L#fU4kdf0%3-;s#Sz~bP`OI z-KgfWEt~(Pz!Wp0L2ps;P()QrX>54>M)SrklObiu~_L?BZ1$wa&)`=u6uJB>Y zB%fQ}S*s)^A9BM?w4AlJVx7NuT3@q1Jb28u#WpPynw_9gK!Ucq!QE&Uv}_7JMG&1j z4bU24?I-w&ARl(#2>8dCwWgGlyvnE6TjERom;Edevoo zZZ#_U%wWnAb5JSwe#bwd;s?J#8~>d3(%M3`+3UUvA4n-e>myJ@rVQ;t$=)x7V=*cB{S7JSbd+hl@_Mt_A6VgTvX`CYr48*bZQ0_uXV;f3sC+1f*L)a@YA zYeI5QPt|{Sm4rU5aEQAUxMCHpW@9du2tq5lG@Q{f+C`B{@|)v~smEXZD=h1|ElO4Hj*n%|oN8*IanH$p$fP;!`R%1^LVRwISTIuCjS$@Y-;Kh9~ zg>Htp%+ja3+MP1kHWhuZnL%q?7y z?8vZM`Ob8WjmM@U!0kF&QHbA3LU~B(X_0>Z&Eht5HXVEMi*Wm3=cKQZ(r-2Y+dGUh zeUqzyw@b%Wa#JmF!&~&P1#Re!d?IVIMgW+F5xiQkzIW863N@ z5rRh42gv1Yr6j(S^-4-E4@%D;UFu5T5NLF|KM;ZDn%V7Q=|vaO5QR#UB*p%071W*6 zHxW1~SUz2gdU6}F4j!@kywLW8J0ss5kf&}h7NGuwmPJb5)p=}sp0w;SKM8_@mP9`$3Y>U7^fvj-GHURg;cbm zP+SGe{mZKF$CO=0x3j`BvkO%>yJ|d$i14gGUG`s5W8wN<`eigf<)+F7gN;Cr5{6;s zZ{01Aq9Y1woODZWA;&$cFH55o5y-6(`@OtbHe~8o|3&CxL)1|UR+3eLJE=$tkoS-J zcAAUMD3mWlbBVl+e<%oal>*;1|MMz8dcx%3y|{(ny+5&4Sbx?JF}7(1z7V?y=9An3 z))OnZa|$bYtm#&41%euXwf5M3Yc}Pu;Du1MpL0CTFb2C_yZ|Yek7US6o0&X?D>(=T zwRoZy?+;3tMpDvfr$cOwmj<^g&hg*%ZDPEXPix4$9}UqZiwF)k>V`D*R0wY0I((zC z0;q5f6rr_U#MysU;D4;7CCTUb$%31zA0IKasJrup7f`6z13X-Jx$|1i=Zd1(ru97j z#QtgI1r6&nm0|FDN9#wY|B=T^ZG6F99Txle!D(N@FOyWyg%P@SGOtwGNKQMa?S)jg zUvvw+_H&yY1UMw(lCjFjmTM@z#k)>hcX2IB2JAry&3sFRtu4e~Z5P~L&i2`wor+}h z|8LcNS0>eqG!_q6m25bU0~5zLVOS*TywZWBAo!ZM1O56<)2&QJC_iZ^BH@IVQnyh8 znBZ1-T?E3xu7o~~oqrW%}-%aW(6lTApYd|M9dbpZYCCE;_VZPt2Eh=z(Z0KxORyFK-<{GnpKxLMzBQ3V$67LDy6y z>d9#8LvuWn;BkI6N$?uGR*!*5pMSH$v^y=s$6!%w_PPNSHDT+V_o-LwHe0$Jx#OCwEj&&5m^ieWmXj;PvQE<`z-xPkZiv3t(Vw2$tij zQ2&a=%SeffpH~SW)^85)bf^_z`-~i|rwZWb4gr#ZY;Hhpy;$%*o3C%xYdKHpCsQ)n zpZsppEdX$IOT13ki6`yQImw212$#UH^g-TlBXtsDp!~<@gytfm;haDa)yQj}S47(qUN|JSif|IHeAJS7d~e0=pg5 z+YH|v^c*4qPV7BLX+*k}5;LmM6IE}h@+56xHdHigDD-gGb1O0~Ty9E{Sv`L3&fcKU zlb|i@a{u*0A;q{Wd}@g~a%hk`MUJ|%9BnAZBJt0qsetu7810ZWBXlsr?eoc<|F1Gm z$}T-jSQmXE3t^6Y8m<-+nUQ1C=rH@o#z8=f7LWdma9KUm(*Aa1xWn`8^XO5KjRXG0 zVMpYzsD2!*2zC)l`PJI}`-{zXkUNTCT^e;9h5r-@LDU+#1BsCU7ZUV9!UA>4vGew^ zGmc(elG{8yP@JQF&8uE?GcI;msZ>{`M37slOx(Y&sJBhHAOjGhPesC9gRaiBwC1t) zTzeRL><=rff}WB^&n+aOg5>92`Xm6Y2(8?AW)&orYkmf7HS!LsLEI-C-@d_8ng=E| zIYc{-xa$5mJi$gHdx{0kx8mB<%_sZ0^4J}5#6KdneN6I=suyT?b&S7+PNxSR0)hW9 zkuwbh0gwjL6-v?O7*376U!kZSs5ClA8jMy2<<}x)j@v^$N229yY_gsTv|5GE*WO>4 zdDw0&yh#;ey6VQQ1-F5B)}&wmF(jdtpX&ecC{>LP@F&ap!Gr28^DCFG>XoLnjXC=s zJ3deFiRHu2-ZmqTnwqLA(*M=JwZ{BUJEQhKh}v(UB9ndiV6_iwi>E4p#I4M}@sS)K zPNsc!xxYw|GHc1m{|<;hgP&pM)M22m%E$Jg5-%)8=3lKVqUR+4UZI&7ZZtrci>T=m z#cX>rc21O|#*1PwJFOjkXEwSpW>r4*DKz#&v~w?o3(GolRk@Uk3dZH~azQ&axw0$F z>=d0Jt0wK}Cb@!Tkdnj;`AVANLRh8F%P}TaTL)&&C-U?y7)T;`U zue25i0vqR_rDF^u>OOk-6fAyyY#|k+CW-t4tG^3=I78sKgC9*$Jh$<41-7KjLMaC4 zs&HAqChBX&hKD0>+1W_@uVYz!Igl_C@`h(9{CjroXBLt(*qBu7;ws5&kB#u_!zvAn z9vqEuln)PmDdC@u7QLmFmHg`J_unhHH6|<~zbMdtVu%>VLg27tukTp>T0kDz)%Jf*j5;6IUD~R7tSY!Dbs+S6c+dIpUu;qJ{LLn(;s%s z;%yhiB`=CDWXujzY98lOJ{;2+(x9GNx_5RL2#!oqxQw_mrLj_90$4uTv1%^Og}t)y zdzSY!pRDvs0)h0fX@Z?;WA;nw_(1*32|{KJZ$AF*00w~!VP2G7NlCjK82b_yDNtSf zt|G}SH2g2kw_)0fN)%q(n6y_VvxpydcrP%95li~G5M_RiKm^$=%+4}tHFnGpMiFja zV+31gox6pJVi~us`03Dx0jd;BdZ|$i>2#^E&}IrPE$zL_?Jn!88hKL*Adi><8i6!v za(p2mL?erN>g|L=QR?^O7)pE6K{)WQ9(QZ?eqjQl_kEP^MSVH(b+-VrADsN#d#QBY z_}pDpq3qsH(UXzMV~*(X>;>ZAsGcz+HKx3t2>z1VB3LhIFVpU3+NjdH%!rlN#YYFU zb(QGf_`}9$XXS0PqlGh48(bnSl7I(FGj_v?Ge{9nlyk~^%LA6!$oTk}cki0ce3986 z#M8t#(X&hmKj=rpc%ndJBJ2iw{ECcKO>OO%&Q7wFl=pOW%4ik3VWG@~DeT)615&qe ztHrm&r)y!C=98o()<30sc|lQx3e)JhtHaVW;HF5zN*Km}+ff!A0TEpo~aLC z@ujakz16Pet+;!$aj6k>j+-lt!x!fObE)Rf#fyxi<9k#*V_%#PI?QIo(A@ExKD%Yp z>j3Tr^VGypLnPy0L9x$%@0V8!Yc4+OWwL7K<^Y9K-_z)z;i|}g-=ro|b>GXo+eSzOsYYSa5NywxYW}z2#3J>+P%kJt z&F*c6#DDO~p-D!sjblUh`&1hKPg5(nrWWj%SSO|e1j7Ymk5wXaO5^LS!8|V_ony^Y z>xa;6xBr|0a0oANpFTa`rvVQnNf8V8qA(AfVr~gtwLb1De-|*5JFrit7U{rD6|{T4 zppf|T?HllMrHsUwYf##ge27F(Wg)-#K`AlxsDS?%^%q2JvV82I2U86PMz>g@c8BZ3 zSwOh+4lH|zc8LnK@fxeI5d;{j2*}g@3g2)M7eL3y!zu7Uvsr0Q-Jou^V<2T}@U$H^~glGd3F6E;gi)ajO+IRQ*KS8NPzB z**;z&40%_fNky(kYm2-kK#_&;*n%d;oXSg%tVDrR?{oGN3E_E42F4drz0aFfN?0-X=|*^fB`W}FS+RT=((X8xRtf;l zAt9L4B7To+!SF~AM1v&d_&J$TTaB(MtK0o8c0S{u<3kK!nF%CvzuzbFk}!w-psY`S z>BE_kI19DemCB0EgIEaP z`ofd??fUH0H^=-iGL`Rkrb45!M-rWYjo7dp##XGRgx=3?Nv1z2pek;-EDYiqC_bM> zqU(&uE=n1KWsl^w{zmF81$q0aXV9V!&4;M`r z{ld$@B?6Liva8tApci#OWle>|(vNni0IA1T?*f&{4g=+4b7Z`|aIqsvx%qFO#sq56 zV7y9PK`kp0Y2<(uM~LoPDXP~FlzYg2(X^g5lk&9S1-pb(jnywmT@+BCH;-hqZ8R{M z8Jw)eJJAL*H&oHcd?KqN^_^zZntKgj|816(>(y~pSXqhhFe^D=asaEgwka-uWW$n4 zeZ2bB!G8(ifqqXXv#1xmWg?BZ#Jl0Z8S-RA!xN&L(w3k_7bKbCRLTBh*0TK#iY{Gm zAA|`!_crS5bVNuVEw zFd8gwQVEaBmGI5u5WP0sSlU_%H58xF%#C&K_FkCmXCXB<(>s|+yICDmv9EPsxw9Oc z)3vW3clcpk+&5<$o~yh4#nzm|JxI^%omaEz>fWOB2Qs$5NhC#~p-@p78JYf~MFlac zL77@w1A19^dn!i(h^f8um|9BKw!it3&GorR3qM7{0*Qs!XI`fKt0pmj8#mkQN>-## z(j?xyrl_3$a5Xz5fAsP|nvpXV#N{(uKy`n7qF^?>Z{9um;S+neGgFqxh&E}KbF1zI z4>v9vpN8=8|B8BxNyzz0*qNux#jmv+Xl?n%?}j;)Ze{j7P1=V`Hvtt#Xi(9xhQ9g0 zM}L?ZXmL|}QYy%@Rr~_$wrY)Cf9YO0eW~M}k6NZIFGJBtel~7*Jw$(HE5J-p5d71u z50lLna5xL0e>782<>~cCDeCpb0-O3Tw}ze78TJCZ2IJ-vokQnIAe0FbFqlBFvk!aa z(-Ql6*--oVg0E6wUocyrissTivWkLe^<@A}#MEiBS!>0?%>%BLr)kt^tFb>zWe{Ey zXI3K11Eo&wGxkFZc5Dqv#vOpdT-fBZG%jZ+t`~Rx9M%YvY9dP+39;iP)|KzpSJl`oKATT=G z=R_$kbH->0g}F_8yl;Y8W_n-nl%|Lfme>+Bof%T5=ewP$Wa@lR@OF1myIhe9`i$Y! z*=5~jcz(uIUSC6mh*(CFrga2wI{KDJ?_!sC;kLRoBb%)WL!a5~|BUX(d7$hIGy2gvdnK>lP{HqdR+2%> z5nIF}@xA~NOR+S6w+4B}?GivrLl@!s1Ha~b-|p{a+-Itc5PgJ6;$xyq(7LC|mci%~ z3XB5rNhU=mXf6?iOEnXsOD&~qjarL792q$wr7X{6q;AjEzjvN_fajW6XZqJgJeq(2 zj}A=7e&F@|+NBRdBbK55O50LQs=jHPl*c@n%GO6yc*;Pu#Hc%SY~=2HMnJVpL`S2C z#3D)2_wp(lfP*0`#QB~Eo+`8v)nes7zv+hT%Qskn@1tm;t+A>6z8fd6{LP}dPxA*= z`H9GcgdqCtT5y9BvtI!0u47Qg#7fIJa`=L#%(6%eUpB1;Y;R6%_|1y6647N`4o|QC zi65}ON>%;<`cvpNYtV*$Ra$B+`MCXF0~cpF1J?r)4AaaD3m0 zs9Cg;(B6Jm1WbZy#OM)sWmpQ$+%(t!!PCV!zN{=txWczGxPeNgfW^N9tUc|45Xv&_ zt1MSMI!tNnal&DN-l5W8r@Px9EOjxn(oZ=Q-9zE=C04;YwMHXv(JUh&IA?1NlymkW zUwI=dQ&26pa4_P=w$ZRK8I29oc)~!Cq~KOk^L3@2M;$aM)efn%|p0H zfUfBVTOFGsX7=EPM_n;ImyEO#8osI zWGrF5d8=8s9a-3?9S`$sg(k`onVW38s2OLhD_h(?qKZ_-B;pgGW#|S~^({iB+|HaZ zkAKQDQa2TnavFIRB*K>e8bkfuRrS@5T;b_|p^zm}LPTZij0T*(JEZH)^N`HFb!;`E zwKkQP-8?#yx|$Uh{0!e^jSFLg{qgtdLYac)2t&3ijk+BtWD#?Gv2henUozaFrNGdm z;|u8YxY{2RYExH?(Pfd=_-q4nZrP^4a4ku1PR~T6D0yS%^E;?S7HS};Xw<*mRP0H~ zgJaZg%KY;wHW<-8XBDfX4}C(g?QLkF4i9X-9524 zie8#YsiPDarUeE6IK(Jmh#nP8`Qxki;#}xx-+jt$&xW2Ef$1ynK#uPUt{ecAPUL3J z{Pgd2!nm2OAKn@EJArhjxK38CyT2d>kkm+`TWnE=On6RM87qv>ml^g~9Z`fD_55CO zmY(6joN@(tR{0a(_8>(jyPh>F(aR5sa)8~rPwN+SCi*+C>$$GY+~$PZJqIK1Di!gkZ`_kWO=SMtwLH4qkFEX?o!PmR98}>8ewm@9H)krnkS?{UnUp*SYXzLN)fKw((yU$^N)~Sz0Bz8C{r;BS|==-B0*$tL94L)RYU~;hrc6UsK_ey z&-F4~b4N#(QXu$s5V7JV&8=6P6V08dABGs5YbQ~L+Sn1nqKE?33Y}=~Ek~#!18DG@ z0alq0!+RMUL^C(XLg+T{(x{@89EzE@+1Bx<(3?Jt4(}js?0tT`z!u_o<{14F*G0F{h|D-{OOkiv=o#VL#tjs8;awQ#^{_}g9cKmNK zP`39pD_(u_uhb~b7|GuAy1?#sU1oYy*p~nKt{+vc5teki7I3XyOlv0~YF?h>ycOo{ z-N-s!o;?28{(EPdP@d;g-c%k)A_hT}z#;Nm6rS2uoN>Hhcz*za+L(WMR#e|E>xBcp z)84}{W5_4JItFesunTAE?2%yjsoc^QU5CKB=#ynYQJH47~3=?|p2xPN!VA5W0i zU!j|yVK0j;+Dwj}%SXi^?pA;C?=-mCk#pC^hbTqICayd9xKmJf`@YdFJ#Mb@$(e(l zVi#!tgIMW!8HBeggcz5*2{>yjoBK&J^UA2f05e|zhuacft92=0klPX|2FAcuB=ev7 zRTn;Y{6F$e+dJaCdFwpR+SPuwN{klmNB^+?Ck0?!aA|9JE`$%R#|;e&3nosBDD^r6 zu6|FBT4#I1_B=E}BVr|%SHZfB2ynRRmntTInjz1g+vg)J z{Qrz&8;|Pf?EH-^mV9xSrGvO!RasU+o2W6}f|mCd|7U94#_NytV%x!PS@+mX_3S10cw zjbROA5T(>`8o9*yqgNYi-AK8&vsqN$^Q?mJFQE-U$L1x%f_i3zlarIr(b4Fdt1EGW zR}!YG4wko()#O0u8OacY4-}EIu`^vT(zTE!zd$uI z&q<5mM$O($j4LT>Ym#n6@Yl~H&(sGOD7Gb~{fu5p@#7OG9jzRIZ_*)OMaj}gOpBYU zWc<#EhLPtMeZ65sRrvPtk@Is2G#Y$ArVaoun>#!$uN%-}Pz(+d1nK$e#f#`p{ zA|k(IZ+kuNP;2a(174X1q~4wa-<&Lui~ifAf*UCwTiLMVD=o>ZQnawJENpKl!E&ha zzBfg8o6KUD6so}h*Ss;Aan(O$P%CdQz>hxqx^xww3U@S1SW%19A$%_>VUOE8UL$LD z*FJI=d+TM?omQxx@spdqAZMPMmeaOG=2?h#sR7`V)p_WbWsD3hxZlT~|NP`0x8S=6 z=)1yACf!&T7OPN|*3%>VqH;R_lB9+X+p6*%HVMiYq;slj;1LV8qJ&(w&?oCdZ(Ut# z8jxm@85`QCN>W9der4k!h>Blu*_%BK0mOPduozZPa(+FmwlQ#FIrU$zDm)A&~m`j8{1h_mq9$Dr+@^ydW(7 zDyHS0)vyXQ9bqz{s%qK3;n)_m+*~Jf%18gJt zj9HoofV@i;E=DP)E!ssZ)oCg7eP?LCHn>(rwe`_0koY8HoYj8#1&18=em4Y7#w)9} z2P{-`nC(Giw;;<-7#NyfZ3J~YGv*G-iwLvNMs-8yi=PAncp5P@1FjHSnc3T%L^B+# zM*a;U8_X)p`s$H6x+h52K>FhEuzP&boP4ufnj#;?2r_Ke$14Ity>9i|u&sfryutUs z@hzI1+A!5E+w!4+S|E_b9e$y?HF#anzdf)F^8v}I)6oUb_7b|cJtl8&|GnL5+>0G| ztv-p;5n1%$nE~q`TAw@=t|AEDTuf{zyS_D`ncoHpRln0 z??)DQl{;xr3c zz~86=cdIlaku-E%Q`N{<1XdcIkghY^d&j9_Ocz)fd^I<1r}Xe)QJVQUaM7Ef`vX8w z2VTfj&@8&!T+<~RtID-ZNlR-2F9_x)<%VPji)>x`0Fyd|%RQQFVfT|kiX8s;J+eZl zBQ-wG+ut)e>kBZ2 zo9|1EWT9(h&Bf$Cxh0#Ucl)lFwR8W^~msk4uF z0Y!1F)pUUq1w-Q@dhe~u=y>U{glj$&%;yFaP0cK)%&d2yPtM=j(p=y6Doznu{~DeD zniIU#%tsGu4*{F^K4zyKOYCZXqssbvTd99@KO8;;7g03pNW&>iD9mtp`~qvOw_%Oh z8jWo}32LUNhTr|v6MJYz?x_nPkg^6T*yR47T~`74#sb95*(o9p?b5u0qhUwDgD^X< zC<+9DEn5xN&9qF9%StD|7z;Vld;gxurBfOC!i{x)085F9dryzchkc#(el zujH9Wwp8reTALXNNBy{bmR))vlq|=KZR-N&lhA7Q+~Gb50@-bmJDMs;&E>1X@yr5{s zF@BD|!3(>GDN8NSE{sKeN6J-;Aj=bkOQJGbezp%kpn0Z`@RZR14csY!VYOZ*U~Q^d zX-tq$2tj6-&UI-?=JYo3|JSjrHg$OTg9H=V6Mssld4r?6uG-|V>j&YP{su2}FSZuh znm?B~lN(U#Xs_1St_=S1 zii0+}uWcBBH(EM_mtl`m54c(k`r9jK#-`$P=hFmI!ud zY)lt=-}Q&W6(*_aY*&3e7-gwfB`U8(1X@2j=9%ipweX)}bOs=h^~_vl63pAycO3OX zH;`LfJ-Ch$Aly+zryjL=Iy$TM9jg8(5U&vIP5mt0Tk=3+j5yKL7*#Cstjm(2){U0{ zZS97F((Wazv%>9-KLLt;#yU?gFHD9P7zwa{!`Dj%@z%DLhD%5-^WON; zFJxZ?(%x~?y*GVv2Py%b4zRuBz%l=Qg<PkhxrcwqGVs#H;Ol=5*O)6_YwluH zmjA3nAN5g@H+FXJ$yL(uGL!$4V>}jEI5@oKzy7z!Sd;f=!!_4rHM<7*dy&vz7;;JO zi`aHlD_b)!B@(+_zu}uFy zMwWG})6m%%58>Ee_c3m}X!%tGs9|~wJ{Gg1s)3woG(ougJ%FLv^n6Cj#wK=Xt0>JY zH65$?JcY;xDt=)rlssyc9Uf6<2n)$eb8}HnO-RVwL;@??4c(2wEm&CYB97ZPJ;bC` zs*eu#wRLX(@>gM}%^f2p?S&0;o#^_!X$`F%SQkEzNdx`zupUDGy6%|BBG#72=34NI z$sf!h@NqMfOnR5YKkRos?Q11VFSHu_E;1`Ge3di|lyD zu1^{V>qO0QJJRqBfJJ~{hakVal`m1&Ffp2pylHfmDfKkX?<=sKK>qd${VSH0Pi$T-w=kETrWKj9aSnk#EA1=xT(ku8 zOO>q;i##2*-`gld?*YH+5<`H6I`e=6r#XPyeEVWu`j%S?pyx{M!TCF z`IF6Pd-#H^GhT<-SDOuaq(~*$JcKwi5{#ruxl=Y(m*0@|wr z8B{WMy};f9$RZ@xJ(zcP{-mMjWM$p@W48uQzLU+%Fi}%)$Ml_a-`QFV)n-=w$*G0t zU%y4FP4bgi;+jpwL_K!IgXG$?AL_+d7MtRAWLQ(RLr3SSgqG@ewaI-scA`Na?_LIv zJ|!4+w8HUk5@#c7jWa)J)S(o)m{D=>>||z`}%)+ zTS_1teIhxL^XQYmaW~k0!lss%1jts)#50Ok2NqW&>$HAgL4gn1O`P*N|A@9kCIs!T z==F^xl2hHF#&DrFUr#Y41$O5SaVYS5%Jn~-?Rc7PRj!nXC;D9`&yA(e%+UqLCt!RAmOXNrNSUa z>CTmK05%h-U)X%7Nx}8X_F>r4wbe*}&ADzl!2Auq;LK(k@-Nlhm*WgW$EPE2=aB7y z;JWIpkJ>)(#K@~6kdq=A8Npp-Ix1eJl}Lno_n~XPJZ!@OXDz=c7b*dKypG!>49yxkh84-d07zAuOf+!#__Wq^ zXLSY!#I3gIm7aIhF_g?268G6<9H(=yef8p@{%Y*|I`atm4Dw6R2xU^2oGS&6GC&29 z7c?Y&m$NXzOnv+dApf&dLnhat&N_bQRo~U8bI2XV%=}PJra0lz+}#rQ+|>!-=2Rs| zNAk;dg;RT>l!;*EISu}xF37iS%8^#xE}|3-Q7e1SqQ{I1SB<9=?~ zCFj&VJ-<>If2gUW6ANE?A?%K~?PC=E2zsPC1hZ5--IA4=`EPdYOfcmD{9oMcHVy5{ zaub;PT4$$11a8kRkY;_J&T^i}yk4H3z$=}GZTz5_Rhid}gu0rJ3uQaX)%}K9@zUq) z%g(<0v$Cf~w}b{a@r@_D8fKi|*o~5NuSPA{KRx|wF6!DK-o9^sU~mC{{c49LyKfEQ zQs?Jh$2hP}ZM=Ul6r7xS12Vr&5WP|fy^wp2TQ__{GtJs;@VEE(&#}9@UtRSuZit7* z|LZaF6tu2?x+ypw$l|jjnERbNOvDM4BSKc;f{fvkSq5zQBOZGuv)r-e;+UnH{ zHo{@Z+7FZt%InJ-aZD|5c1P(4FikDOW4XEhhVP5}mKs%j40QEbgt~fn@ZJ;L(6#`* zco>rm6m!}2G<96-c?no}NRQsJ+6M%#`>(J%+EvB7EK7tpIa7B2v73bAIkhQ1bo>rH zTIa$oN{ImeRc`$ExcRU#ooBgm;|5Z%*}o8Vs_V(oZRfBvtUcS38N0X>Yjf5<`>H5F zc?UkhlKdNex1iH2(6g(4uTHLZ0^J&XCbJ5<(3jjS?x?d1@{hXr?SNnqUOZ)yuSBJ7 zegkxEOC&U1+f{9p+3P%=*m#JSm6H?bi+zZZ4mAKjLTWqq7PB3lPtE-y1WnieA=m2-Tl6mb0SiJB% zKD9TRQjbbM3!lf1v^Z?GuC&tRoyBg({oSqsJ#-P}1huT2Z2H@-hTiHNCkUWS}r zEmse3vvz-{vcVXo`c%lj134vnf1RcfTkP{-y04i=v|!JSjjes`bDBBj(s-2^0n7fR z=U@+n--tqa33y!p0C4b~_VU+xyu5myxCZCB1<0X!rhJKuOQrbK62zz`f$4it)|%3B<`? zU*tS}Q7@3F;;i$h-ph;Av%>0NQIq$BK|s-D;KvCrT#rN3{FOX^R&VidsW{;YlU#EJ z5mXx@>4YE5yPkV9K$2DtB-RPy_tyS-d$%;oKZ!gyGOpCeuHGaU)L$a5K5+()R4r5) zLn`ln>-5t0>=D)})%ily%i3y$IdZg}#cf!ApMg4uq`slFiSnXqRf}V$K->1AGPaf7 zpkHrOO{UZeEI0^<=HUb`j=Kj0Ub*^0)_pw-n}J}}ld;|_!jr=~Ux&vEke`*F&y>x{ z)Z*ClS)xevS#9B-M`_1%FLW&Mc`O{L>*_`6@4h-~*7ZD5$zCglnbL;!1MFKF%}c5k z!z&pMB*Qh{@;FRRCl|1{?n%FO3XXFq%gt1iaFiAK6Zov_dvO-P)pP!XQ|0XjQm6iQ z5!aW}3uX%7?>VcFJ6YmYS1HRja5hw~`3-Yu_Ki$}|IYZmx+#YP%h~ee#sj;=|NEKY zgoDG@g)k9nbQ=mGj|mA4h5?|dt2`7+ca84HrQfhbgsWbF-7ZBQj`GjTh@IV{f|I zjGeH|zim7(X$Jzpiqn~$6{OJ0=|3mep_^iu3?U2jg;%sEKnoC!gVg?>9uFaB+lmdDc zVEHF5OwO+Cn7D7ho-^UWmU!?K_B(Tj5nM2qSp)uTZnDx3zRb(fV^-WWem9bt1+j5Ni*5Z%0zS+VHt zT(7NILe>%qIYQmbXB|B4m`sj-K#i_!d?7$=d?8)0E%-obrg9F^*;nFjqcpnu(!pt< zusN;XFT`f0#6HA*CE9*{cH++3*vv#|s?jA`g7;VXG#Mr41Te?rsiI+JsZ#wWW9ZCDfrTp zLvFWQ>8v8hePEvn+##!ZTNH1*{|2UyOyUWQIg4>D*|mVb8qmA=zu~ zr4~D>rtg_uj#_Z{;xHHX)vno0l;ONszI7(e8K$pnIyakR9%?;T5+Qsf+$F%UrHaUym=q+;2ql5%M zoStL5;MQL?0@nW?({cZh-jgF`sQMs1ipkW(ef1t4Yt*R7<%KY5HJUfX*oN%c zLsHg{ZuNF{vFTll!)By+N%6_6Q%uc&Mp|QKJk=&@D0t5ANsCtK@_p5{J9^vMRC`W$ zJ}^ec9DamQ4EVpiejf?&zVh-krW$r%QqlE>aMUEcdTzhYI8oX<`*NA%^9;0JqS#Xl z`FE;wnApf4qKMw7>LFZq5((LQGF&HcMAi>hgkt!$6wI6FZg4#1v}n{$VfY#trS zCJwVJlJ_tF5)Su;+b)k9EsQYpm33t=1nN za%+{|Td}^zE<>~aAw+pSxGi$^gZ8X{v%2+A=6~#n9OJbEP}1P19pn$tUJv-Dry-j8 zL2kBe#D4U<>pSsu>|$>wGwN!o>!)#f_0pmIN%3_>mE+CvFFck%htN7Qv6<=J|6ft8 z?|phOR-@>Yaj|24c!c{^>9RKou<}TAM(k%QgI$-SkB7|ZNs+^la|M@BFLv4kGuU1f zY#f|@r_WGrGqS>zP_SL(yi3+Ro8dz8uM-C%b{nNg)wtM`chjXNts1+BR&CNzZEu1MX6S0p}ybU z@h1GQIeIOH@_7ANZOAAIbL~~g+Ok`RCQA{EhKaHo!Pucck2?B?X-5rLj3w!Vu*q&8 z)5000Il8a^(LNY7i$|?5cQikjQ8^NF{q~x!9Q9=OvLTjIRoUZEXjD4`e)#icW77m6 zP+VX*eFak&T79-FPTyEA}c9NrqeJbqED z4|)1z#F%KBH)*OISLKhl%EQ{K5tMBwfpg0*vd?`Vf_j+FSomyIhI)L;>mc6%>%S=< z2wCUk36lcG5YkW!;Wk2U1piPksO{5rF)>Lz_OEEKsmkE0MKI88v<1mo`26|VF&o+_ z?=J)1&60VPc91Gs2kpZvJ?HnDqfZla@-Gjq^pw#)9w&ka7)nus3R+34CEp15m;86Y z@m=3mn#cb!4NZ&YT#N)ki^Dpz_ehUbC79yuuBXwBt8p>0e;#gSZ z#0G9g(!#|rUXZsp)*3mqwpjMrqIRXbXZe*5tD0uO*G+4Sbxh`K5WA!0S<6}U?e}gj z8B~PfjLv%nQK_Fm)TH)xEk)HXMdgP)Us&{|g{;QfMa8t!M%06v)d#x*ecE_;?a=tH z-G56vJJ>;e{YcYYeyOj0=w^zA>+;b2#ph#hpOAt_6!IE-&%8f}n#Ilr*=X@EkG-Bh zF~(|LeM!DE7~o3o`w~Qpkp0w*S}g1-uvqi|NSq}B6KEu_AvzXp=Zq4~&)bYx!Q(ND3px6u5vwQxrcAkfI&+Qjdw$oBl z2ev2$b`kvpAP91~*o#W#T!ca!<0-b@lA^`O!x$nV*Xlc_Uk;BrCUSEj=-^JV{eBVv z`n|v=9_Zooav_EBKl~*fG(=X(Gm2=@3_CcYZtg4~4*mMzIX?tBoL)Gw^_!|C4!)kPM9>ZU%(K5ALqQ;$oHy8=N`$=4s%oxqCw zIBV78{9fp8K!=U)`0jTMi4>I@2#0 zuOPAS?gk!#dpB=_#I-OzH~6+WtTqyC*cNS-a4EVqS;xE3af}DXwOj3j%Fq@eGhRqB z=?|u*1p`^P6rojy3x9#xK;RrvXaQNi9fA=1h>JVaD`n7PP-vGNyZG$%{?_{pi^2|P zt*93wuutVtX&`Lng`nh{L!CS~q0n>_SNy2(n&Zi*91XUb%e5Z}Mnnm2-vL3Py;uFT zk-zQS3OVfRw#rnC*%6mI<=WSAvRJ--Rh)0r4ej2PgJMbmvd%ZVPigAC%ucB8SK^Y- z**t+Ab-zVu)&D)#OIUF_@?+VoGvjn^B>*jw^Hp|A&m_dLHaq9Mq59Qi3K>G9U%<;L6}%XC1|mm+alqT4hlEM zKxfcWVBHVbN~FJf@9cx=-+gqS$PF|?PTj>Tk1vii8COdiaeVdz@IjP-x(GnxGx!A~sLo$`>*miS+E_nf-6>;(jl4w5 z-@nEHO*7aQrVT;Ed6WLtB@zbriwGF4;UDVbpe?!|UV{{?ldRfI-vkw$q4nj;5X3i` zCf{0bt59>`BWbb_tuz*=|5U?i+mPc{=UzwGdw1)|@|~ zevP0b1BD*G>i4R~se063(;y;}5D@aqqZeXa%@jy>cyhq<~0V-D>5kVM3iV_-SA}h8LH^>Ml)?U0dDGs2|o#0wY9ap0SGeSgLN`aHorSElERZ=+Ka+-&U`)0br)I2%zaze&V$YC{N>eeOhG5% zKiDFVVwf+7=MV>J9km`jBHZH=NZxG9e@Xl*;N%Ie|Cuw}A6teC+bdGWgKunor}zV^ z;m<$uLI8--(=797=+`ZfxG>+{&@3A1&eEUQ@ts`7dYaX@yy0#q%=RBeKhlu_@FqTz zV+~qu2}Se1;k=HoH<)w2_t*Lu5FhH6vH~_`V>`dX7AT%Q{~-qW+yw=7BUkaQq8qk< z8SJbJm|eGY0@w?pqQ!9WPVV3PWA6hDw=7O4mu;+gpD6i=8^}cMrNgwv&RD`SB*xT_ z32b^h6-fU1ER$- zaV}LwSk%}=S|AK*+`9=lZ6)xr9vJ;w?Cn##3#YwFBN%Jjs>Zp@KN{$IiTIX9wg9gl z^Q6Eby#w0?U|@-kmPHCuc`FyV+p0rNNGbc)jn0B(GbMn%NMzJ6>IVSWlQ#?9YTm!q zUI8gpa87Yj@w2$6QPgli>ES)BUm?GWzjT$X_lWHQVuXEnpgH^I^xOyy!vk3cVtRaf zZfpC!jY2PsmSe{ifpbm1+A9#LJ));PBxU-=!iAYq2lc8(w(K7pmF=R zDc4t4 zLE<5p9^W1x?aU&cv()WSvX`^^VNo%l%2An9f03}KfGr#QwG9LcjUV=uMMgc}0%eJa ziTGUt5Taa<_Tl;Zl!a|=WA0dJ)=laz5Dtxrld2-L5#WDrB*EG?v6JGZ+enAZUXiMR zc|pl$y*#e`FTm#)NVjU+P!dqxx2!WShR~m@)zCYVZ?PTicS7Sfw#JsWU>C(C6yM!f zDj-pCa+aM_@9~YP;Qg~@QP4vj3av4le#MP}oewqog;{R{xgMEpF6mCN|E%_6$AU=N z-o+P4KK=a+M~ zQg@F{L7{;}KQk`>#0|RFEG}~WnRd4l)tbt~k^dS#1(G@_IkWuHc$=33SW2ZHHq-?< zwhw9=U>%tK)$0vuEA+*bX|tW{vlmFT(3$qtvwSe4i<^pTcvnnH?Fuz{pqmOdW(f$` zYoZ%*&v~l=Hd*Zw9VS$X(pH zajmDF8*ZKzY5o1V5IpuSsXa7TdeGAdd3#U@%F1*&&{a`P3snI=2AX^&Q#1s24ww`O zwi)ILlczR)Ed{7OJ$Cdqmq0U+CO`;0F(w;)TrW>;|5^h+wls#B(+V@km$>vy|4NYEAp~-5)?VlC`)RXww_99D_+3p3pwnQmR&Z3erw5YCL}&a1sjrlgpJdBwbT> z?Cz0xj^0^zJp8$%j5ZaoM6btQ$+4 zQAZ%-eA;Yu-(A3(qsHD2Jv!DJZh%<#mcBo|6*{^{Y=`J>faO4o(7V1^K&X^CnEpf&-t#=P z9qbXg6*3jH2^t~#-JGj!rUSUP#3J4Tp(daS&ynu6fUzbIs7VFjeF!PYICoOjW*By9 zE?4kzI`jP31vddnmlK^vd$4uavseCHgIx)p_U^5aARy{Em#F(XG7hiIZpE!;10Q19 zeZDYO4CGlZRzTqv@j!sKji8yQvO{KP!Ht$u{nCq&wvltx18#m-rstx?k*VOCJ4Vg#s6Nm$ouSazfkWH9wL+VzWKM3uHfomUIPhKCJ#pwJiB;Usp^Y`?4*i34;mD*rnBn+8m= z!`}{LFn{&10-MzS`UOH9p3LQuRN#Mhxfk)sLl`p7y2Y~*OcJh{voCJw(yv?S)%m~) ziRusXM^n@M!tX+%XZP?|N&W?vayi>r^O$o8=##nMlWM${;C99I@+(EyO^|$}?m0{C zgd+ebgBvq?&0HnN&!`oAq4{MO`&&VwEs}VyM!A`m*aLCHVdbi;gxv}rVP7Y)ML&;0A=odyBZ%iRrf@J>QzGe1p<+pZ~Io@r>{MgG%oQ1 zHZnj4|%3?_t{M+O3QIK1>bRprma4lduTF3s0%%=l zz)HG1OdEoV^V$OuXnu!mH3Eq0hmVUZ!oI`Je5@fBdjKZd@8p2>PH|UfcJZ=s2Z$x1 zd#{xY3&XW4BD+QC74iX)Cx;kwzk{>|XtI->Cy9%*i*4`Q-_ZrW0*K}s19JRPr ztq3r3gQ8_lU;gy9Dj)}krpCrUB&7it6U`(Mf^2PA-!sf?*?W9-$U=4HS>iX|a5h1H z*_LM#eg#53lBgpCK}ATyZa2-X4~vU@0&x}6z!}}*fBz1!DB(AtV|ITv>;Pd*Uj6Q1 zJL7kUezi<~=!>Y>cE`W}aySF)j@?C{BK|9(ue2;>m{?UemnWT*G<4!_1v&#r6<5wBil*nk}GlOi9aPn@dD z3QGf`NN?}!tlXJ!$eA~$)D;oPxCf`9DhfPxPqKZAnt85*a|B~oj~jMd4HC;Gw^LOJ zxA_i=8a)7~k9x(sA;k!!Y|yLkJWzt*@;!fB{-)gRclAt5JV2I~DE1#HBrJ8`NO~Qd zmo(Wm*352U%xTZlw(SDBsoS+pp}OYY2wzvO+c}44zSc_THApF%eu+g}?q}S)vGkpcovR3nJ75PnYYd%370ID66z)w^0? zM+J2kLIP~A;RQmUuu^`2ATLjWI2_-0HkLS;FpyAO8trq3fcm>%BV^Ch0ziV5;P3s> zU3_b(3EgSu`(|K+Z4V-azxCX&GdoJ{S#}n;)c9#Vg#3KS;#+w8?N9OhXDy|XdO`-| zBJ6dl+W&_AWj;*Xj$b{B8?n>@BE^S1ho-jIe|qrK_8D+oOuqCHLxU5=IHJS8GT?+9 ze1Bq4vfe49?{5s}0C%8-W6mMRGwp!#rxCwuonG1+bu?*b#27d>bD%YzB00)HiC4Xo zKkFDlp?gO=!gz`^ZdJJk@kc5A3}x*1NSW)>o4!~nAWsMxye#q{ct|Pc`=33OqzW8P zgcw(l;@5h4+PT%-UQ^e+@w79DXjggv%9x;6YnuM)ge)>Y=%@3~@ktxQXmJc73B zAN*;h$fuHI0sfOi`$^luQR@SGJ#5<MZ^x{Si@<#8AIB<|xoHr`-Nb5dt+tWyoVc;(6Xg}+D8*Dfj^H)=zq`e>v zC}-Yc->BNEFZ&asE8?3g>)ut&}T)$ z18_0oqT4U!H~QTDe}h)FplUQ=MGv*6UnB(cv=NE+r63fF8gUh}uI_l{c*+)1URez@ zkMXPGlhx|?9)skHf>kes8#(_jt^qcx&|b5#^X=TaE~^9w&hFZZBB#@^?v`h{B0kGI zr#G>rp{luGnBNQ8${Ko#W|{ha(nM!WkSgqThwi5bYovaWdi37jel@YR<+;}!^GPFH zoH;Kwt3c_Xh>+ATEz{k#o^$$HSq;q;&;3$1T1^v}kFC>rV_WJWYh;#Bn#z@O; zP~9d>Pb}YEkn)!1y@|e;zF{XShfp0?2%U;Z$I1kj@cp}if`TK@E_XS3%~kB>Uk;eP zW1M2K>_=l^b~5~^Wt#Ggn%f)!a0GF>`>J5tQ;uTH9J@2kJJ_!HjEiH{M;f@|UWO8O zIrGa(W^IOKh%3da^Lu~gcR>$|ir$i*cB&WO5V6E3`guRoNW+8}8jd!-Sjj2kxtEUI zWxufh%4fY1G^x;W-=bP%8KuGFSp+Md?z&YLeMf{j?;mcyP&s?6KblTYY&2O|VOBdC zGfGsXA!S(n5Jz6UVgMxZ4x)`7PG384K4gsof36MR3m3t6ww*c!e#TGrbPIa|LuK5q zNY{w4SY5x4Q=xxrCeZCJA1`R!*b_9`0wj|1p$NmE*C6)wZ;aSz54xi$LCeY_c zOX!(Ct$(oQ*(k3=3)czTV|pI-#6mu_#X{HxdvACa!N?Absq88GD-|!=y$}-E@TBIN z8$n0v9GO(a&N&$^8y|!=&E05^7z_RHRepm8%-K{w7;MCSFE z6Gn3Hc5m3I#20$nLP4bAv!=u|0ey7b1-U^B&rV^MZftGDm~*z7pusm%-q4u5eGP^F z#;&$(FKGyOD|Pp*g6b5aQzrJ#{%{i-XW`~6gBHN_6CPBLXV-#{edMouN<1SoaULA+ z*-hpi$dl&rlcRd}q5ts=5@x^8ko?3?o6JHVoTEIzB}Y)qAh0jWZKo2F4IQ2@K_SlEo1t z?%pCTwOH{TXGKTq4uQZ<3oMi&+~(3v2TL2x4PTQZNrDTF6xG5 zaLs+qLMz`CmG!ce^+ABag=2&Eo}I5lQ%H^pW4~oT6;u{j>0h|&XOC&5k6c(LH?v!! zTM@?B7GIqX&|;MOOj9m)bzafx=sNmTQ)zgq(#6}L)$b((^a23e)}4G+_d-j?4fay_ zmbJF%Q-`#@ylr|xX7eL>-3#A?zb8E!poBE(LwQTB2YqH0i_88tjP=7|Io3qdYd41b z4OOCJ44Twi$r2-}PQu!!6?-a*BH1{6Y~O{6vBJfxzHGdG{h2kzhUU+nodlMAeNDdA z{^2ES=i7~xPr*NKIt1Ry?5uO()(*~~cBuOi6(z3uURPEf-!|^AjvuN|i5s{w*pXb; z9qUi1(7DeWnZzc>wJg`nF|X^V-*}%oaI}$4obNAbq%(a6bQ;(#%y~_KO84-wK+V$_ zd)+e05Kmg|*t5$d3P!j?_l18*D!HTl*n|ZcSG77c%y^X`{G;<}|JT+Tz?QDY-o9NbBm@0XF4 zk1|ZYGicbQ_a$RS)&lcK@sMMhf3J?bx}NPqu!vm^_N*%U^fQvpBJ<(dV1a33jaa?> z3VL+%0mV)04s1S@4wfAH&ePa3%6wnP@Uh)qN@9BkSOqNgtfe!3(ew1m8~b(?j*&NxGKz=cFQ_Dp&mG1Q#GJ{q{T5}=`L-yt)HU7Y89V5ztHy!}pN z{f2^$SmNVUCHa4aOMP%4p)9(G#9Oz&UK_Et;q!JKMFv><-Y&k}7180Ja$+sWlogkA zbpi7Xe=qvEWLgo4G){l;Tq&^4ypbv%k3MHrVWJQ?Gdbu*ZqFhV``cF(QS3m7X=$d? z10+Y+B0Y=8=}MLT2ta`z)=tc>Y(C&PyONkRmO`B=TS6OKbYYvDs}iUmyn{46!unmu zusc>bj%OG_C!`_a7T7`8XU$JM>3IxQH+p1D%RVG<4b^dTq@kzKTz@lNY~aEwE2Y&e zg3_@^Q0%;}tbS|0n3DP>XIcNKB-Y|ct6@JSdkXjR>cWzbbB9~poR3J(LDZm$mOZ0p z1sg-dJf|^6`Sy6pQs_Z@lvLX|LN~MRDPi(~{nO31C7Nwcu*7Y?CjD{zK5G5jON zN?cdONJ?uwz_wd0qIVYC+$wan_rm}hc$TDA9eH+NqA@@eQzND?hZ?l)Ww}`bG&zF} z^q3WfDCivLTdr9QJjS}}JH*kjxRUUNg6SG=UkB(V7o$m;DYF;JXfOgzwsqy8Vs*)gBWs#cn!i4qRS*SEf0c22x@95oI4#6ot38F< z3Q)nU1fcnM&5{PnqhjxT=`aV8v6p^4k|N0j25Gm*PM}75=1lt0MlLRwh(@r6Jp1=( zQk<#bi>K2mr9{f&aSNtqTcngx6;5Qp5NzP??sY#X;`wZDJLPi^ozbw~Q!|j#{N;;Z z)p+5T>S)J&eIMGq7EqV8Ne8Ek;1XdIr2UbZR&q%D&!V@NdE*hNAF4fpxn7i?A{?bF zt5fV>;W(Q)6VabSDxASkX-uE$ns#|zd;Q!Q4>7Dq1+KOR*TBB1Izt=&7(KVypG?Bl zDlsZkapdZ7ZN7>Q|9uriiZ5^&gfVv@Au;wEzqxDBC-*bym%3O7ELba6B+6B-GLtbn zF({~7$|yhoh_n%DtbB%^H@<=)y<`TWCQ-O>F#;P1VW%^{1Rfl_B**+w;skmDAYTRQ zMOA^#!yC!9_+zx}8G0l|p=t&pix_6ERLme)R8bOP-$H{MX)L&t8cpe7vYFl8^h0wp zya%I>p+-%#X8i3rIC~oABFDejKLl77iE=4Eg%E|C4Kl8l^3fV1rC(hjV@l1aasK-_ zgNC2&ElAWHc0!}Wruf#T^V|d%ah6d*3c6Y_zB#a$8EKf&qtJepD z3(RhyBwh5NuG#qt5t@_K$o$hr_9QESt0o_59KqWJUD9l&{T(ve^Mt`1y>finlQO%Q zMwv%EK*#3fqbbtg=r6-ph5t6tI=Bv05#;~u=6;|%E zn4YJGXx=YkWv@ZKx9LYQApTm%=&~q&e;!g!LP=|!^KDXoWo1S4gEWbdvtk`I!DOmn zgHnxsH>puJux#%R;0VeOTGV!!l4r7{y{;`3KcQbFd*;u$SBhfw^BM< zYduV3F@-9zaEd}#%gY%r{1zUdxls8L+0l)l9Q#CHBnq&}rKd3J0!L7@ASyRazIx+L z%Sdnr0 zAL8o5l78F86<#g({a-@6CQpp@eAIgAmii5`!wjhe|EdK;P&H^&*g02E+`vTqOeigO1h z9xWSw@ihg@WCtJrb1C6Ib}*;m--gBX0ft4#)eM%VP_xm#-W^#^#I%)@vl8NfniWnzenD>idwb=K-Z#c+h?4V7uVa~ zyQyU#BB?tR+n?3I-n+<9dbX5skyAd59a#w&XoMs5VxVMqKwEV_Rcil4nZJ7KS|9Me zAX^CX&+x3SnUM!LBjYvmbv7`j+}yzU;3P&J(_$~Ie-NhxP>$1)0q38m%mJcveah&I zPs~U~clD_jZc~#hJHQE4RPuP5SJ#L3l(-a*P*VfQfQ(3SuQ>86l`ERtBuE50%iU8U za4oWC4g5lP60m2~#ek425wUl^j#XaJ!YPW(+ESwpaTcccs{I>)9;KR89j+uZbL$L_GIp{ zVay>tCg|`l8af_rXSaSJN!4+e2!2X&1>F#$RA~^OBK*;$YdBju0aR8Nwcsfq5WF_u z?4q19l#S03(InJsNOM|P!f<8=yuiuo8rcmsSoSYYp*ZKa&+j8C6ITM(R)FWKwhYu< zH;>F$ZKH#5u+vLx$Sa>hPo;Ed+g%B<^9V!bZ|K40nt@eIda(4%73*>P;I+eg`u6y^ zQ$j~UV`#RavuswL!r|-}EsMh*PU(d?8!GT*3Hx9uDn&RU>{YT7vQWZHb}cBM5;1mC zS5|izYaT(7UMT!c*Pe}#y&hE~TVWy>xLScDeAWR649KAfOn)=DBMFFFJRX>Z2WUfqjp!PBsU>{DON7y3sggOwuD z{HcLSftu|&^c_ZJg^6Z!HToPz3Qp(urGs?dM03o9JYs~s;nJ?<1!_Bp@3!{Q75 zl=+6srRL3-!=f18ySTjfJpQeW=*8ecO%|l@MBO6hD~XsguTcbdsulF{)T*Axg~5BF zaIrEHo(F_rb^-|?PA|L^X(-~C^j&d5Zg}seoAANAeHh1n22?1bKGOI=w)57-_NgkY4T^#awmi1eHgL1*W=^Ha&V_ zSz_NpfOrq0!>NlN_=uthvIDGuZXoswW{+l)PECH+xF203KZPy`PKmI%@Dei;X0%}p zd!m*FwZTn5+Mj92lYRp)!Exp$&?^_NclsuPJJ>iuT26(L$t0gUo$T z7GYCr#~bL$vU066nfd-4M=7H&%4s+F%plycHU{WHd*x6oI<{hl$u?O?Y-JuqtExth zfl|+U40sMhP@-tW(9~|wnaqcCsA4qB=3*CyKY)Bii&BlYz?@=YbQ!08Fr_>o-*F@m zjx@0I*|rGs5?N`HCFRh=Jy~*F{^nz=-;{fdC(-1zz&GKq4*y2COYiUCKMKabjUZ%& zS`G7?)|2*`kw}@+jFk#2Fr*QgM6s>y@UYYEu*sicy{626S@7wm<5oCQ&G1o+d7x$w zP=r%xS;{w7Erk{Ql!j%_2MeTt;$vpsa$&#V95SM;tHbH&Cpw*&lTFGoB{iFk(QmCV z76KK`X>jWDJcX3q4%edds5aKF(6uj;o9ABMgF~L*)k=KaXXvBczPdNnf2Qy0z_wAC z(ln;!dEIkT%bOj$BI+HD9yL>2S$4ipd=ujGt80>Q4fHWSPG3|~vtvz7%9w)cO?t5q*B@ZZ z2a+*0AtvN#oYc9;Db4XYITYqaM#5yj3oL^qY9MFTEaExI2air%OUzP2&DIA?^JJ0Q z^%7r!+c{{GSlr}$T_V}aTb=sC$T95S_*Wb5Y+CPWU}cZ{Ma_gH#I0A)q)K1uzhKiv zQEl7yRmVtP_4Rmwo5k4XAj#^Q8l9Qc4^;W93)LPJl0OGV+|aq>TzT0LdM1J%r^k)67fr6b z0u8i#t-d77su!mS&u~zSDa~r*5q}-=QrNeUcH5N3@#~qL-_Km>yB&sttru9~wFRt) z>U-dt4+6{We^y(QLb%^PiQ)&zoDL|ns*@-t{T4R8^!x~R{t9X#b$og6H_TOV@Uo1| zNvf)htjYJgg2h%;l2{-CeLoB=Lc`o`&RaxNHY~Vf!u|&VT*hBGkS`mxDcQ=Iy|hSo zocNDzT6O8YvjlVvPm7*)RS%YXAcJ+3#Bu~50*4j*J@-?aRd zGV4)lt~t2MnHVe0k)kILQT^4!2Cw0Qbx9Q|lemFuPZF@)J+rc^?|r&ri4?t>bzIp& zE7DlLU2TVB*R745lI`q3(N~45X)0Byv)DJ@Au$jLz?Qgh;Z z5St!Bi4a_|9G-sTKICQ22sCHTpBctbqc4+4u3sj}`92slrMf1(ePvnd&7SIK^H^ri zlzU#FtuDAk7@7ukQDvk&_1ReS0Nwi=u_oV8H`evhSVbOX-tHoaB)85&+M9M!PEJwI zh0uKLSwTCrLt2D(A}zN}hJZT7^}{ z_+qP}4C@np^Gy?HS@S5*%&9r^aJB6IG<0n0L^xIxpmjXaGoe|+%l#=pJ_t+TUw%Q% z8FzL0!s_jz(>*G2HT-x$xK5d@Zdte08mm2!BTzD{wZ^iw(~Uhm1PY_v`B^-Z-8aGX zhYL0HizK?AXsw0w5D7Q)5`@{o`5l1&VPQfdn;=9Ml5XH`5ZV<%?QakMbssn&o^h(l zU*Ab;-hEWv`|E7P7myTKy$n7kf|U$It@*20b5Hr0&+KY0lZ5N_(?P+S;+T|*4%EC; z-_zg}U8WYvx2$$FkTc3B>5{E5XrCW3R*`Q~lYc*8CbCAlUPGxKhsU!d*N<6P^isON zIpyX|F@SR|&4i2HH;(AA1trwF8kA@{wmCWu^pGtm5PpksZ%NS7+K#aHuA=TOSfEwJ zE?2IdBTMNgkAHKm$}rmBMDLWGl z7;~&q>&@F+vyKfm%ZxO4<8EEicwUWmzY;R78;SE2rPo9WtXR8E26bIqAPW?0y5*=Z znJxCkKL6Nl&Y~0N#eJ%!oa`{tMt0S>XMA0QevMRl1+B#+_V*WLaaw$wLEZxR@)^32 z-fsPwkjW#SSL9KM1(ybl{^nqi)uMMK?4A~;6s`_6Pki%bGa8zEm;GPoQQV#TL1{9{ z^g&FL4*@hU!~GVh%`cRcv{y13@e1esjG@ zVF$)jk64fU;3WBK9#$So&UHVaI5bq{=iL?i#q26KSY0(@BvOU9Cog0h}ZecQIAul z>X#mjp+@PTG6@8oKL~*al_-)29UZzt#0eKD!|_ERW$K=fx@;1Us6}sS@9Diw7ypSc zypYR1N0|((U=JJCrY{?GDyaANx9hor*5Ji$Z%bRj;oUzv4IvJKIqZ)#cWR|rGPn0e z^qj3t)7*y}kf?pX)E&NuvH#d8Kh(ChUfZB}sjqZq%TbFvJKsm_g`jYg#cbhefh1nV z8shM#y8i?5!O7JZ;oR5~ZsIyH(a{MhMi+U!8RkKt|Np31_P#~&uYtzk%XmztjtrP^ zD8-jz!x_8(hmRZW5*Q`zvLaqAp36aT^N_$JW9OCYKFPC{V@kW=#dp*LZnZruA1#^L zVhOh>K~abEPeg43Iox!Shsr~eKltH$@L%PF`2j#N{Ib=-#M3(H_Po|Bd*w|lVgKg3 zPuVM*RUODcmBR0W&Xc{M^8|>wnY$L~>G@nK>@X4ULBd`~fP?h^TCHWrjM@QSNIPya zS2ItYr|DepX=e?Q-qv>b|Q7+3)bno3?!d-9A(6f zep*k4a8FMC!4Ld)QxbJ~5J|MfzCvE{kpG4RC6SktdICQv7h2ZH2I>Sw9z+f8bd|D$ z2FOSLB?tqqJt%Go(}v>mZmG|7sJ~nK3BhXs!8sib00F6_Pi?-f4*nX>yJh3KPiR=- zzqum8gqyygE7`?DC45`NLqPAs$)EOnS;XsS_zpBRT@HxeOTnLjn(X`i?u>r^(=qrb zNMU$K#dxH<3E>D2VlQnGVE;2r6jIjyHZ0d0*7%Yvpy?lsnX=M%86Z~&lVgyB3us@! zC?NzVUj=(_0!Nu9S;e9yLR%8{J5w=KA91h$=2~3-lnkEN4@}=!+umskK6$wAH#zn8 zzrJpISoU9k&terCJc3;7mblY2p~o45gNBmu3`OwH zcGy1j^CrLj-mdg$lR-_ff&h0AtU$Z>B^F%Ei{~L_ZTq<9_LR0#d=xrj-@B|M zEds?I;LiJq9t|p1T?TWBsqj=nsJDqhD>?&j6)%bXREa*qPzVb!SF?7Q8^wyhIJK_#< z>9ub0TOfA^4-Abfia}efO6tw&f%XrgVA@3AHmYg0*BWkaef19f-15fg z)zg|T^=&+0p5my@9h)#-s6b>cVz68)ng>zc5@UV_#S^zNeN{QwU%i2xP`*_N3^$o; zl4ryNh8>NwSWto@mo07OL`f9D?2Ng)Cx0(cn_HNkte#BhvyvxUo$B{?3xifC&Q%wO zrWZ?F;5kzpFPk8EIgq?{ik7Y8e{A^2Ua_IC)1a^OM_Ly{3siYVDTEAai~csA-GT_< zGB?4|*=tS>Y>0r3QT!sNCwom%pzhXDb)#|rEe1(kez{oof?xau?eyXXy~rwMSt~eN zILl(GYR3f->Mes}Wa+wZ5%5f=!-Wyy9gOH|Wm(XjnXDBc$n6;OHuHN|h8M1diVxT9 zIOE9^cMO)-AA8M9<^hMlSe&{w^n&*$kXHuM-5e-Xv22Z)FW*0#n+2Pe4FkBaY@gfd zz;1vfY;5C1x76LrMn2>D2gXeVxWCw+TUb3~P;H5}z ztD64n_Q38!htS^mSc?YMpMIzlLjM3~OqM+d&5^e;hPmJ1Ji0)H7{g;6(YQJ= zpqAE~JyrcqN!86ZB?Q(kV_w?!;#_!ee&Ubq$1laAR&G#g^-}YvK5A|Zq`Bo2OMF2T z0}EOcCv&axEDq~A&x#GosvPgCg;fDO(eC7MKpH%-Zf4ty?Bh}9 zqL1mx?a6tPW{Z*)A38|3xY7Twy@~wN!3FeMy(p95bF9(h06yMgEG?4x2CVU{Xw|ir zwxTlnT<5UfG%e^nxSRNNRd*d$Q+QV2#cItU;Te$_Ck`Ib0R-n-l!SY_?u0{yzE!WR z#+xY0<3k@^%z4uo;BlaTea{^JUP_VhUPCktyj{UY%SK-x@Z6SI>{;Rxf2c$uXz07# zZHdfrdsyz^>6dr8lVF^Ge8uZiR+QJcA0K>*OPgGEDZy|ttt=9*?8$wSmeMk)PIkTV z>R%Kph8!`_H|gCMbet=e#BT1fyW4Iea@75!yq-L9s{B4nFxqrM30B0L+QE+3oOfP| zvg?g`t?CCWLVgkN^`1_Er$i=A;?$^7^`$3~k+Wj>0cN%lco)N(_-`cF*t4?$$)izT zqjZ5m@3#c9B!{^#1wyLrmfnwt;=W>DUT)z!*Fyf@lS*d;b%bH9%nwrJL*-YO5{; zg3#mR4nSIUTh$^574+nDraBA)cXXBiONLf3x+4nq4A*yWj>cbE8l{cEGO@~Pk9cZ} zy5z!DgJEOa9MB?jQ17ZM1DmSctfk`lFRhK&63%FBxd*H3n*to)T^a^nAw6sAjF-Bs z_hCKrMizMA!8tuW=0CPpG)LH0#TpE%TB}ltr0olE4d-I&ZjQa}vSB6?&nI&oCF${p z%*_{Zv@;KyPW+47mXIS;@3t!~xV}@FYZ?6Fv!S>u79}V@2$s?CHTVM3Xc+95Mm!q0J(SkzuysU(UlU>A%1_cTD;-_-6u?{M-T|tlday~rq&3^== zsIdN`06dsdMX%MrA~y8CIvdt>(H!-yzeOtodcEcIEutcUrDR($G}DSTLM4=xj$!_z zjn6Ra99CwOQq5Pww$GPxkbm#uPO4sUGN;r0Hsoa3ZPkqQR?wTwTta=aBIl|I{ljHy zsRz9ZJpJJZ@Xn^a^miQsFQL)7*7_!%Qtf@r_9?RvP|}o*q+7aM>Rr_uV0EnbGtr-B za=iS{QAX0eAN*kc8St2mJF~X~cN)7kPcV9Jj|-r$$F3hg)ujeMSM>PKvvo~y+Xm9a zjCM}37Tua(EM7^BqyS9dm^r2gJg4y_4W|$k!l+5{j{dQx|L2#3W;v=gjOd9xFRZ(D z$xZVW(mJ!&k5l#K{KeYE5<9PNpyL46$j@fof#r5Jw&Z+RUCgREOwz1%aT98lCU(cp z=t&x}B_{q1>VK9-IK$M!PtcLUpol>=%e{*N_qNM*zRa zone;v-lp8~WhOf_>GBTE1>?ZQURUs{jzsDG+V6knt1{Z*p)lBCeeOI3KZ&wZmPL&@ zlL8*{1MMfZ3wum0q zyfD1~$3nkT?Cx~W#Hi0)varWuD%mqtFv%J2tf5-^|D)^8z7v-LY#W8v zcX=O7a(1aQrheVs4S4B&H&FA~Ba4)SDfe(7n@q188f|!<1`KY9K2n)KbNyo9?%$8_ z%NxmryD)<`%hVahqykb+`(9(hnoMs|zOki0rA$d9=&9Xwlb}2o&QMNE?K2KWl>#E^ z*Vyuhb)f5YUjDo0994cgpep`sv-T@{*|JVygEE^Sl0ZCr^Dv5&Y zvrUhx#5MDoVcMW+Mn8{_|5em>QhVKK6m3*@fD7$2V;m2T^ZyQ26mct_$B}$67)a6u z?!11%HEU7y%MZpnip)->CrEM5hfBSA4aLH_j-bp@lib50Qz;{hW;@@c`*F|i@Sh*^&JS$Z|pfXjf>-p>? zxsfwZuV+8b4!RPQYe0*KP(fwQ9mczg=fzS)f3YJBvI!pu+6|HT*ki-Rnc_gGth2{_ zh?Xr&^kjB zYPw{1@euCM5dP-v0;Zb{q~qo8A!hVs+FtquR-GIGSKY0TL3QU;m7zouo1LmS5r2>w zn*R&QcQu)SNLxQ6p{+9N$)L_J}KafNeWj?n39Wd>)~5Jt!Ioo?{6pJ?+* zW^GW&jrg|mxI1u|PwB=E^E)spslIZCz;SeJD(+OX zvJ^u|AK5mG9nTweANBfYIde%URrv^R@?DBmunq4{yULbILc3}-!687_*ay$c<2R2&<=|8)M`_2iw`XFwYJQ|nW=;yMtH1JIvwEFay41+1R=c zol|=ZgtLkViU;KM5SddeS3nn~Lv5I#y@)=wq(v8}8;gjIVTRFVx;X4SvV=nDIy?UB zvT#)2cN;O>YeSuSBp1BMTe#mJL(6`c6w6Lu7syqs2TWYe`RH=}dHK#A@y(M4_yoZT zFNmKLDMz)=Qyr7za!Wot0iOF=;K~a#1V1r@-`{;IXZDr_dnG&Yuy&2r?1n_|54BEQMag9pK^5Sgd+eLLM36J1Vk?VC1DspAXZh z)I`u4i-v@z<30;izcv%!ta1vK1$5BT-Vs(G?f z=>TNnQ~1#Nd?AE?dmF^ox#I*<(lMyQ6C0UR?X@ctkB2vwd#Gw-=gbW}e_v=Szmzbk zIp!yCNE82ffgsW4&5AyU^eH{Ro7&npOz~IzBo*IROOx``W%)0!;Q9FTOaUP>Uwtw2 zfm=y!q3lIv{Mqz$yDx~xRyBbo90Tw35LV_<%2b*PYFsLUam&3ctpiXPzq!)>?8W+7 ziheA6bSkFjBzy8YoTCj=SUH6?Hc1+6tTraK_W-HFxA zxv;xNr=IcPvGJ!I0bt}BxAF8k2maZI<1lX7iTimInkTL_%4oEZ!sY}!hr{W@`?t8^ zYK>^J`^{N3hm_GTkn4TW-EsC6LsOXLrE6z%!s68vJZ&?N!z&g_qYaNQ>n`pQVQ20u zZZ*nx8<(Xgrrh%v>)h5vA#b%Z?X4)((jXil3jS7Cffxj8zKh9twVs5XJ}ZcL=jcgo zkK+@n_l@!^IXI0b^Q9VUZ+>`lIeyc}D&I4`d&YcrSY^HZT?rP{@GFd+LCxs{66SlvSlPC2{~y3%M?~6RCiobZs51w6Bm4GA;qq4;Np4Bt6}0!TLB7OY~~)1^tQ?z41`h)Pa?Zh4_!D{ zx0XG6{-e^eWfb)J^h#(9!Zd;q<=Y0{m0v2HY0K3^(NXE4qP)aV4KJ{y&V^=YerX-x ziA%nmnPG?}R7`4*9xyJaYJQ_YcN)CXO-WOq=rKPUzHBRtb=>Y`^RL)%09k8B8^RyO zun#;2y=N|v7hYU~H&qg>%$}`fLQ?i<-;V|ztwGR$CaNn<-`|GWO#C-~AmC1^Emx+| zm5IH4(?gi)URM8f$g;aL84rI!Xzpfuf>_VqIK*i$qfV9A`Sk9|*dTSaHwV3XH*)WR zUMu@S;~Jks<4YU*kJh>L?#X=X@jlJp%lX*cUIbZB6+SCq%lS?Ro!ddL-1VNHoen}5J*{fQ%?sFp)8%5s&4P1;7t1Ge zb0y~yq=f2k)~~g?tk(b3%C6(s>2voX);}r4W-*Aw(UI(=07slZnQL6_QybYxVhVtE(mFXCdmm-pN*4 zqz5nR7p;#Jbo61RMkJ>pVXiCg&JI&&muAFr-?R{-T5~N^85Ya3pjQ^4FU0x_C8Dx9 zqJeiUlY85rWyy3fXDK!cvi-b%@e35nvAC77j#q<@B{sfRS0TH$2u6d+j`?=4e-;tK zM3XJRt2_0mVx5`WwLgwPPeL^bAx%!-)Rj=3ObdD*^PV(ow3uMg>>Xb&{VtqWT$4~e z6~^*>3Q{}IbMy38W4|CWGlt-pUVRQD z=|!1WS>!YN59Gvcg*&ViCBCU)%yg`0pc<45Lt=*Jo6<YH>_E^|&VLO5P!|7Z*Q%~CsH|`}^)-I=fB%)kvX%i6#v&HuTA5 z?;B`qGnnsroKq7F@a5TF_QV~N=AsMqYDaOIFVpe8oq(l<@VI_cx&*Q0WmSV9F9iIW zxSc(o4_qzhj)XS})eRYK;Dh_#ro@z%UlCWQ?k6oWXBGB#1HF41jp`G7e=H8>*Q-Js-=#A^`+REtk^18sc8Utv9 zA?_eBie0j2A8iti6M`}dO;U4dD7>2rilqM}yS!W8^}?C$Y@qmC{t#-zt_7WH9CEss z!$F412j1Zzhi8od7QeX&2S7BK>( zwP)37_WNN)vfY2SzfI(c3coVmxM{{@_WD0|>9E<2LdK|WkhRA@Sr<1}oYbw`a`aCm ziVnJd)A+3j4)3C@Z2H)IUS{$wXO0BDJpSSSN}ZM2NJJa?`((S<4yvHiEUm6=^SH`q zjbmd=yqkMJL3DowIOajO0Szb`#RHxnM!(BjZucMT(RHUG#*$HC7I<2(=VkMTh1tQH z#iJbT?dXO})^GZV(nDNLMgsR^b5n0I*6g(c%(R9Q1?i!CDa5yY(Au{6XhKm z@$7F?iy%ThVK-nh7iDqyvMpuwy7CiW0_F>`kEOt?^UFC7#P&pM3w$Eg7RaL5B#X_d zb|()*j_w)z*!;7%35XBj*OOd2~&`~Ye7&73t_zSu-damML9qkYpmynTD z4_i8``7Mf8*X4|9Ks$MkdfS=DbFx7>Kp;CJP#LDWmM;N-P!y?a(1#yp<}#uCH3 zu=&~nCmOJ%bU&B%|@pr=9K|keqewTCG-n=tg!nXb)wA6cV1PKOui-PX& z0Uusf9nmS*dXGoxH8QoByaIjE5`=MS+B3%>>M}3U@jpd#9LzV96XtexT`eQn49)@>)>l2vu>i z1RR#+-=(*7?#Iw!l8^<=YxD&P$^-DRZZ#8)|`wooy)^BdZ4MkQ73OU&Pk5MIvpqpTK?_RD%q6CF5S*$yN{ecn><}COhh7BrnXFRn zWsmd_2M8MNx2khtC)lP!WsH*S3cOQe(#wnc$at({Kp+}`>o9G!t#0vRQwaucUyE`E&qhLTsQ~h ztkipSZ1?W3+>lYANRk471~f)kI-1kT>0ljC0){ju4>!L(K9XJj=bh%4`UB53Im;l^ zncY(#?%gvF8yh_Vlr0A^bNlalXC`JWmx5ovrS!yZ;0}9tJVvLb+6Q&ODk)ZG2r#Tg zw_afCm&DCX`a_xLCCrsgNEsWPp=V6Ks`yA$?RkdcxCG~5emEk@C8{nd*o)CDmv~Cg zdSVwk=Y1=q-_h)T3nj(Gu#sGo(fDe(A3np|Imc|?UseG$XJS(S$xRhE3|3}W{wU}P zK{y@WHNBKMBrM6+lU9Mjk0fNkCF_4$Htw{lX_$6Uded`g7lPdLKf~t;scQ!>H}N{h z%FN^;BPO=_4|o?LBSY-Bn)a777A=dB94G^(5`0buMkLfR_7(tx_nUd|o|(1EtEucX zXT3VDbaH72%5;ppP~>mx0T~n$f7WGqJ*j%={kzc59K{PRXWa`a%>pba-P+hN)57Uw z9XB~P-;dw7C;s58U*~6amHC*&czrRmubj_nuy=lO%?wmwJv@t*iXIO&%R#~6!iRW&st21?hP1C1bPCMZgmB+O>Hobo)P+B z!R`I$&ly_ck(G8hrd`-H;D0K<_lXVP?$Ozo!6@DBPH0=5~_tF6bkEYmy?U8BMkU2 zcU`h;TxbE;o9XiVZBf0wdve-e57W{^#1}2t!Oc6Yyxk+Dd40p75k-8@@3bvRh3}d2 z0B*o+Z$T`ZNTvr(_axA!{ib6d7Q2LPG#MIDg&TOQ&1cPA(R|rWFfEqB@Rnj1C;oP$ zXLOvviF*#FNEttw)V(<(ekyl{7Qo;9OJcEoYR*&0_g92MYkX{j2-ZQei<}Fe#(!HV z_!#%h^gwKP_t7!`(J6v^S9ZE+?g=^QMwT{Dl~j3>nE|FZhU&;vkQ4URgmQ`C z2CdJ=Phvj?x~nqMITri%7L_S0C}~;JfKWgTBLYLIF6uD*jYwCAXb+Vhz1`mDwkm);qv&msC`7cBt4_~ojvTi8-jXR9P?(=&9=-3DETVYQ6eM)k3QEIrwen<5ef z`IAv^Q2H#RPPsH#FEamWnh{!BGuI{b5@nM17IsdrL(O+K1Gjis8&;A=Yi1723An?D`h zHJg#Qq|yKD_ww{C^tU^FIWwq=zVgr#W(+j^_{jMhpbm;gUeMR!3o9z-Ir)jj#vU<` z-@GcYAc^^9TS9soV+yHVb760n6_8?lcw+`26XNj{6#ZZ~p+52k9@l$h69r=?gVAlv>EXDhfzpHPzY-?O?cMA(vKRS@tnx4)7 zC^!>l!d#^M+Rgbz!u)kP*C##-{JXMdz6tACSiU9f>rvN~`@>$AunepB1EWq&=Kl#H z&{R*%_X{+v&iETuV3TF(6Ux+F={LsGcnspS06WO2k#&34LfT*P#qjvEosykCu*beB zC?5ELJsKT;0f0-^ie#8oOZe*;QC}10^N_N`c{Y;m3=yc=^(s7=5WHalDL6pxX~)#v znjCfy7yC0xSz#rRs`xvrmLW;+0!-j5M2PDDN9g>cgc?*8pP7EZLvwhqs9wRYfDqGj zUS7$R_wrd8%pg7*Q%PtDfB0yvINBRYF;7eOgn<>YzZ-Y?VhwqL*wko$s;Hvkf?Lm< z%-rEILAZw>`B?L(+3^Q(xPEHp2o~h9K(wtS{?ql+?=?GsW{IZ&U!k@lzBA*KlOoTj zwnuWGz5hP;%Ua)fb)1F+9B`^zYz*xqR22Kjc=4Sw4-?rzD>@UkYv-GAB&=!kdtF}3Yen3n_!TmR42qk}w!v*LKoab~wkt{s zS)Vodd4*yn`5}n?zFi9AIZ(l8zdeLs|0$>1wAZ@n7~eVQU0{RnujX^cd*Q2JgKBw# zmzDR_AAW6wf7ZeawGp@|?u<|SeG#Y^el53 zZ*d&0k>ej6mH={lAI<9S$K3pXKe6dG@SqzU*=yC}yXi?57yP__1Yn{X#J1+_L_^*s zK;=SNTCx1(j<5bqATsblHf#`7`+Ic&oIuUqAq&t@!*rH5Hnv(*&97|<%U$Xr0Sj~Y zu9aEQS_kpBcqcX#(jzzeT6y!;M@@)RAJ`c2mh?%;hv4V1BN6?v8To%ZQtB-(choM4 zv+a|>s5THf6_MT{rpAE>0a}3LkDxe!mh-fh^Lux^C2QSKF@0NnYXiDDAMO+m| z{~Sz94Pgdoj%j{&C7?#W0)CG3A2JpNLY5zv*72$+fWqb^P{DWd^99d?&4bp!FP0)h zSZ%;B2#p{_40u`QWW8d5$Nm3!oCxGAJqh)^<1Ou?(Bb*Bd72vVeBJ%_MU!y?d~&1r zqvvdY^B(&Dp#p40ZXOFjvs;S5bMukZ=C;18Ad}D*l^i(1JRNHa99Dz402%W9_@6ze z1%SZWsi)vJ`^$ce#9G1gG8UWWO=yP0&7&7fKw1M(vj~GOzS}q7RG{q%1Mlf*fEE;2 zR8)fO%-KH@0TF0x#cV>ggcej~Aj%D&>l;0cg$dOkkYBHBKj}hEaq*clp*HGh?U7Bf zW0frcGHBw{*V2`=$(7Uil~a>^!1KGq{U~1l=}V(-R9B2qf_D6WF+C#h6-5912tvEr zrnYc#lBM|@a927utTTJ;`zmWm(0vHNL`LGJ=mZ9iTAdn*Nc$4-AfyTV5AP_Sr}5}b z)X}v4rQYG?Jv6lS-^(4>pV_qkHqaXX-A2zgjaNd4?)*LBs{+>hDm(Ma3)2oLM$YYP zKpOCV(6i4S|Gv%tbs}u-#cs{X!<@FIGp5vfsrTYzRZD&&CN;#~`_DdoWldcX7kL_q<0+0oBXc@d6nJKl;()5)|6WLj);SRAGK+7`qn2j!+jk;~t6Jf45? zm(yZWC*QvTK4@4}Fa1zI9`JQbSaMJJ4+_8U*^=cQk`U7`Vj_@}N_*P;8BQNdU9R`}obE zi`P8ipKr+(;@*t1_;RWaq_4i8Ox$~ZcJ;~y@Wh|km;xfuP>-JH#1E>bHzwMBF;x$Q zZ);(_SL65{q)|nC> zVELh8j-hC)h*EDipOvZWa>Jdif3-Ru-ESWLWCT+}HOpJwzy^M;TqPy{uHB|{v#)Oo z1rE_Z^$TcTiNh9IVG6Lf|BXVgnd~r!6l-g>0uBJZa@Sy)-1HW2n9mOH^db|k5OMcz z8GeVE0*IFfodzaV1lxb}XL(qNFuZy}$6`8l%foupCNAQ?N3qtVVm4z^tP{H0#K9(9G|>Ov1!M<`!*gJpmY) z(t8uB_#nqV;?Lh-L)`T{VVd3mh@yDWP5ce~ro(%@arKw_hZ{fwos=+DEv*JG3IoF+ zxc}l5o7OWst)UJsyKMX3BNUj_y-c568T$*#CUZ^4FcGNy?bp**e>ec&Vbym05A#mu z0W{%|D|rE^4Zah`85uY&%(``emB70DZxpL|V#tqwa7kZWIW*YU-T zHmRYp&Y0((Pyg{B6SfP^Ka@z3>L7{d+4lUAnv|5o0cb~qI= zRc)dk?GPB2-1NEWj3d@*E?V?_^}k@(RE?wgC2qyHhma>;jVH%?S-~>ac?96&Po8W( z_g|N=W%++6ObFuvn;-kS@~5it-B-Y~i_;p^KckPB7*d*&a%}3u%UcjJlH^ky%&odV z5T-WD=VH^uIH0=L0A>>AzaQhcnGxA`ua)JlUCk!5FaDCOnzp|;xZLTb8UQ}eB*=r7 zdN+>+=JQ;U$3VXGp915hGtGzCcJJUYXbFcWWZY8CmO-=h*I8k?RnSCOd%m0_exiEHJd~ z=g&jDqQ=#Ic}-u$$5&o^hquy+81q++-rpOJR)e>?16B#`7UGuljveeAQt+WV<`uKI7GS82=Ign zB8F)B!|${pi^2@*k}O9AswU)?>5X>@7=xN4m>gDJF`}GeE>%1`_SjcAkto6o&>hk4 zjHy2TN|Y!Wb9J3|lQ1n>1x^7ok$dEpZ33uX^UmSjnBS`lwNbi(N8!rfs5RuQu3pjJrNZ$+9!7#+W7KGiNzM9Cm5{RD^~L9SJHJF3Te2_UA2N_lV}qYojNd; z?;16YUkPGX`ImvJgG~BLp5a06Rxt9;m+eN&!-~>;p1Jwr+`!1a95(UiMQD{S ztRB_G4f<*J3alc+_}F(22~p6Sho7AB73e665!6%}A#Z+=Q$11a$Y*0!Q{@1^>d@)l zb#zQW?i3+d?OP#)*m(aDC40M!Evb{*&d&N8tTR7bmD{A|rS+ zM54|aUf{+EyFIp@ zv4JS|bgpn5x}45~Qq}kI+Q%!$4przTm{fN#4IXj0Q(PxuO=eFLoN#%ulJ{Wotq{da zy_gewBAGNLqfdwV##bzM$~kRm{A-*C&p3vW-Vv!_LRaKq!h_=pd1v}Z^({J_+-R7? zX44ENbnms6Zu4Tv^eE*t==?Lj&6x(~jr+6L^PN#7U-^F9e>u_GJvk&&%Cokkm>{!4MB=qfulI7gY{~5CAi-Iq zue|9uT?bFfXZMgYD3-K>oNqs6ZUdt%*`ymkxh&`Ob#*bl=CBjs^2R+1vr2`P2c4R- z9o0TsqBp)fC6P|lSicmZllak0$w8u8A*cn|)>fa_awu_Ck~y)Xpnh}7i`(cMaR#Gr z+OE$Wd(JR3=LeCj4pH8=$$X<;ER*5r_m~sLgPp>2!K^SiW2=57I2`K0b?%l^D^Kcq zX6yn;=7CN_rz`8c4b}^m_so`!VNP7-Z0bRAlg4mF%@lVbFEPl##5GC;$vrDboBhtZ z-!6(P8$H5o0yA-)3oVu`tpJl6)0)C)9(IrT2Er%m>OqC>P&u(sMMNL$)UiA$gxdgB z_u9;b`W98kT!rISuVzkCqT0+aEf=#BO9D&j%R**-Yq;P#_%;-@5Ojms0C*k0yKAwN zvSLoOwFi*-WX%(!fuH&Cv}-IgQM;U@&1QWNT+7U#n-LQ*vI7WDvVz#BnT!h14D_(>x{$-+yF_JTgT-Q+<0tyn z{$<0+r$q~bCRh(#)2MQ+aAzzGi@~5xs#3pvk0Y$W?Z&>u_SaGxRkF_FEpADi{?ef# z(VgWbi6v0Pl3DpWF3tx#+go3^PGUXZnF^s4SKwaH0t!_W%m_tKoISGB56(pAm!vY1 z=Vro|1U_99YDIGT0bAJFbV)_Ta+aHoISji}EADj;=%7p%a>d(?+Jc}{hdY&(!>oNK z-sTRf#n@y=HoAEy^NT#i=K6XERJTc$V8uUq;PIoo(y%d;IfSJ*ld;-%XJePDo}b=z z-{*Md=x$=F)Kc^pu$F|oz)q1iU|G+C7x9L9eSe*b(6QT*a)!ylA&uJRVd(a;YDbUQzfTG{Hu;x@5!5I!8TG{8Amhi>db zp{_4Wh5b3p7lSDmzS@}R{H5ZS|@tl1$6Ge^3 z^{4uIll;yw`{1J9NL4cCpF+`R5SRY+_-DDr;kyVh)a_$djNvFA3>UjWNFd?PvapiR)55yk7fegw zVI~xJf4cj9z50@z2 z5H@T6wkqAxTldl;e1AQmC2LFn_P(Bc)duVT{A2MoU;d_GU}tqR4{7CUe~LTzEI zKXd4UZ*1{=ji6EUqTp)H<1fe;{kemjfzr+)WYulu0SDjhNJm@9ZfA$|V?E^(Ws%nEq&x%5ooyf?ICTCmmz!PtO>)5BjD6qP~2w2d@Vv=w}FIx-XNW8 z&Y%QSskAEJ=KYPu6OEEwsZ?~gzA4;!wQ)U5gSM!p1TS9M>D&|LS9iN8ST@Gv2FevH zB>BH{?tCG$?AMakzjHxW=Vw^28cV{0Nzr%?X?L~H-)NfG*x15?+nlBwMpZ{-^X{E# z6Wz^d^=La}{z=XZ zufN&WiGGtsZ91G!KiUf(AfhgBXa)!7dVgiZwALrt)RuKCk#M757ZM|s1%RYXN;ggU`t9mq;vBC&Z&)) zK8!xMm_f*VeXc)kKQ6z|g6&LHbH&h6okY=}tVidn2J34hB-dz$ zuR`c?lDLsjI_bu92 zy?c;3GLg{+1^;DN zKWJHAiW^>;EsLx@5*{;F`SoZ?j$q6^F};}Jnlh?ZViyqHLSM&`NEBD z+WA5tBXpL%Ij=?dD)UL-z8Dw#NK*6-Wjq$N5dp3=7~%(ddx!X)ToGg~)Gu)DR|Uuw zHLLjgq)D z!l7L$BRr0B7ksh(b{5pAo~fv{gpBie^K#9dR_ z5x-~9enLk8H!p2)2E^a8%YFLw5-fPUJl4WcESOOLd<6` z1=ExBZ?%3YrD}$NilAucUD*l9@F}$!Jz*qX{M(o_6q6a4epkILK|QYYVqZ|dkImP+ z#p<8k&(lX3Q%mY}aoc0{z9i9!#oauA$IiK5I2pDcu}A3>uaMyjyIuxF(R<7yjYoKi zD@ic!$2BI5cda~#C%vzkKBd+%JW}zV14kZVnmV>iJLaa}m)@V7ZWCxM+4ZCJErFWX z=fwlMkM5DyTLsroM#9I?M!?$AjpV>cc%>4iAIEFI%t<5SLhp$aH_pUQam7}XIjv)6&G-qPhD{w zdIuZV9R;`uoj={JTVLepQBm`F8;Z7UarqJUfzso`(KK*t!?|-#@jd|Odl(;q6D{zu zIYKQsR}4v^DLR=GHw-0(S=gr`1k`}pUacdoBhl`pjN|cO_AjmW)ayO0~mD; zNFC2^BDP+`xrginCc}F@~cZR0eN9+jH6K!i1#=NAcoVf&}rxp012lX#EAQmXfmPib_ zrv_22`+%S?Pdl)UrvX_tu9jMc#{qGH|AlkKajm@Q$Ovi(%IT>a_mL3chuQ1!%OiEcCYEyL%Z*1ds z4zvE|;PEBEhDEI1Ujdcsv3~cHt5zBxJ*Rj)4spznOE>Q^7+lo>e!)FU#V*gjz|@ zwX&L}?Zn`KlTt{g)N&O2)l)<{e?%uMAni^R(NpO7B&4w!=pTCACi5ZZlv2^hej%4d?uwWcDEr`oB)X1C}&+Y zNPo7};y|f0*stImCp{d~%-A|Hbqdt`^Q3nz!LRg*;2lw2?yf$te7$3*vd41j*fp}g z3g8mR@>hir{_y|h1F)zeI*zVH*;<)7Qk6|S{!e)wn>orC9nq@0YvK$fz!jl(2>EC|8?_OoZn7xNY`2voRC6)Fj6}9QY0E?^WvNgzE9wMpY~MdWjic z;A8dtc?OlMX~bBYk-pd7;=H}+S%l09LqF|=x^FU%!Aw6z)^}w5(3QV8|JBkm^jec} z3nE}x^pU(n-|lf81Wa&Dj4;c~s7*jJtwbphj*sg-kp{`Wf-H&l1fcNevv6>l8S0^C z0}foS(2^hw3UrgKYuoR48;=o-qNu#=86m894+hnp3(`4C{@EPA=~a)~9^X74DZ>zo zn9+GYHDVD5i~>Mgu5G;{F~RO6xrl4B;KQ`+N1Z?S?iuFbK9lw=RW=VF2&;hO3P`cL z0FJ(s5eD8$m9&J$t}vB%h9HH*T>a!j0P;y5o7&ix-v>CNJZOPh6%=6|)mga!T+zC8 zsd-q|wZsA|a%c|GrnaZ-#h}`^Iku~p!4630g~+I}cM77%H{Zvx-C${AY%}ozC3hk! zF0!Qr9l^ut-#7iz39n3ly?a>u%#{N&%Q31~J0_T*urd#eKjPT^>IiA(MTTSq%NnA% zSocib^mhT*Gp?3Y;2Ms*xrEF_2bElBO8c=7fZ%tsU1?Prq7;)dDnXC@-vVVx1K8=T z7xCFf%;^Kh#fLCGyrM4iwbagz=?>)i&OE1Ip71I0Okt1%qte<>c?lftF$%-fjv(ard}Zd# z3mzB1>x-OQ?Tpj(>6VB3s*dJxi`3NvfU@Pvw3DV=qwR*xFpD4M#-L1naNj8im{NB) zJ12~i*K3U5W!1Jgdy5OG_$W*NQyy?|h3i1D11uhq%>-h{AI{1V$p3>*ZlhwB{4B}1 zeUk2(hZz){zz5tbW#mNO%CLLgZ&WK4D3ADEf;Z&<0|qx0twM-Yf(-t_cPd$%rbz*MLXi0w=!p*Lh>s&2J zU?9Au3^6bBX`M&iiX&JNdGaaG_@4^=qKVX{*rvQqcaGWJxuvh_mYsdADOuDJ+J=q( zHfBC(2;fiQp`a5t3o^hRsC$x;yztZ{B!G!-1(k%9#j_n-+#*R*VIPcGLq)<*rQ4kp z3`tJ}Me%oAwI%xp& zutC9`1Is|-zZe5c3mY!(@x!(NGX75TqcGyNG(X%caHfk$`jW-r@Ql;&)6s)L9JKgl zm@RspnGXe5rgF5IeB%prTtV9EejYBvOSKu?@wO9Y93A$>)}c3H(!*N8wBX6XePrL6 z-SBOSZQgcKFt9_58RMymMwFxNu|RqgGXHzlj>~vH_A=<(@8G->+?chn3m&^15WMs}6JQ<#jSzr$vja z6UiZs;f6C3hI6<2SCbA|T)rBgyYw(V4`}SAR+Bb@C><=*Q7@E`MF)gXRfJ8Z5`j(Ud8&z!L%5MCw&Z}zg^5Jb#0?28FN%9ntobjdQGP%4YGau&aAi%M ztQ!FO@FBh`4^xop*L9z~SrE=ls7`E_0;2clp@G7hP|j)I2h3bSjGz7n_+=?LsmY7} zc5EGDyv~xP83T)2HUz7Z>O1n48%!cYKl=diQ8#S8PAXWyErf~ht;2Q6F)c`ehm9Gw zZAAtEnIHui6LcbPpbodV%bgA=uKyBHTr_$S4Ln0f595Lzq3UwY^0M;G<36ZtB7SkA znKbTCdeXaR1Ajzc($#^UIt{*3hg;w|0eSBLr_siFKyInt7q(nb)jj(wOwa!{p?aj^ z3DGE_8(ZGh0L+Se66lgnoBwgdB=$3y6xj~nTC);u3c^ItL7=b2;9D%xZK504m}XoZmq(cSO3;*HUK;p+;vRiwssvW zLbjIvGLgF-S)Yt{GS@V|4QmV^%L1(w9cxZUAaMJ8=#pg%C(H==j#)|1myVeW+>Mzp z@54h2HsR91S?CO~I=Xg^q}02CPPmpf3hL}-Uebi>6D85V;K2R-~;;0DTMF>_#k}}OsugtpTE9gS0>wKE~t7-+~S>;eh+Yg&pBloUs0!MDa>9m z!kjLgiA+V^v zwMqRj4u}s{SBu1tqiPgGp&OMg!-qBERz7fQYHo9#iPizHsBc~B;)zp{cIFxpaKBA? zCR3ym+p)CI5=0|>-Nnq8AJ8&kSqPJ-*9AO1a81d2QU?WpjZD`alcqhGV6KWCEpGZ} zr5EpZaXg#n^?Oc9gl+SGy-G&I3@>Nod*xUWtxUaRq8+e-vBUAq^a?nvja<@QJ;js^ zu#fPq753*R0eNd!sXB3rG}S{syN8K=?7)7|51U#V49w`=6Bd|kK{$o<%VYOjTE8c| zVJ%Q(|CjhXW;XQa(P!FMr^I*G)vxQxs`QViOKfNTxd+vTO24&OABpKux4yroTYTx{ zGhMs0n|Ionp$U!p<9=^ee+EILus9y~5Eyb}rcx@dt*BL83e8$^jG7(42qdV7)iUEw zgc!M4;~!u^3*squ9^4d+z5aKspFv;Ow(xIqZA5+iO+eu};DHHpzBxedJf7ge*j`Kg zpm&c2{;c*Vz614qc=u+ebe)k1keX;FGyVuq0QIgID)@>Wv)1l~GOPsn^BR@Ll-(mNzymCkehUS+Y;5RVb6 zy57-tU&q{eDMEcwZtQuv+*^@-f_}w1Xo-_Ka5sBV7oDyTAm+x#aL_p+Jt_K|bHrLD zQ-;X?WmObqPt`e{V(T!!!7QiDto@~nnXeJqbsOp@2-k?2p)pxM^b5PvlS?~yyFnIz zEk0kAV|P;iym<`6?IGKhm`Udy5IMlfs_Y}Q=5@dR&&Xq7-cvzi=Y9ZJJ*?ljFw@5Q zPj=nD@FBBD%G0^!C9&jPb2pn4suQ^0ApPfnS!af4zgRjG64PYuXW(SMm|L^tT|0fP zj#8MreO>RpA-8X$s*jkiZ7Wr;*MSq8L%5Ye|CAZ<2*A1f;p4o&_(%(nP*1WGfwH(r z%hF-;uLy2h?=n!^79LfxKkbKI+1bqN@v%~{T#w|A)LLSMLg4N9ejt8XwrENOWW5Q= zuDcBkmGp`=NZoP*8b0L0v~;h(znx(ywgNldB!_}$(`D~? zpEdP6?{(LqYY{iw;;m)LfXpOsT$q?8u6Wd(ho;UY2cje(A`q^3T3Q~4uJVAuo(~^z z8AQc@Hx=oWyUYbvWIiAC_6?!cPOAL;^vg2G%Bf{YWQ`!;Whp1**XH4nhOy-`*?z6F z+D42sh#!c<80b@)a^uPs#VExWki4@)(h`>W<^x*ohI+x(Yb1z^e%v1#!q8!YYiPQ~ zgbWB%W~gcto%YRw1+V`eiO&i3n|<&=itzwZ4uRiq`m`yFaNLaDP=_Un6tHb8@ zU#Y0m@0=>P4otGAI@4-4^n)}>m=rixliEvayKhv&Ga%zw4s?@A*$^{>2R)zm#~Pi; zUO|&zWygsWU9tS11`XriO%YL!js7d6obON~rTU)NBEgwOW~{r3F%IlAWY_Mqorupj zw)>NI2{|3Al1du4wVNQzgY4ETU!gKC^JFj=0VgJ7jtm}^Cl0|SZu@#@j(HDt%ovIC zOxKjhOj0l}gjTaJ{Dw0jjj=Pl3PZOr^3v0#W{Vo$Vn_kujO07oIF!W2O3;uk{}leJbClCzdg2l%B#E+Uyw;|>+Qmw!_jrapy~DpF(OHyz zt=?BZcvoCv!)REcbJ5Y1q5f!V|2a!joQJIJVIOjUI%qaU2e?pQo>lJw&5G{i z5oGrgA+!uk2eOMA1^n$4V#Z>e)B@?uh&T<^oe4wsYPk4Se5k9y4wc_59^n;*~@(nnf95P8y zfyNZv%!O*?g4IFArKo(kX0#$dTXRlejPrZFx+n@i+!J{-3&%qWBE>6P|jF=@5wJEA(x>LFR^L*B;`WVd4$xrL2s+6$&7|mm6 z9701AEHx;Wxkj?WVVF*^hLF@mxZBt2WKj)h1CtzZ{TfR51M=814X1je2sO?;Rm@Gu zHkE?*3j*R)zG6lE8(Vh6;r}kto{go6pL)@6SO?6xifxS3FBDab)*lwB%Wp zl~(duWX$P@Q9vRqd5l4j1SwG$fDA!-zfNzrm@J#We1r_eEab##NHpSg!&xbpbxYkI z{y^wGHIjyu?Jpo)p%qrtPuhSwQ5CM;Pfz3Q%>5=;E^7*BXe{dMA$~8rGk?37UAH;B z%3t{+!(lSlIKbCh04oxAbBeCLApc0+0F+hqR_SMMFpW*@!}x7G5rR#B@~QJbp0MQiW2)E>1* zjM`fs)U3TXtzDbgBenN#ZDNx~M2!RydGma~zxOzP@B5d7;E>yWU+1{a&vmJPZ7*;O z=?RGWbK*Vd00{EnL z?nR;>Q @TULbxZo27W7;syCG3~QdX7}|nG(Fnq>~XtIy>1uFZCC+v8;9^SEVm3f z1XK_lK%%kY=&#=;%0rTF{BID@$q1V0QVHy4`9&8|u-(>SwVRisgkQzLbf*|o({=Yn z>xPale3}Fkc1jbIl`hu<{EA8I}nO zSXL-<@1hHNtrR$Ou(og}f^@}NyUrB#QHY`QHwy6tCJcO*&MFX%j^Ut60UV+E$?^+*? z|4TI}o~2t?n|9(UqZBML5d^9A`Q$YPGn^Li*Tox|uV?w{v|vH6%$I?$bYXxQa)MeQ zo=9b97I{MdE13Hg>v@Tr;LmhNF*blA{8#a4JyMVVmr&Jdb|HE#ZyH%{V=w-&_$%JT zLS=b+4%>t8Nm;ahJB#|Wcblh;u3(2^K1rYTpeqJUAwKwh9ew-J0?-p4bZLfzT(Utg zx(|$hu}tw07qn!@)qo!|<9V)%r%ko8D81TxDU1JAAPcNqhtPwBEMV{Lnh^f+^d8Zf zv4B-hvAF#or*3#IFV4@63~<-Icz@f!AWRyTScpfF{B>-g`JA+)yab82+TGNVo}20H zncmshZ@hr@X@Q{B`EWcXZCKPy1@dN-MYG<2ioAm5de`fK%#qs6z~ttPY_;bcsofF$ z)T7&?{`%*RJdukU?wSiC*t54^uL;F|jV=cmT_U+)glydQRw8P;a7{!98DvtYG{AdY_{!=TXH zCvYuju}tD_o&`+E3Bgd;YYGOepi7bdW9V{P$Sr2R+7dH-Q3AOtyf_D;?k-R^$l9x% z@wNK;vNqgxAM6kWo0&&LP!;%Zyu+i*HOk^<^`fBuw(nvCa8(WC+VoTiTnM@oLinCV zz)o+gD6ZRVyWD{7XR{;U95&NY@~7<==YX>XSeYOu>|);V?$dv`=N`E4hx5NFWVyM% zki+x2i?gh5PW(36UO>-Y1|Z$yrpd4jgt;MDU}5z2ImZ{lRVG-&uAnKt{}ebN^gq7q z&5GgVNLU@-+XvfYbB9sxCL?5t*CY5_Vz;em<75$x3{x#Fx@p9vTby?mX#*~x1{IKX z(f*=@Aq%3GL^xg}0lZGyr&Ewvd+Fo!HB~=QqaJdt#U223o03n(OaXWFPSp8yZk+?s zy$lP2*guw~xGOz(_4@*TS2dx0n%%iuL5p7N|N*0DmQ6s-b(d3x1G4MEH9L_1`bZZNmhZVu)eUpc4el z0k@TY!396JX)hT;WQh44Betj~7RY*&j-v-92FSGdJ8xJ4I{#y-3ps6_DYtVt(4GC$n=o&} zwakon%0ccoLq#_$7xP)&;vt8cmT7kgxUVMe{FHCuOxYh`;_*igc4`Ajz1x}*$8*Xv zH&{jJjrKWmyIvdm_l!jc|IYOmdL8z3pB(qPIn^GXoPc+(h!=|PwvpwA_;I{$wU6(x zz|GXV1(r}YChtYZBq|oqRjkvHP?0w1!U2ygz5C<~7PSzY~^4 z7MF7FYJl5wq#EQ()u9ya@}GDJaATPD2~y0kM7Ont5~2V*lNXE7V85FV%d4r25a{{M z#gAJ_6F@|Oz?l~;2shwfx&+z(d7T!`pu~tvz~w~VV2`Q?86wFPxV{eq zAsVH6F5B>%?`|sYyl$phR^_lsih{ZmrbWHa>t&c~C9#`eXk14~epdGx(Sl~^@ytRV zy3_VVu5#va`hKqL4!4K zbRm8QtJuLgH-htyuie8rGKQ!r@8zoNw2_+-mOA6!jbQ$>8dT%xwU=E<-R0$ap#i^< z%%59_D9Y zw)wnt5_inm!iX%;3otQ3kU5;zr1zX_sY32a zoy-r7B?ahj3j(akO=KJPd=^-Y>-YOg_G00iQ34WCE)$B7x*9yEy;RfIrNP|f#+j9TS1PgX;}FF42KKvL`y1U~@$ zQw;`Mo9}m6La{Feoj6{V6fT3`ES}sd!Pj5AV2HmwO**MXYp;BgZo?Tre$CK|4Pf;O z`Y~H?S6_N1<|2W47mko&7Jo{xQF2jWWok1!?Xf$oJu>Y;-O{q{=SITi@Qh##*6oBq zhf1>&ral`N5OyG)_1H>tvM#<$V5qy^0694_$tr$}31jST0DFwsMCL0A&HMR%CTe@s z`@IYp-$9=+XeTPjT0;mF;4HV#JKnMvxbO)qiRt!wC;kTyEW+13Qb zKj+BkwORxD1uIV!CRIAEs(PC+c@VFuS^T|6ovJ<9P}@&X#0~(;<5)y!`SnrDSAtwU zaABs1sHL8Wyf^S!Dd$y{SmLG;rTuvA(LiAnz&9_))5K_}wm{lgyjPwuBIYZ@pKliC~TWIJ1*s;#SSxAJ{P5Fnd*>*>f7R2bA6XJ3!7 zF$OI?YR3@Su-_M?U@-2o+V1?fcv5H)DJ2u4Ged|f`sTA%q^_&2YnZslY*^lRmuJ1u1w`89jo(=vL}v`hu0ODzTaI}> zdS5StozvG!Qn@8%0oO>Go$io}$`^Ta(sEv3`X$oF>M%hgjC|2bhGz{&wJdp7gV9~+ zM4K;ubbY~V%e*MgoNc7Rvg(!9Awc6yTO*f6Wq9@vDN#MBChlD|@Knse_Tc-R|0gSW z!uaqbswC%3pC0w|+RY@g{PJG(Q({WPt;)Y6N*v<3r}yc}$(B2ZcT&5Pcx@POdEY&V zA(Aiu9?B5r#+<$C|EaDuW(q&FA`XdOkr@`(6SW!jAPUl|y+GKkG{tY3+04ro$ z;jt^Nn^n#6GbU4Jhm~<=>v;}?)<|lm`uWd%?NQ;oU=$;Vl?RmveZQ*;_qBTP;fT}5 zH7DG;l`6wP)5V0=3CL1@Lx|c1zG-GG2AoRs8UEfn+>a<*Ro~Y4n!IiFjrqCxk*rvf z8`x3RA<-H(p#}^zqzE$m8(=#l^|`QI8)7vV{;^$>$i|r?C^Djn1?uXSRN%#pnq#0A zx&R}gPqaQ4qd6s$13Ag`gGCsRd?E>?bhZ^f0!WE&Vqchb+Z@^&4~_ZM<`>pR*LS)( z+qB$#G+SKg*iCUVx!<9KJdlgCf3q4J`o^2&yTj~~M9e|oJCc{J4Li|do;bek>BAd9 zRBn-sb+PeP47c)E!Sb(RWciC2j7OR1PIv(axFJ96b8ou?@KBR6TlE-lESj)DpFG-U9CYiar3d%T9!HfFXIge!Kc=u ze(oYq2@qGrQ$*7`K+O5h1kC!y%EOA5ybv|Z(Q+AO5-`F7F~-i0ApJJ^SAboFPGUye zw8V{$NN0whkdi3cDrm^5j*5V^uzdetGaFuMLV(CC!p>noT-lra2a#Xd8$3>j=Z6WB zs8c(})Sw%EOzSj~Z!|yB+k1K%?8oQ$Ji2aM+59E$9sA=(;0q?!3jW2_TV`d$(oq*l z=L9NwT4Ar}MZ{e9W_(@D&v?4~2W=IDwC>2Ml)%e{_!Q)wi{$6rL|5Ik9!GiHgwoUo z=Qsa~_;=#IcZ;VvYtOkHH?XhlstrzA%JJ33-v+)hu{!6aaY6S>Rk z5S)%(<5}?v6JO~nZ-GQGD8u`Rf3>o(y)w$%AL1;Hk6)ZQdZyZug7W*E_cr{4qNU#2 zuP;T-5}D4rlL8T+)!l5ibAycEUvgKQGQWU$e|hv(GkubX*OYux0Q=YIR@FJ)b2ROj zxbN0CkN$N7$uvHTudPj`d-u&J|nqrapx{-SxM(W*>5qTgQ)EjuihEK86* zB;9@JN~HV2xGU!!i-cC~hLA^&?yp?vf~5C*DX~?Lcc`#)Z3Bzh zQx>l8YLMyx0iF5Q!xr^z^`e-tR{Va`9aQwUT5jv&92Axcj0o}Tq;t$XW@u#UYQi;cx7r?B|5&9XbN4b-OF$Q%W~TkCK5$%7X;$#{%k;Jm*37V)z6Is z@!LOjS_dt?S`jfxBx}>+s8AJbr`WnHYd^7selFpYX!pn&Eqp{+EntfCf!u8lyLh$4*Nv%%4hRll_xZULIDBmmYDs;m+9&Ns{{TGnQ zD>2EBog36oG4W5Q&PZzbvCsNsoII_u=DDHv3*Hun0~%(@GRAW6=yZ}yIGub2MNGrr zt%{l{i}2b6gNLP&f{RvN;2oX;`VJi!q68~|) zwjMAT*9myvz3rNtPViBsPK;`gNmW7B7Ze`D$xby?U%b{gl~$)|AQj9Xx$Hnn+U?0czFs*k6VJ|%3qxuF|*5x&8G$GGf2D2$rvrh?-! zTRvm=Yuc3b44F)gcZ<|iXpKzoUiI@?t)L@87VF(#-~=A=#HtLwa{L02%HRyQu~qw0 zEYndiZC7Cd{+U<*4-aNu5UySxA74*^KI9ks%$U{)@C_Nbj%^t9JuGGVW{;~iA@E_W z0Hw#wz7Dl3KQM%7*tuP&{o0_cW8K|;=IYa8^urwd_S1y5XL`Ji%IY*H9%w^Gh@bJB zkK%+Ix7`+AtlOZQPG$wR5>uU=j@dQhOin${2Kra2sYKjmQ`d|eP)mFI@(LCR*38-y z)>XhgpsBJo`2~d%)S2vJ;f+fVS&x=; z8a!8r#5==B&t(kHcngwrnpyXy@r68Qq*wof;u3aFmyn6O-5w_%j_=RhA`MvLgm5w0PHkUaFPu)o#+ zkp{KV00vy{^?t}eZdnu78eek zw7xF{Wd7!>H6BUt1Zqz9w5{Ec>>%Wcs)_t4w6X^HrYuZB-->}|J0qZ`1eXdaOnckg~<%G4!`i~_JfmH7jMky{s$5f z4))IoT9_r}dwOEHFktKA#c$2v@oaI0ky%LpDrNMnKvY^L$GI6xh-paDIyfnR{TtO#v<+BE%qD#KRi>#~9yRMk6Pql2;bUAfl< zr9jta@n@K)J|%badB0inzhE5?Bb^@miLG&ED+K z%w`@ZVXgp1rIE?+b}ojXJ^q9VQq+Q3r(_Jy6Kg5qH};EpFqX?v7EZD@NfYHQ%MENy zB%0!QfgP>geM2NTyU>;O^QRjB7e14{jz7Mm-twM|r~7sPm%WONUgm%s>2 z@=myqV?Mz%tKKSY##k};cY+m@^;T2v^qmSerY!Pxg_Xza8TZmNrib50Bg$`SFi;ah z-8uOSo{!pVk5Ny4D(!#CaebIb(NmfwH*u3-Ima3z8TfuhVoh0z6Ds0aOYd?HoQ&wy#*y`&^uY{JX#-r* zgv{(a@c4U^pP$<4RUp*$h~G76v$Uc3UQi4BfwSLatz09eN286jRY20LsIMDpq`Th` zQr6qZ1QHx`+l{IkVEg_A_w%N};fStRts*{v=Z&my&Y~Q$(Raw%g$et9rBisMEPFqM zWr5$fB?V{+C^eg^s4-6YGCt3gVg5q=ev55ukK%6eCL?1;ue^51%zK)_A(4@B0!=q%caRA zsuyB*A_2TRD|)VIIoE#3@r0Mq*yN5O?glW&Tq)!tVfr+&d7E44y_Zn4Z{K9Fkow1& zcOblA_;f0YNSIo79<;wuu`_*B_)iS!+LSoq%xohZ`Ru^fhV8pSwPP?P5l32!J`zFYG7a@FvU3qtnU8?enHvT`LP z-%KG3Y8961R5ReuTj3i` zesOf%XMY77K}?L61z)D}X4@weNWjPtSFOHLZ2cFRA~=rDIrJnZ%zk@8+*>1h7hiT= ziDqC}?6X?M3MXb+-L+ZzRWt;^O4t-cm2X+_WD5SYRuF#%Dr+$!tNefU!-lW_pNCMa zXf4q`eMK7t`#HDun^zn`v8K2p>r6zHt3Aa4Eo|lNX_Y!pwnl*`1wm~{`sx7f(Q6kO z2D)AaKhvZ_y2>xd1w}kjbKq8p0+_j5Z3bnjOHF_XRANnyas^rBXLAw?=dZOrnWr`H zy8Gd9kv`Xqy#rbhr%UzzzEY2#mJ^)Z3}Td(D4|n(u(PmA9xBF4mruAp5s)Bk;0be>ioed}m zJ~mYDYIdqR_}y*x2}FIyk_MgrePY<25UI@TmA6@Zy|n^!9XeS#SG9f+-hfYlYaqg* zxf0Yw+qPY4U|;m^%Qy0m=U3}rgiO$PMv;yTIjxcxey{R-3UY85T5;IzO)hOs8y|x^ zp?j_J%iyP7&+2Lhr_30?Cbn0&+5K)sKHJU>T}y04R#rM@KO>0rOrU1ew;!i;%buajwQEcsy!44E}^?V8zQr3u_jDLEq!eYc6)OtxIp@?Lf`O;o=s5 zK}i?lVd1_g@^Ge};%meyu>$S58Lqncy; z>K$wN7pMqYngfq-`Jrf0rs&@9HNGNmd{KD9YC{W7jBlm#W)kuge9TA&(Y81-$w%xb z7x2k{9pT=@Ca|p$eqoz{@`o5l{=+dn+YUodCXo&F z&p-0&Qd_5NZT8&7|0Ryt`~`>~B3Us(j@>=2OO%Bxfx&As^&ee}u&?e9d4aTeMca** zGjTtY7qPxlG_WfxX+^zz(3V*UH~w`?NE(rK)Cubd&79sAtaRH{FRo_b%>33K=s=a( zeUF&Wuogdp@_jCckqn|s;%l=vGd6~rSlFW}jk1_baz1AvIyW?2_OD8XH|&rfy}mhdfYIH>FnLt58nd8HwO+8}C+h;o;>~xdDlHbFdx*RbEY|&~MAT z+xXr;Yb+I9x73t;dP+=f^OmfP4!c6u-$#ZNa$bG}Xku%;NSVkA%Sx)0bxHKq3;#io z;Hzz^t{-Ib2lWJ=h^i~5&q(_{)A5PARQ~AS6~G;>T@&%gXz^0Wm0b^`KiI&eD!5=| z@%3x9ZMVnJuK*6?0_41`ZF?J@)F1O{oEa83QtKQg=6JE5|FA6e#XS;0_wqXAy2g7I z16~89&5{UQZOUTC95M)~wYC-Ns2D4w_Ir?<(qtEUiorcyTob%HbwD3)c5s+T=SU-u zAeuS)?3qhf8FCfA37we@IDKfUYWbeiek6<>Z_ALmhr_G&wlAaypT~r)MnMWad-5X9 z-2a0!IrE^xZy{GZ;8w)=1FV>wDySH_oOHpoNclvFZSNvvX4@T>2H^qW!Zz8Vvtog6 zH$G&A~Dh!jn4Euc*du{Xb;~ds?Pq=3s+FM~*SC-5?fYg1fridBk z=u8T(#}=M-Ll^w^7&=usGP+zs0}(@GQh_;zC4hN9ByQ*+B&qR*Pp=LMIJpB}RI=ZO zoBG+v0+S1l)hW|0hoVis(WOU2N(>9(n>c$tBxn9c`1_Suy2||SZ^Ret=OwL49PBA6 zK*w(!p*^bJNXa=UbY3p75M?!h0?4gK2`)76O6UzL;3qwNjJF+V1qkYBQxVWlBTOH` z@Ke{4@Sm8&O^OS4ZQ>e5BOi{MK`}eWJtL=)^g{~5hZQZf6&*Sxwzu@t^J2Fd?}cW6 zL$Ul!DhK|^Id_=6_xm|_5sRn0{#b$*p<=Jqgy()niTh;G4DhTMF1LsQ3Z%nYZ~ zrs6!uzVvbQ_QGNAj`Q$H_9c=R0Vz@SkROHa?X5~&F#k5!UCb?kE_fXsEVTh6=M77; zj-+m}gN4GIhUT7>3XDxnD3;$?gXn}~<0w2HE`4)v30i*G`@3f;WAVfF#xMpI6j&Dg z(;tu!Dzbv?grK=G|Hd;6!09564bSSbDw3AiE%|~{!6GOIUc#NqOE#r;fL%kfzGm5yd_0#^I7ZHv7ZM8=c_@-LSJD|1;n)1V*g2g);zsD z2k;aB|yJEP$|zz78tTnG`H z4};Q2-)2(AFK)n({7f3Z2-kTprGv77>OY;X*f!4AHqI?wWUpu<`7(u7ABS^1Z9lsI z2ZV-D%a@qr?45e~fN;BwX05}d4VCNNYuHl9gO^mJ7uQ$N+QXIAQA8Gb2;87&5nDPadiHKu10D_+ltvj55{@zKas zo4aQ;%E((L{9RU$wn%C^NkA`Mb^3>3eY}znlUsD%5lM7#QW~ZTlfsa`L zHwp9ZXLUa2syPhb`}FtYT#Y!sN+d+WEQ@s&%lGBXUGl@=} zP2SPm@+D;FAxuuz&3)>^qO;LU4aJM~;KA%Dckbuz+GwmGAtk25$6DJWoj~rOL@S7e z-EBgsqL(>=8*`YEhA~N&%9B=*3j7q$6sQZmqQb$iT^L{f1CQL;0RVMDiEZ@Kn&#@v zG?b++0SBj*b3&;qtL%v})e6V%thaU$&ubIvv*2yvD1F!r36$wGmL>}>H5}GF&LfR8m+e7ZD&oZ>b|r?uUrtU2 zSJ^W&cq8R>?l>eN?EY#lL=9z!g{ri<(QWGDZ`4o0!OU~3q6@1qK#{G<;@Y_#sNkyW z;W!s6qaK$3S7LF%=TvcYe>DQUt* z$opw7O82M}X?ej1|8n)M+)^>DY4gZG-2!^alw($35puZIp*|hD#~}Tuy2biMm%{OS z5D)eeG*^F-k=P6!4084;JlU4u<~RF#12R6VmU^Clh-wT(CdJ+J3dG{r6Zk7-0e#<9I zkN-N@Zx*o%q^}xEO3i{WQfWF&7SK>;;^+hdLl3JMKZMdE%u>Z%Z``?YoVTIt2P|z} z`24im%xsbRKZ$+8iMhD;iGNsdviNd)KPW-%KBW!L@cc^NXR`BA~E&IBTt?EEPz%sv2P;~mT zV#73)eEWPnw60CnjM#SlE9mEGZX5gHA`}z8GqQN5O+k=g7U6z1K6^7k-iUAQlwM6( zX7?Go-isuDT=6>*>j6y?N7?clh~EY2Kr)5pM?fR&hUcu>Ohx@A8Gm^mB=g+GA@Orp z<~n2&Js4JFSV9OPrRx=nf3UopP*Ni-|lbX)2;oa@o5I)kKBDg%VPC0uq3wAu;0g{I0d+^kU_;bYO$>_ z;*djd`Bbzo45c^UU^x+1;~jp~i*~0cx)~D+PN|cd(W%R#YW*bDvzBDxv!TDJ93u1% z8QAFukQXrELy!z92ql54j_5BgNxuiYent>0z@I;Qx@kC2Fmnka5d-)QDm0%`a7EDU zJj{JYfS+7xFdwdShppKuWv07+2RD;H-Tp$Sc23jYBThZ-u5e6+k4L0hnBf8?zT!jg zM;pGj;aA*L!86U?B*Z-8j)1_N%oD{@djqPR4{sg&kX4%@nVU8au;st^a;n=XdWxee zuc?)$`M5YR_xOSUi;qk$>qLSJJ8wvp1S<7cRcf!OE-C+!=!QO=o)y31H|DWZnyh|d zZzZ=9l68~{Bx9;*#4%H*Em7o{sJ=zHSIK*)Ogk;jz6l@Mxwm=ykJPp6%hF6LDM!}bVH@nnKwtJufe2HiHj7tp<`WJr0 zYVfeS{YDx=uwb?q*}h0%_&Sm}ZV^n?RvfO@VQKbUeD~pq_0_y9iMm|VdSbM;^0yx-5;n!jTORta%cVrS`S(cr`)!tpD77x41< zh7U22RndYC@L2uXCllfZzaHvJ8yoNpacZukUjPeM$mv+?sQjP#)>7adW zkDo$5bCuhkbG;8L3&eTc*e}bhk_W$0w&U97*;gu;NHfAp%SCgHp346&h6AN~%z8RQ&Ld^h_hip>SEB)?U^fn{sMa$PrN{O2x&k0Q_sO+oF#C-hI_slo}6d^i*S2KqKdzT1v%0wzjWE#a?kk zEN7-6DUkAL+x-H*m|R0e?_oD#hC7vkQ>T6O>)=Cm$03%D{#bEY&Fx(Ko#s*~IE1## z+lZ%UKJ)@N`>Rurf*^<2y7-mBMM3G(QGTzy8qRQuG`YPYzVcDaL!xM#l9a6$iA)R4 z$C^A#sP}&Y`xNKhJRtBsc{9U45#`im8Me^w505ZSk5*W3fu8f}uaFB(d3Z%RY%}m2z^?umaanrqO*(0hKf6>>k zpK30;=?%wc0y0y&@x9p%mU*wErc5}!g)R%;Sd%(quN8(&h>L)_ao+H~+wPK@4exw| z$INgqZ7HTUHYR$R;Y#M9ct=VRXyjW z`_l~5^v@i4eL-@F9+{C|n@@WpyrwF~>Bg!qKRzqbBh6pZ++}`uWtsOmdSIS<16mES zCpYw~`I_7k8qED$3w&#n{s<%7I zji~v9O`qNd@rH=hLhWkTG19;#G)1odTmRL2Ii&8hK2QlTGNAC*rz@wc@?(Ht4(RXK z8)Fj;o3dM863te;xEvjqul5WiCJ5LJLbnxOj zvFI}qO15MH`UStikl3_%?U^p}=a}hE>7A5l7ZH>8Oq>mVc2#&!^(ClJ(|hK#zzR+V z1=1M$v!87gRNO>7z!OYM{e(!=B}h(EcWK!{UoL%e-m04cW>%?2X)_F`q-SF0`Oi@_WbZ0VPSlr>Pi2 z@H|6Eash*T)~>L0*H#f5_N?yJ{q&9c#?H!|*T%av?zSHA=(y36!z}=wBu=akA*9JZ zs+g(JYVsC+yZbp*#rWw0xGH{|GMVq(|3TyN#IoiaB$4;gVmRz+;}}s5&oO>5BK`JQ zb+#q;$tC6$-62Qw1DnGgjH$#`9(QC26m$uv7^hJ?;tAkr&ZP_u8KE|=| z^d7taLdv*)nmBS94bpwTb`6|A2gWeB+3dBA(SLF}2RLq4P#40XkdPy$vn*o629*#FZCozw{8@lP& z$}jtswVKgogZ|L5%jcR??u2N8rpWzh>0xm)h5U-3IQt zBq?5r8d(|&(=*vx8>klA&`e^F2~-S(6TBQ5vZ-DCoaG=-FTgS<@mfVKCT0QO!=(;k znU!}1=mU;CH%SXr=PPwBf6K}DRs3n1nTqY1 z^LJvFQkUqf5W!iot5kUjm|!ui^wdEVQBQ)BAUcHEczX9aYl!;cQjlI9z5 z7hsG9V0upK>JcgYNrUf`kxen?wyh0)CZUtvcdvs}TA4sxL7~~6FZ%oC3W{3<=LO+o zqa3*sfPtptxBVX`V%6N(LpByfne(QHS7i!ZCnZdo@bvM9k+&q?6Mh_MZSXU6KoVL}x^2bHeJX1)`H1qqR98Y7sl4 z*|8L&o62)1{P&XInbTqG`SC)vg3@Vc7Ii5AbV=xc0K>J>-*sQX}+hw-@Q@tlv%-d`h-On{Cg5!gmV!Hy@M{M zjvMb=*%@8deD5;wYH#d%weioE^Z;V z1=JxJ+2yO!W71Vp&v9zCj3%Pp!;(spqQ04#VSwa}f%0#c! z@nP13vJ6UAMlK-+!=$JD95O}PFHhcAb6MMYeP#aUjE3dq$`q6oHFOq-(QO9`gKQ40 z5L*>%)oG_oE9&6^q)1&%uV;cHbk^%iZ~S>w`IRsmtv*$kr^K~6azV&;Yl2QndttG{ zsMDz3hI)=i`5~b)UnqNq>r=8#P0kR6Pd6`Y6Kd_K9e_R0FyFg+RFWDlb(82Hw~lu#vC&TGkdOUNyEBW518^iD~f z_GlR+%;yDvOj4B8+$_PXNy_;kL*@WACrS8<(`An4Ng(9A-s8QvM&Hy+xap~Z>F2Bt zu+Q+@aAmy9Ddo?TS5p-kCFM2we#oWx(?C_;^3{GQOPZX-gN*S)DVpb6HS1t9_eadx zrqb4IXp!b29#zu;1Ke|KEis9tl-$lI(4R)X66>;m3y2%l;rH!pWRT(A`|i*>(xkYCXKzRKHc zeV%|g@3p7}@u{}6y~ltim}G*dYRrME1b@W##oV{vbvx@R(ACGL_REk0$=OkCVzByl zs!P5L(`%uYVc$(|cx$Yl8t9Jr`Zgu80QAz!O6rY#%Cby{)vH9Z465?XUZbvx9VR^` zPZ-}bf!Lpm9`-`-hQ=eaO}nh?Pzo)lRyXLoSdNuP52J>PThU7vdi!x5Ei$dI9{Ufc zr!*WGuV!3oA+I?Un&BTk7uc^es^q-#Km22atw|O3d=p4xaYs?<(k;}I4)6Hnv>OpI zA{$KEHRc?w*CziQH8%BDC-YX}@ML6k`UmGrNopcP!J#)!()B%-(vdtOZYfc2=qq2?j*GG*K~HSv5f@oDV*A&>>v7G`%hQw2Go9(Ip7Dn z5VhmM`}|m|?Juyw%>vb*so0t|j%n0id1|gr27VIH-z`g@tP%1&giU@UT4-eXBC~&O z05I^=8;X0u3S)nK-+rgv{&NuBGX5SN)NXB06&7Jh*ms_1ymuPS4pLA{rJB)~3=VLn?Tc`flPU=bg;j}9$;DM1}%YI=|@^T{3| z*I~HUZuMiWT6^->t@BAGn}LXEViy@s0SKo!m@!ST{KP1(hXNW-7QFpTHqI*QA;TxF zBlOxe*!j>CKD}hjXftJgpkvg@N!55~u3;#)P4${qCMZYdIsb-;n&VCSLE-ziOq16c zIq_m#lP=v2H1alr`;S{}tP-o95aXc^y_7eVf%sr6;vIm7F8K#L z6fIE=-8T9xSBWRF$ii=SOS>;S)gkJZlGeO-(dp24I$0e~Bh5Ji_BxN1^g~r3sb#PS z8;Z#HA50Q!W0^P)EF_hE8=ax#7o+Oaav=5~4{3WZmzcSyyz8+%fRSK*(vBFhndMOtr7vl@XTB}Aa|4Sm+Hx{PXr{N@|AA;ZaWCZvR+B`xgwkz>YFy-o;-80H;}K)-|-+ee&aot_Ui;9DPjnz^6|_Rs_F4{ z{rzfZ*X#6xbY$FOQEEfl3K3Uk3O?iCCMFzp#d!xpG*?mPy3O9(vUdYYQ|73g!P;xN zGFMDYoZ83`JBemXbd$>Xt=W8YKvR+48~?M@={HB4Id)m-l-kFeg`>x@_d zzphuT&R8#R7dkedJ~r9lfkyD(`Q6-5$L@tw?M*cuOptftR|b{dZ$TJ5yB)T&x1KQn z@z6Ssxg^xK^Ye$bJK&8UHtJ~0aZ-S>9aE=fU2Rdiy)~3~yHWqpruDlJ=FRIgy3y@p1WQhyQXsd?h=0 zQ(!qKByX?U`>oksWb;*Mwy|?|$+W!(l1pUT-Hx1TIkCDRdWT3@M3>9vyxr)A1VOO& zx!LpmF`Z_Yi9y11st_}-A{Y9@G(4Y|)7Hjp{SHWuee(5RN@2G+qF8Q(rh)w{Wt*}C zH?IWGIJ5D0PFS@Ld64GsR}{!3B$wx(nGOcAo1=xXfvs;bThJV2?Ku}u{)SCz^BPRDt3$55HSA$}=<5)sPS$^j zD7JScG!~`Drfo0KdY@A^gCL(=t;7n&Q$5y*>Xq5kk7@dH>Jo&rSYEP9KjH%#YPkxl zRfXh~Q`q!fd3i&khF@By2&F}2_@OHeRmP|MiC8w@ol-^WfWTH%AfC)oM zwBtMNHm9278lLZR93!rID~3wFi@t%p|A;*10aV{R-Cqn&Fo+P#+zs%ijEW__^V{k^F`iivSP@G1z+LM1?TMO|B7576 z7^P`V%?+Yr#w8q?MP2+Nl7*j}7TNnul+-u{vFRb+J@=)ABRh8ixTBBvQ07dFFXf-qk$o@T>*C#L6A{WA*EwpD8F^dw(T zQ+DI}h&pnss|j@+tH%YD_$sN)HX7qxR?r}&1VgyBzuCYtK9eh)^Y;ieIe%UZ{VY9< zN-#T>xc<=lDE6Pq$^`xnJ&N>}ed|K{>f8M+V>Nrhxj`Es+5(||6|*#xlcNYN^GVyy zA-p|aMQ*Gnkh1z0bP48&B2Cs>a?p5I)^-sP=`OtqkOY=5PVtQjvQV{y$O(bJhub^M zWuIL8ksrLGMxs;=yS$|T)F(y+5Q2|rzQC>=4-XHO)1;-XF^>5c{j%9 zRA5US6Hx3}3ZW_)52NeorK zZ?;Rv#;N=foK|13t*BhXaM(Y>0Il=Z2u;wUIW2NKG6%ZBbHa)bShF5U{e~UMIu_Op z*jlxPeIo;d%@j5tD&>e~Gbr?DHHDs*`AL$WM4@))F~dL=O7 z7^&*Fs8ZwQh9p*wJZ(6S!tEXJ`oSkX>GyD|XZ#g|^gCZxh!{SGcMa&V&Gvs~#6OQI z=dGU#9kNeq$9xmZd!#9dhlnN%Iv4D-{oyw-;RySo;lAq~6bO$z|2O?4hdxpk?%NWi zSYvc-BF^&`jsrJezj;DZ%pJ+tBO9tXG%!1lrov8i>Vn=s!H5^xTKiple8(w2#=W&! zOs%tIZ%CjhJyR3sD7>Cq{^UDS*|WiJ|8o~|sRncJ*Z@%iHT=`Rr=My=!3wBzGm+s5 zY1^C5=qqEq6p-wR>Z`ir_~Y*=b`IAAeLy^N_cYlOY=<`9HQG&PX&+RaUR{UM!|DyU z?4@75bxzeX(2-l?j(=f);1iH=_vlKlcfP{Wduw zbx0|BC_i}Uu%>4u8*N1e^5QoBMs4$au41d5{XEyOzT^fq5$WGBZ{AxCz4?v_=NyfN zPgokFDCc)LWbw2a2FX@ierV!pp%03pKB;12PhNi}drkBu@ge`p3CG8$$Y9OS+UQ^! z6)14ag!ywVkF~yPk4eP?x>kFBgEmzW&K}}Rr@=T)$oEr)jK5}h!Dea=g~e)@NIvWuEOuu4NP!UtAHg7ibb?vaM) zd3>^v^pI1v8A<)>hd6kz_pu@Kn_c|)dx;Li-qS&?m_L#J?KFus3R*n$Q5HITEW`p3 z7BwD+q!N-u;5ml%D@1KJV*h-IQr+!D_ZV&59Cmp1{EI_H&gPH z2x$q6kWM~wKddR?898Bj%1_)68M-qHU%mQLeL3x{=z$he>Hf5C@( z6#|8Bcej>D+V}YO>dDEgkkizGwa4h6tCvK(RE1l(4LF^f{Y}jhMj+TY6FgBg#~OLk zHg-!)*9WhEH(ojM>a3o3M}g#Eo8y5E#q_pQ%>Z1?q!|CkvS4!bmQ_}G9g}wZ?lJ9k z(cMmXFGwh@PfDT)g)gymVA~rj9%V(>uymuk<`Vm$fcD~g|ZpN3#RHXO?e~ZdUN9=}NtjYmcwwTfBLbFY8 zV@_Y!$07UtH)sOz_u_ep*wBxiPO7_Uikf&XOk3K@={?K>)hKjdn_s;U(24kiuQxW$ zR2T0yPjN#=M$jCITm#ESc#ro9!?j_Py432)Qq5dhADUW6^-SeZuLYdtS3}NS`{Ygn zSJ=^_e?Kg(@?OwmbJ8LDaAexc3E;?8JFqyx5QOw{C3S z_kzKXvUDRv=a;CxGJOt-pT#+Lj^jdH_aRoq{?2WB3&=to%Hs^D_wl~HzpenQwSesH zw)n_9I9Yj|itg}>OSFEgC|uWNx17t6-syGx1ck8@g_VH>>Swndx5;lDi%5fqUF_cuPFwPWOYz^ju0~? zE&Xyb@8iVjYiu?^AW=UnlpsXWRi85Ui71eu>5;7z+V_2V9KOt=qQ`x`?e zaKHRoiTutdCPC%P(r%mI+bCwwuhG?H#)dva_K2g+HrYHW5;sXWvh0TEnk4-K7ReD> z#NN(F8T8ddi_*9qGgw$_VN6Az)a#NOn}aSWu7w)AAR;%MdAA%C1pFc1KQE9{B7;ujO76$ND^SmpHHGKHi9I&S;fR zMZfcvjY6|)6o(lsyD2w_`Hufa>Yh1{VQDoFmqW^K82Q|1A?w}B5V$${rBJh<8&xRG z;ta6pm6Fi;fj(h}zh4CeOOD@qXt3?jC;L{V@ADHk=snrm65Y#tXqcyY9AY1Rp9;^v zEFzQQghjRjp;HD*mU8`3e;c>8A>IoRRe?NEhkA6LrMyw9W~EHOTu}6He`+GIy`Tup z$JV>2IlLP|wmIQXWH2w+1?FM7ZuOHKSSb2Zs|0O&|1^dA=v|K3NlpZ^Bi{ABT9~a%a6auT%;2?UCL9KM{qy*)51k1-A)mUR97Im3p3CV&&SBMfs zzf4T)fCTRSGtZyjlROFK)EP6PH`JIuL5{S9wFi@&qJ=uobGI;P8Ew0k*9aZUq&&eA zC4#pCCoi35T841;gj4}E30Rq$vQ7(MM)@|1oUgMM4JOlbuBLpZXF5{O{}WBREu&AEhPC=XPlvD~ zT)2z^$G0?wh|fa*VM&XP)3BcpIXV&F`&}laB$u`;nUmY4-|4JW#G4;{W)(-xza5yN zB3ERMJ=E!&WPckPFfhwr)crcVYWTW4t)K5-2P4xz3_(f~# zT&k2A`kATG6qAJKYVf(;3?Q*0+LIo2A>)Sscx>s{P9fPNAAhX92X3{ zDS75uJh@n)daiFX_Vv?=j0+wnoR^%qGkesCZ@&~j@p(xKDb^cMsDzEBF zp97N~WxP2$jKtx0?@;RxY_a$&JonMdi|C;q4ekgHHgy@(FI!yF+=r1j_m@dU-ah*b z_0Adh4-$rqe%u!(n!H>4oEk@?cIdO=3GQ2;+M?zLFqZ7@`7Kl*KPQboGjToPTqysrc~FyT(#=uxm{oFUmJ8n|$wELE*uaGC30 zHG#M`-kjoe12eT34j63n_9E64XlT6`7+Ct-q?LKvbyI#J(Hf#R@vDdGsH2AFE>8Oq zpYHC3+X&$$>aPzCz2N`i2Zc^UF>qco=F-zR=yhELh^137;b97GbdK}j5@M;EQ##`P z_9K`ui4GFRb#G4(X<4=e@VTHcVZgw~z6G@J8%IXP56u(^sezOKvfEk+q;#m|5S|0| z8wjKh&07Pq)p-nHNwPn>+yB)I4zJu6(VpJ3myI@JA6Md5X)d=wb~w9MUTCN@;y((%hKQB#`e3S1{U!XSkb-zAa4T1Za>f@*q^oqf$yhR!_eTkH=r1kgHw} ztb6fXg?W#r-tN(!c=*;V3~K3i#Dy$r?rOFe*x(1Oww*xN$<04aPPOdmmVRrKTJEx( z<3MBoMoq|${JwY>Gu1JIGflrB_e`AbTc_gaJ0lan3fygOJ_T5VM)AH+=%ueliQZzs zfcOzdjcO%z`HR-V-CM?GM^f?n#lsvt+E7>qhLXpOhDzQ{ zh}vx4KPX-IUp)m!_n++5L&%>0@5D-0!o)l<+7Bn&oM}RdIf?)(_VFXv(qGNWFFyUs zMitmp4^7D|_EfLR-l+xAECEp!uJy2L^Anuk+9#p1oN6&5EAHGE8U;nR z9=+dxr${geO7WsyPcee7dfW660y1RHC&aQ0n%zVfO%D^ygx0vR8N)>AUm&WrSOPF@ zPV>w_I<2IxJdBpGuuvm9)@I2eOE140ubK^hdFi2H2ry=3R6H9%) z_i)z?bl?+O(?FxC;uVtr09v!&3JT0%zEwtJIO8j)$&ZN(rYSiiCSRGy$EM=m_8(}B zsf*pHYPi8PNr#$y#kSp~g?vQopIQ5WGCY@^smmGz&a_-TQoZdI>o5GKm6_n@Z%ho} zs`wkSI7ueWpNh}KN4UHMIcxjRO;UBZwrI#XrJ1c@exhPzS~s0VKRr8bd_w$IUMK@4 zXwju)bud)D>7~!OrL(-x9f8H1G%rz3f=w!Yl3z(=B0c>@mX6F|d_qO$zZw@GQ0 zceST8O@1HD!UY>m?ZNGgy6^-YeT{p${=&w({Syw%qL+p zij4z$zF>xytchzHd|E({by@z9Ao;D##fOJ+=&$>}>;ltx-^2p8ET5~~-x?{iB&XUjsmWJ--H3l>3<2TQIoFQcFkkV~yyk)s)EVGS@G6mzF*u-Cp*ongmn3 z69uLXe8AdnyH`wYy~W1I3rL7Wp*DTJ?`y1uUK)n|(#wY)Y00F$$*dnb5dVOyk3t=U zClAXiuey%~BglA!UyjCl%>mg*RO{%_ow$eZ|RJkPoLk@?I zordA~IUd_p0}C!?RmxzQM1@`PxVKE2-s5*}0}#?JxA96}+cteK?n(L1nrJw-IsF4* zDC39vtbqMOYpWF|5Ga2-r52L-e3|~MqICr2`07S|0udk8r#<;4v3M3;@D5y~JM#H0?$zUUxN8$+LDzu!nK+00&kmtiB7YeT&Qz|~zDRDu z?*?^LoDHXok6aN86&3_b;hiC=)^Jnz330WScnZ!KS%vP49;(-`{-uo3G3{2IHnHaF zClQubM5C!2`4*d0BaeF7kf*_slA}{m4Q#k!0)9zx^GI-w#QQo&4rwtaGPQ-_n|5Zf z;5ljm05i$8W3-SQ{!VGcLa_K8Y@WzQ@$ErxG1?MXQM^#(u90q{dQk#Fl!h3!r{pM0 zK?!9TCLY6(1Z)!uFU3Hd_($oHpAoj^e2v)hBoP|z$Y zV+{d@2J=|?vEBHum@wVPWn`NV5zAo%$B5n5)CgF7T|5N!_1G5CPCd1fL*~Ey>e(jw z$PBH=-{n$3l=$3oYcRL;1$c-A^Q~2rqz#S{#kseRF^l8dl*La?UTlzUh;HwAcn1ZG z4D^UCw1iH6T+shgM4tNj=8Xo)F1sAYclLCM$an9d(XIDnYf&DgnVe~O6SW9HL_@;( z^;;a9pBDY>-(MiaV*J7*eFap@_v=6D-7t*tP%Rlx^jPJF)Il2R@$ zD-)Wjr^3I#%q-BwJhiMv=CFMexTwfX9uYhxmQ{M4t&568nRV`N_NuJr?fef#){J3| z9Z9b`Yt6a-{sZ;a{&wJqL`@NoQQ>9J%-7F@zvaH;pXAwX((_^{-mNMKqpuu$MFpCN zS5?p2HR>Wbj%F4p*W^cfiM}aPA{<(sbSdy~Z%6ame-n=s2#{bNn5X!!-k-~(%ovYF zBiP_5Bgp4FBTCtZvfoP3V(RaBWYXvK(o%I|aXAZoU&|TN1(GJYXk(I=QUWQ zrT==9ko@4S*$_Pg&5-_rklAS_d-M_z|6qSURSYx)Ro}krFLO|mEhGkF&2f^&JRI(t z3M)<|&jGQ{Xl^Uv`y1GnXPo&99r8uS&;&tVm)Nio!Dr4>Ac~!PsXPAs*qGoDuW?L4 z;%xy;lokTog0tR@Wu?BN0-@3rOajGwl4@J0N}zSs8y*vVtpT z9{BI_EN#6gw+3Cm@=HHcLlqJ=%hT79I~n(pG@l}!D1G4Hb4SS1S|kWGh?Gfd-Vj}` z_AOpe_LKNP=3B<$-NGxCo?S_o?oSLdyg^s~d&SHn13xgWK5TNIxpG9xPp(mVcADIU zGlX5Lpo2>UutI=Zm$gU>C~W_t_B2!Rif&{}J?*Y_#kGe;p(j&!vNq(j3^IF3sNI&+ zpQu);)Lpbe!tvu~Qg6EKL7lmV zM?!T*zcx+i!Yr#Wg&G{}I^xA4GzZV|oP3>gPp5<{L75QWXl7F`a`@|RP8<5=?a}?Y zN4BuOq_1ZOKZ8@sqqp;FqoIX%Pti^oSveg}^5uI+)p0-@sR|33o_!mrgBu7Y)D9n^o}2`j$(4{EW0ugTG#sQ&IPw#*`KO!0Zf zuIw_k>H1YSTy^)*#RnuZ|4b_u_lP+%xwUza#l zU@1KBW9y^61B&|ldAx&OK|0&cU%NlTHhd!fa&3a2cwbFvL z0EURbBQCCdE0QaIeY0^3(xAuq(e(>z+W(bpMu0vI=6_aOQTLIk9vja>h~t&YxUPfG zIV##KFNF5 zDa-~$lgC#*@3^nSd)q0kC60ZH?yI?M4j1viiI;-s7yA)chvHQX)DX7#oOM~}&8CCCV?mdPa!u z0tD4h7Erj>DJexI_PRb{iu@_29Q{p))CpBqw|)iNQ$}{CU90W+-W5YndImXh=rIdF z1MolnY}a1p_nuzLq4qyXWR0g`s%L@)@o`6qH)^iqTs-qBpNlw~PQ0eaJVJJn{Z|1` zV=j}2&oTH`rTIifJDBE4R|eks)cemGF3*ybP)WP}$1-x;z34ca#(&MXo#Ux$CQHk* zGHOEUy@#4J55BRP&JK#PuR_x9YRvDvy~B*g$1F(>{<=3h8d{#Ah`Sr5^JGX zc+k}x>UtWxN6;!Px&cQi{=q2MWpD-*+SO)uFog`_4Q8IL-y{2H%G@IZFS^jdjgS<# z{3>W2Yw~^hj0ptG_xJ}N*&ZyN$wx+?uy-8p3ooPpBX@7#r$$hSMAD2^-GrYs*E=gc zh@(TpmBc$x4%v;~Xo93R#5?Rsh;ylO=PfHYXU`BzlrwpB-DWSGx-Bc$OJ5Pe6ZqL3 zR@Qj1I(IOXn_sZCDVw-xQ#dE4Dn_fa`Dd40uTP#3KLE*82tcu$|M4DwUG~IF=EG~C{+hehzj4t~`wg;XIvc%h;?Gs+lo^XF zgFGGRP40kf=)+S#$Z`}oHe^GWTsOT*{YX4c8R*2|n<5z*|FKo-?3@TXq!*dTy)l-H zH0l#VhlxGB0hoU}pUh`=3vwFbj)i>h8P+BbBlUj{rn4*O{@i@ zLZ`?t3};UPTp*R)Mp9tgj%aDn6@wozWZdSZceUC@1GmUCr|k{xXlc`!!PVG!OITZH zU}aZ3A@%aOu#1B)l^=EQ7zBcuW;w6lS;}f(I64gU44@?}Vr^RprjDf=&?3~>_pg`{ z+80bJ2{}cb_SQt3Tj#7-%hp-9W#u=6G-a3h`IyJlT(~8^{aJqGie!6rexZAra^0uA zQPXvE4~Ezr`Ihiz{Ci9?kKOn+@)Yt3#tvAQkTGt1S0>O@g5mke)v;rh#C}3BbE+Ti ze+5F!=?ffl&-H4uZup8;)lNh)F!R#?`=43s_=Aola>7F?pbcwQKkfDlzc2T?M|-*R zl$%b~USlk{_M`ru`L~{2Q@9cFb6~Z8bla_=TEwX)Y48M)Y^8itS7NxXrCWVi3lz&e(Ql+=?7WtoxyMpKw^DLuizBr;Yv5RSB#@%a5Pyp<+M1>jq^r;?^S+_a*y7EmW-1lE5pCpo!7`L}(YdM8& z<5x+5@6a8`fq}YZOH&>6-|^ZAEy1%3+fl!~0Wi0)k?=vtni4?aBy}G`U|6JR0jpa!0o<0^G^4wFYIm z<`_MF@s>gU{R4e`L@V9`>%9zDr`B$ijA>uBtJa^C z;M#Ud^7AgQ;Y!ekSNBhKRl;w&lCQR(N5bu0lid9HOa)J%GuUv>QQq>7&Q$13EB2+A zTEC-xLYS~?1RQ_lB{1xKzPn&+aj+XjBdTQZlhJ1-43|T z|B2E6f1)AgqX923N&3`s3LUNG5WU{fx{+PmpTN-Vc&&yfA#Rj6H|5mrv9n<=n7wFv z?N|dFkPQsPvVf{CO@T@maSa$nH04?JWHa-f$^Whc?|DsiowdgwDv;pVAMGNA;rpHE zz*`ClwyltT$i2LJIYZCz8lj3as)*KLA%Y#{hdsaw$?OhB#nHJYxRGR1lx^gu2VzOK z9+d!LEy2Vonj{n&%($(j5&-CL609J5J8K{4Q{}pD>HI zP7XYvE*!wdtDr>yv-wziL_v+o0{K(b`~Xnf9o6g zfhPtBI)k3Avt|tYhN&l)1gfxwcv#7oRiSKJ+~-u*vmq@D%0Onq!XV+)JLqKP&a3-y z?2BBvmy&tU^q7#$XmVTa7H2=}10gO@5K>wpHbg->Yj`r~GIS0lX5Q(C3+(}XzYHv3 zLG_E^?NPlP(f4^lFEl-6z4*n@L#ini9SL(_G~x}R&Ea6S2BXHaNdnsM&{Ra=3oG5dNv_onmXHg3W zKRUyIj?J}){Z8xiowSSm-x~=gu>Oz0xOlo)2cIngC{ZG-?Z5TjJ#%x*K`l0xx!CF(53$FEL+N%r7D-q{ARK($?w_X`PhgU)*nW zvUx-UT{tGR+^PFIGnl`HC}b#=I3DYBa5`b`?-j(qCBsw(W}@zJG4r-OWI%s1yAh?4 z;x7YFlb4PEEL>#Hbj}XWp&9IRsFfwrrR?_kni6>=OYopsEh$KUctLLly?YJ@tnxY! zMp>Sn4zW;`M>o}xT4Jh!Wo&jL4CopdS^tXPntg1;=3p?Dd%3JDxEDSGv?sdMbwbZBiMxaAu`w93y zHlDtUFJH0F7kIlWdW5J`5I*oq1|kX!_xx0RT_oo01P^Z4%2|=<>iDx6;;Ad&7erTi z-|OfWs=XZP0y`}LFV*r5Zs-0kz@nn?Z+kD*I zF?+g^pX@aWkEk%jeI^%#Rr!C&Wi&G(Dz zgkp+b4%1qz7_TN<5}qKa^PqfyGt4(1CxM#uED#XMTqlqK=Ib3*!IAdG^v}USFDn|J z7c*zeJ)V~6AGy#vbRpdmmZUm=98f{e_2RikwIX9m^BnBG{(nq%S+j2ovHB?kr3;SOqPw&Vz zJTb9&_EF?UD*jI}+nl_fq*X&3KmMoz8EFWvAWZGjP{h@6Eme0l}b#w;&39NrlZxv%+OkDg4j@* z@MR>7Da0nI)%I9Mhg10vu~_I=V$dULI&pzq33CI`;5(CKsdGu+p6^CRSACJVPmA!f4pPK+ z#$j>NOl_j~CDC;*zz%Qr_y0G}_sj$95 zpVmyJIKzYpu4JjOA?Z7F4oW1!_(MrZ38h+<;~>@Y$<{^JBQ)O-_!TuUA=-!V#5y>F zn*f7sFT7? zL?$#w=~!~|5!|1sZUsG9%<{_#zg~Z1EyLT%!QAkVFS@EUb>Bl=sWw=8b?IjoN2PqP z7Y)X1=#+m7sn%Tn(1Ou1)V67mK{k%^R&PrfFXo*El>%O?K?3q4!Hkjdf2mJGMlY#W z4CVEN8t>KlC7;Be_YMIs_#fM!f3H~%9089D`(j`~rT0veKG<#N99W>S(f)z=R50Yp zHvvwHUtDd+ptMseE>qwQkaifq;#jAv_-a42t1iH*g!_Tr;F%kivPEn1m0m)}e9*X6&;w+xr5d z9J3J;oxO7JF==?F057^6;Kz8pd;C90Vs;){po|>yy9K>&Df&E^^jQ61O^A0)>(AzB z&>R`Qiyl%4=~jQg%5~aQ?v}a>_Y2|mTs(E;qcTwPJgdSxgMaydX!AN}OdNox1?|mE zglkR}oe^}T5=bf3XXG~6Mj(0ca2^Mh`=DQ~GR@d>{4Ph^oZ~v98IiTfZ8>_+ZRSix zWaGoCB-^iDknXqr#g(4QXG4gPo19P?HT^5VM_zwMElVRajEw)V?-i|{x7b$yKSK$^ z{~AXTRHvwl>bN;4G)ykRLI00ZN6F?RtQFVcqBJNc~2kw`M~i)}8SL3=?3~|fNbl&$3V&rGiAL0EytjBW)_wJCmt%MYZfNn*GcgnnQID`F zV-#g6kc)fQx2cD9QUTvOA5|?j?_KKH2Vk3s$?X)Hqj&QLo}|G!3wMA!x(IvM0%Oj5 zRNCtP$D*P(tE_mhx}OyZs|p`*#Zq_dx-5H0DP@yGN>~_^-&%j@evsn`y1Mmr;>R@L z39=zdi0)ola_m(&QqpB4w#w^Q#08Men@$CoZ`{CeW4h^@eZZ$XB`DQUxn9>Zlaa_N zpzNY$6`D-;QgXM_75LBDOp7gMrjSQW1%nk#XL~6TmyAqL0(FJoq3|ls@=+JDY!n%L zWsom%YCOTPR>iv&kbf$6MhQO_cevv#`>t&EP6F~d>m_nCXqxwH-8ZHKpL~1m){>0_ z0?%mvVH*lx2YDnWbBt*=QmJ`-xRUm2CxtG{{@H7ilnGej<{lOtNMRkp;Yp~lcS%!i zVL@?7TDnh<0yc@8Wa3j`zdWK@G*|_kY7#h1>U0t0@(__+h&!y3(J## z_4YmW5G=4euvnucX-goPI0y#)^|lYPvZ_)4VK6>}nx0|YdC|xB#(H?GfY!Ip5~MKO z)S~Tel(wW9FoQ}C{iJ^*&97_p z%>`K}Jp4LV`H06j*e_Dk%>60mj(gstXtMe=uc)MKS>E^xF4|A#26o#Mh zJC3Q?(hxgLW#7RM?=5P8wKs0%ZT)|(*{Nq*nBHH4V~@z7ml_SCb(O&km%p-2uJ8-w z2sH{{3r>u*X==`2Rh!gVV~BR6%ma1>B+`^C+%r#mi$Il!6hq=^h3i}`3CHW$aUQoz zcTRVapRfsY;y<$Al`9YIHidw-bfVMM`gTYwNRzeUG23>LEhypsLSNR|y8-jG)dZ3a zJ8KMgx2~+#+9e)|1YCFxvrqrXmSYy4kyKly6n0SHVuiYy-L`Z`_h9N0?Y~8k_rAP-BsFBhrY-S+FOC zihY90m4m?17vlH$u*S_#31F^if(^37OM^kAa}*`Er<9g*vxfm@)W_{#eZr8ttek4s zWF<4S-3ITaCT7ZCiWuz=)YW9A?zJnq!zOl{VCCjzEwAgKx##{xyltZq#GXI4qPv&zlMKPYb6htgB760 zV}Sc(bS5ggUdsWs(-x`uZKAX0rwFSr<}$okdDk+H^zvfBq&9c|HdcIraUWPFXM(%I z=0gB}MxON<3WGvOKX4gPvloGa{VHh1(X@rYi5Sdeb&oA0Q@_(ASMF)AP;7MZ7}GR>40g~;%_mg%!gBEu;=GpS|&0t4`n5JobWvu(+f;Hg* z5tBT2;a^&S_5)pGxzZ7J_}hv~%(aU@KXH($ixEe?j4du_hAG)s+NmVK$d~)I?#N8b zoKcYJRddu3DGn#sQZL*6eod{zBW(+x0>NQ$tuiZCV9;!60rZ;fIu!O9AMRnZ9*l*s zUr)#@uWKKr-j)lbf^XinkL|2nEVt~5vypqoL-kxT-5FM#Qr1H^u7D}&f+{n+)Sc!l zm|Xe9Pj`tI&sBez**&!M@SUGPACv1K$ef-2TUOc+b{jL+sVU?`*^hOE7bh>(HzkJi z-EXG7b($&ne-&n$4nXRhLOq9xG_(hAzZ*a)L_WN%G2}hG_uMFuJtKOpZfJW|Xz|z9 zCzKY9xA52AjTzVbEYvku-=LruG08r_|2$(_gvA=;dKAxQSaXWQm=@lici-<>U1M*m zpo#I}#CDpyxDj7A=p<({VJvQXop(Av17iZ^RT#!GQNAOL=0$-`ELCoo{dB{vg!ArU za%{|(i+`-mPIp*r_~*a$fXSH+rS@I|#9t|G>9Nx2|8&K)Usfe}rgE?QdC}v$*PInC zNB!lu{(swk5gt3mQ&=$J+&SYqFJRJUS2#$`_rFap;viq{*8X>+UjBF{1w5$xgFZBU zF}7CpO&d&Z6)#+(XiFAieNH_RbQ9VKlyeB}sZWw~f4`=dOcyn3bd#>?+MUmm+p*+h zq`_v#TX3c|MqK76s6hQc{l2!Ky91&8ClXsr)+sGs7a=)lBX?_{YS8v)WnsI8FM&Xe zSNu?o>_s{4R2gY9#HzT(DM4s9$o2p0zD?Car8E}~7RrUTe85LZfq6HHaj8@t4bl@F z?(GvlM+7l4E3do|7qnL@w_Y`IISyEb`APuxAJ~_7uYa`tr^wSH*6)5cfB&o}z1nO# z>>T1teyc2y*Sm|#*qZsJ-xfFv)Vr9*?Gr}6XZRIkny0{KznN87>W*OuSlY5-wG;?g zDW1b?2#EBQ5_2N+L(5AdQ27a zrPxqbjtMJ(*dfB=U*eg#&BRdSJtGkEy0J0{iqG^Q0@MJv9#9FNq@@X8**?%;#&i7J zRab^LCeFKkr9YOwd-FLUj3xw$Fus=OH;qg69@`@x!u00I9g9C3({swh9E8cv@nMF4rTK<4c`kXI!Q^Awy~M!`+lvD7 zaq*tu^%KiS7eOiTk1z-A#W4tJXnp0B%Nf9+ue+OtY`(scS>JHF{dxpNVTSh^mbzpn zdw;$NvJ1KeoRce}BQ;VIlPo_dG%#zd0_KULXMarsUHl_ld+@);WMBP}$J#f61^n{< zAG`Iu#COHk(74&)+0&xH1!ZqJ`*{600@|vue#Wy|Gp9*BDVOY|A-Rzr5rr-7X3XR_ zTD8ZLMO)L1W7YRm;JS#@m!77WA#g*8*z)hI&->RXoUV6KnO4=VS?=>Vb;mnbSzTv=S z33HN2{Y9(hNOL^9WOz>verT(FPnY4-lIxP}Z`ZrPT6 z+vIs$fPC?NC78?EJ8zU2*z~D&KYvH=Ne;vqRKcsZuB=ndwONED0?uf9|KXh`k0 zSId`GXL>0eb1e#z4Pi^wE7H{nb2;QR$@CHF-Y&p~d?^y|%%nqzy}8*W3%QAyOZ-@? zGmyJJNzJY%?Xu9`Nim~xAB~Tv#|T}ffXKc){g6iGAo67MXW5E_buX*}MdD~=(=k6N z(16k8o3+0k`tX45mW#TRmlW&|8BWKntL6ye=hIe+fxnYf2w?9H>9&*x9xYH=5U ze9?XCj@Hr#=f|Ve&fFXB#r>|$Yv50&{)JM27?C8O`Uy^aNZOT|fq0jWaBksD$kup6 zK(feKXzg4FybWRf1Jzq_L>_)%RBIWl9^qTR`cP1I1ROX~1)al>wW2x;Pr{C4;gfuYEgZuWDX{=gb$||&he`(B9l&ogJ zN2E9LTgl3OD%Kwq{^8k;Y+`O?_7)5dlh+eu^Fwi??u8rw!=yP5m@uQjvQe3#N>kQGOqN+!slHs5*`h7X;slnvs5-frlj{Gp)ef*b*oRm|RhRnhS zRaAfSvTPk#WKNrY?j{Gv+9F$;(QnW-jcOT-%ZK}Qz<(Gh{(NCQUqjxdDymBg;5^tr zHDSEub@MO2Ml*QG;}pLvc97k&wFjE$-4w5+zPB%fOl8JVay+8k`7> ziHmnr5UME3-wUMLUtO>uq^#`M0z})mXwz&e2>R?JH074-lvHi`-M{JuFB?F4?kdng z%@Pn>TCYatxS~-0M5f0VwtJWwu=Qc6Eq{n@>ByWOIBQlR)s7TGDJyj%G-}7y7sa?x9$-x< z2A4HJ4!PUqj>M=yAR1`w2wuCD#OMY=yhqQsa2~3|Lx+Dxw%KKy{|G4HrZ8rpeF z!2i`dBK{5kP~Dw68dc&2Ht1`3AIyZ)a^y)yHSVMNf3|eQP;G*bUCq>t^wVFl z$*#kGvob{4fVk2OISD z878XG6i=GapZkO~(Q*E&q6e;P6=d*SCjGeknUSWBF`w|N9$!K={bH&)(mRFpSLBM~Hjs|XdSs8S}9e)T8QRCmK(qnM(>NVG0fY+jV z#w%Jg_p~f);xlCdaTJ3pi6YVROOquKQvXO-dNLcZFhvC`_ByKBQAP`lGif6fTGDxM z`JOqB;rUGM4l+z+frb;VkY)N*IplqD!Lbwd@T@KbBqZd&xRR^P4Od52?BuoL9Ym!A z-;)!cI42m`)I)&!b?ZpMobRQ#I1n`)8kSvND5$y;yoa9{-;u5OCvtaq=o`&PA`m?0JPXz?0sm2t~aWcVw3VS zQB?eOPP+51;ci_Y1_h5!_+N17Qug=$ZdYaAJxqD0j$OC1&e(Wm@rua}>p7nqfA(v; zSBmI3Io+p^@JnQS_yvibacx((@X!t!q;t57RfmvsjPH*&F2&%P?!2Z@zmt7g%tQZm zi2>@RR}f*|4HG<_g24!A&YC~5yleT0xS0p`S<}<^!!o>m!{kwz=A+e?9gxLNLUZGT zR#0rWC+dyI(mNvz;UuNoDi&d?-$D%YYl=pFkpXCKAfkYhvQr;_lkr!na5{48F*}6G z;a;#HD*s0e3J{ss5%UXo=x3)PVMTO|C(tyTj8q9h621x&kBq_rePFefg6#H?e1_qY zl!6WlW>Y%cjBozF5`adO&!l2peWHOc}Grltfo?S<@E@*wy8>y@n^$ zbOfu+@^HVZ4246StI;4x?9MxQRESYGehJ4{m^eS+R%D~WL>0(UaO#n+^Kfs#p}*Rb z<&6h_(dDu09cZ+W-7g2Y`-@d>RK`+mUL0>@5*pt@XT2w)ps^ubOjdrN{I!kQUF49X z8x|ioes}psMm)7+#us99LT9U?c$;q{Z z%5Vn|AsXz4z|z^26CUmjqVhp84~Z>efo71zfn69~wuPs9`zf@CxKk0h?Ew{UH%7W> zhjGdfi>*^`DYf_V-4Q@9-;|i#LMQ2#pQK|w^|8C+zxGbz@k|BLpCz~^IParbdjhf- z_T;LxG`66S;dp^r`(r$aN%wa{;^#l@!;?x{-9ae60lJ+?Dh(X;<&ms1Eld z5$`tJn-qvklVBXfNc0!VofK;4s)P_4Ud3)w7}D-!JaqpsrN#Ijf5|NLDtn>;oWvID+kufV z_9~aDtO~zpctHr8?)JMEPF1fArk?opdLwxXzx2!@x)t-`cxub#e`T=CY^s+@LvKkY zw}c3b7Temw8d{)YgS6H@8G^)Roo$$?Ir4qKQadI?BKCnJy*-C^qt~o;5d$0>3W1I^ z<5t9Pb|amw-=Qa8mjIuYKPA@o3Lmxw6Yvp?*lZxs~nMYC0Bl zYfnb#RDx^JTbcBzhYhJ;-${{DJz^{YP-)+g)R4GHJCgH*g0H;(rdymO+MzL|$)$Kn zI{sej`J53-8Q(-tal@C$0^V^(-JH2L@q4rU>SNvYzz`p5N)H9(kaJzb!4J(C9MF^i z$$y-b&%o_3P)W|c1M=Sv;82^tOYCLY#fIaN@76Ow{rnAxveF6CLETsw9ULl{sWUl0 zeh@K0;m(8unj+wP4CrZ3HK$KHKM>|u{W<~Ez3XOReQZS2Hq`!eQIw+Jalk=&dE^$^Q zcyHzFc+Pi(!j1O!$4oHIkk1{EG8{&tMY|>Hm@|86NVeipyp5YGvfN?4f8Qz5Zh$kM zqk%Bb^}zH~+R@(SUcAV~NM4eOVl9SFh{b3r;ilTzaw$($6oOEFI8x~w>uAtwh>yfw zi^>f^AcXEN>@tENRJ~9Kq&-+1Ts9^F?ni)GY4w_jY@fsWd%V39+|a7DEkgP@sNulo z_8R;o00<1oy>GU^Don_v#v~O%eaq2IZs?XMNkrpw2%b$EbZcv&qv#V>6%ha9M;8^d z68PVM2f}{i3HLM3`Ut*|^Spk9Q4uZseZ`k}&XpHR-iBl$6S!h-+!;xgcmIuS{Rq&M z;uxelLT`CUA%jvrtIk88xBdr_HXd$|`OkRgS6NXa64qJ|f#p&w=QC3uQ+{@2`jI*7 z&z!e)!(1HV&sF`{G)%UGdS{eLWLT&K`*&X8+9?^ulDw+4PkIUyl{#sAYW+7CF~!i% z7@2DF(;RAzJAY9t{=|dolH#vbu1RG#oIt)G#DLR%g<756t^G^%yWoeC$Xmj{#Y2FH zvCoY7_*VM5cGsu_kII)-Xs?g_kJc5Cimn9MR-+Q70xY+w>6B16+T!ybod~LPxqn8Q zZFKN_kN{FY*AB{(+(i$L2eS8g&xP5jq+VoWC3-PdnJ5OI{9dhZ?G9mpYsaN;z4MxZ zu!v!eV#ND~Y;W@3-NvH=1$^Ke)ZFXS5HsLea6F}USged>ca$ULwDiI}e-S3C2IND* zCBLcCvY$hfo3M@<%-Ve%wGYz*E5=VY@txP(85Lt1!C}Upu9~9zzmJZe&wanyok3F+ zs4mr@e#~E)cp!JmU(5<;?sGIa56`#$a&Kgsa~WCga)yXK%Qd1J)heZ1$afL%ydp_| z*w{T2tRn*YOm}&_M$s(u zaX0ifCXfRJ5#0Vp zr#}C($s~bLXm_blcFf<`;Vx6l;mUkv;%2m$cW(kY{0R8s$gM^Wj{SGJrvF3=5Q&wS zt~t4If}SkanK|aDAGY!51S=rYrSxt&liItB^Qo~pIBMauYNwtU*xIQclL7Pfc~A3I z+ti{t6ei}aDK#xsS#-(iwnx+DerJ4hVBx6DCAypL^z;{BQV(IjHF4y=S~m-#^0G&- z;6U)9Hml+o8|e-aqCLkz$>We5TJu&7H%8YFB507>+vM+xXtrIS=P=i_?D`wFKaXK9 z4^ij;F_KIOxcO^HEVL%{+{|OrN6GvIc$aYl&K9^1E+hG#!A%TgkmenUXa?*g88j(~ z)7ykgj8L<`vk0`Y(a_Wc^=1Dppk_rWj6o%{`f{Iq>OKd!^jwV&>4#asp7kDq;>Tm0 z>hv$)@A$az5{ED=wOD)XFwk4FU+Rg9??)I&&DDB=CH?;h#sP_DiDqovbTSZRe%yo# z!KS(-2h7o0mw=UUA1hdqVZ3_xQ=1Do_RfUG&%uAwE-DAMwj74eO#>Uh>xUt+tC-#P ze(e+YL!rr2;YCsIt0RS-o3|mcby{*;I=Ao_f6Tq9iCMpPoxbaIwLyRXb9(L8gu5}% z(K7d6Z}~o6nE_@cvdf9vQ3CAg&yaezJV)CpiVD)to_7S#!4H)R@oBhaR8SswDO6Ca z4uT2@iKBg%SP<4T=n$B7cfKv4QpYoB+VJGP6OzpDft$F8ZUd;XGb4R{I`s7GqkVt% z613heRWTa2UlNjG=Q@X%Q!2`jhn9LjqnDe1Nx?5L1Qr949pW~VxU8yeXa+?Ja}`9! zzykiJXO zl=ukW>%;v-sc*T{z@y7+o(2|`y(QaUaUXClFkdb`!?NW(uWSKLQ`dwWY# zVPZN7M(a=hVzW6x=a??1<^P^;RKT%s>_*k~)Z9=!p2!o{DRrfn-v%ES>W@B?@&zc3 ztKVC(DU-_D@fqgg?Ia*-@^k6J9W`M?z>=d6YiPQ5%a}|#R+PKP4wkRMMk@?S`MZ!J z#Oo(=fAXxx;GOf83#je6QUGvCKAA7cH{HR$8R3&60@Jg#F~>EZCXD5rm`FQ!*tiqz z-Tbl(E)&|Rv9W0gEDnBY9llp=x!WJsarSRbckfXf0P31+9)>pcpuQ&8jbt}-F?YG`?ov_`F}5+NS? zN!%F5o8PfJhX$z74RZzJ&}^y^1Kmx~e>~tqe1qyUW_nK6thIe}dXSTa=JGB%7I29k z3rG6G%bV0wY-ssnwvFEC_3|o4KP*5F>sj*YyUS)NE4byW*)CpOl!6^@SW-3MqsBt5 z(EV|rPb_zOx)AR7FMqs?r8CQa=eFV8J)^=6d%}7kBu-LvP~wh&TjF+6LIdS6pllVI zuKkmhPW*NBDEP)^l5ir?YD+kuWic3bAJ8+JAHQIzF!Hf z3$EM8XHh|`5J^Umesnp-&)aSmNtOf4j}12ezMX`TsB1)69|QzNG4bQPQg6I6QHU%x zTxx$6C;?i_EfsP=DcZ78D@>WT;&huvlO1W-?9#8u8|Gn$rwpHMNXrJFgA$5c>SPty zE9*a_#=y#ob-P(Wt`VQ_Q9IYl`fE0$W?aZQ!+LMrM za~wj3hXzwseDh-5iHKLoaZA_B3OHLv7DMk0`3>&B{tpJemI>ZIu07^{iTjG1tF>=H-DC+(RY%BI_l2h*i5-3L#rTz=2~}6*f#|+~RJ#GzUXBO? zx#Ho4`>Jb~i1Vdm#i>W~Y?T@HxmVmhowhnsn z&SvLzoP6NvPF6iR1Xd^i4);D_F@+K6J99b&NM%9CKaH$QA(2z)2dE}=-qQl5N>36o zB~q@>FB#RH4T+T!-yA#cpC%Id!|u(dJ<_~IAcP6t4lFm)KJTcXb^tcrTdx+U;vn7p z6dLr4h8?iNJ1Tet&%Xf>+@C%(2Sr^Sgn2%%cpW3AbsF^Im>rcw#+>RO>8l}V%GaND zL`upS|MY8qe@Xwu>j(rPYH{Zfogh79tk!7?gF?< z1fwUr|6=1~QmVg&NPqMZ8sB8Hz=@Cladfij5vcz-y@C!wV?w~I#1)ty#wR#f^%&KW zV?S@@odd-+L)zHx+9hbC3gzfTv(RX5oUC_Ei%CQGfRL}Ioh@@n8hG~ z%Dk#_a7OSI|g!;QW?PYgT#{Rj7RpqI$UcSesGcc*l+#`q`;DsSX z%qAgT4XyHI%r{jh&YI8KW_fPqQ`?r2ZZ=Mn?`GUfXjU`(n&}}p#cdkd)_e6jWLohi z_T(?sZK>4Ca1a((CZH}g%Wr-GPtW@-C8riMN8TLDm7 zt^svb|B21HyQnoefpX9>DWAlLxTYPN%dS&HI9ctrPI^dHqP^har~t|lej$UUeumJ? zuP)FrooeBTW~sAuaK&p+i;A)zd%+KD_q2FvJ$pDPE5i*ind56PIO4x4YmHDD|Adw+ z&k|K!a)?pYxl{5UsT0ge{ci7h7_gVWnx4w76vm-${bDX~Xd)YoyPbN_`wL~l-8 z%Wr6uD5(-YI)dnb3Ll^dlChgWWd6J#sCV>$1tIr%s-5q7OwaJjpXahx{{5k2XX#9RH`PI!QZy@ zug=xSLwRQ_)pb_7?DC6MsT#rP8h4P50l`Il_x5qwA`0J_s6T71IDQW!XCKXW49o7H ziv3Fv-ndF2LxW6evQ)AE{TU(lMsUxNcvQIP)zwjp=#U+ECoX`>!se_`XzzC-L0v;~ z#X3te*|k(#%Bp8K+v=9|NjU@H7)7HFJC@GFL|(BiwHj5G@I`}eN?B<`p{CuKHMX%X z>!8zY*>39g4op5w1Iv+m9xX3 zrr(?j_$!3HCe|CPF5~Qjr`bb1{lunMa&KVjWs$3UPgF-`Lr2H%bg}!lKJ8}Q-}3gG zjRy5Of2kT}1W5ckX10u4O9!YBQTq zFDVacLk0pr^V4onf`hGeS7h|-KVd^+23$|Y<-jjD$kU$hNJZzi237;=%kh5< z*&$TmriitUBzyeNw@B^Jg_ql7k^kTwq9TT_mEL>03cX5A5EgIXP(yyUjHb5o9ffKu zsP*H;V2`~_rHXwMb!DISOyy#%=!b&SdvJMVpPrmT&H3)f%rpvBN01)yJou=1y{|!^ z(cCu09?O0qSm&4Q9g0D=BG?T>54D?%zMoEn1%X4dAGarLr=@@(ny~K9y0dDZl}YHT z)N!9~M|#x!--4C^niuJG&jSN=-P8r5c3tD}IJr+*uiVNlyin?d(_Oy#eADDp6q?b! z$75V*+AAS}&`@|o>d@?66)F=f+dEkg2Uax24<<8yYu^|{q3P#GA_)IJv8iVo4l}JeP~!%NY+H z6h%Q_n#rSLqXA`F{+@xns+yWx^hwsG(J3;Oiui;Qci+4lHF{tsW9#%=uk2PCcsC)* z(<=#3tF+l-qGk#=Hq709DG}Qo-TN(MEgpsfF-Ql+5riiI-t@G3SKg1v&dVdQjgaoN z?Oq^(ieJT}ookET69&4F=L_<#mdK|HmZ#L&zhI&kaY}uq_s881^+CQtN1tGSIgU70 zcgWKS{+<_q2+6WeQp8ACuZYZiYY-Q?EE+^mv4K3H+%NAu*^ul@u;zOS9n@;r@e2?2 zviQ*4619aFen3EQe>zlP(;S%Vk+rV$T&<8IuKV5czcb8pu?N%ljtD(2;O!<#J}OJm z*`Decvvo3{Lu6sog%qCZ5v$>^u>2N5*yq{%YRoS#tmy8sU(*sbP+>{{RzxU1+7KK? zYJAudZ1RqpJv;(?NMcl2@f!RDV;RK@048F}&TD!Ea4!D{Pq7r{(IJXL3t*@(yG50L zzkjb)j_&=4DCno%U|K4S(>D^2Of8%o|%~FDe_+yW}m`2iNe&e|*X;r>Hoh z611_PQb3={)_DT+S<$fXDfM_UhRNBWW3t*CowQ`6?}YaQf(1b!>K%J- zLrDrBfsqm;^lxyeC@AYBP@0MkSgpUeb{sMIJ#;CdJkU=n_t$3%o8R$7nF#4_Vh03m zzUc}6760kUwkywI*yD!HRpfm27X^Z$HL4+kUf-Xud4MyoJJFLDKn$Zytuk zQ>YZ`E5DVI%06t~LAOT}A&oc$m`wcrWwKGdgAg~r}Q--s>M}P0W9~Asn zcLEY?+@G1#Hxtpy3n6jX@RZ!7zh@hBad1RF-$#0PsZ+1aI->)E>fN!B zJqP z5I}C9nflh1i0mt<{1dj%X+Y5NcB(*~O^-rL5)EQzpd9^6N>|iC)GGG_MYxN1sFW~* z==0yG-rC`~&BaCra;tWUts%}z(1;Ihaq@L*kIXQc`OmN8hor4jwTk6{qkD{qFEYcj4GQO+=@fXcqoY7$+u|UxEf|Y~!U-TDw2b zDZC98V#fLA)cB}bU((Gy1Z1vxh`Rs#1)?-tPbt)IH{;@Xc^tRNCrX%`c7hCFwTZAM z(p*UUm5V-&8Lnx*M2lhQ?&m6lq!02gqXY1jlZ0U?bvV-b$~w2$>jEAO{`;WjtLTy) z$S$$^z>+5^%&!=WP$GGf=VxBnRzvbQZucY?hSkifVB7nJ8uhvZ1kODbvty<}#b}k& zRnVVM0?Fj9_IE56;D0X|&T1TRy^b|8DXn{b8c!bVTd~Doe@iylpIwThrF8nfYxx%O zOu++oSVpP+78|xEEi29S3ttJzBI(s&B8!t~l2yp~z;BzPn#5<1g9T4Mg-PjdpDI=i&J5btqSMWjTs52iUc+ykQrpv&j)DkZV|Q9i6j8yti$%svg7BKQ~c zMDNOqOE8AZEcz z94Xpcr3al$K8eu_kBnf#;Id#B%taG7%2o0)XJ+;(x2#F+iQzp8SEyK50Likt4u zfqu%Mhx6|)C);y5c`aENE_4Jc1z_L3*^s4>IgNaA^NdEZh}F3~GX0urF1v&!wN8h6d(zA2tJPmMP|2szgsHC;;$k65q0o`wkzgbub}X+MP#Vv zr^g<9yL+j@_1U@K_icLw&AtH>`=Tu{JDiU8p!5(zNnHz zIv|K_-<1HZC}k{S>y{a{*BKPrx4)==kIB4t*h%2Q#^U}Uvkdi@o}YG5rr3I7EAmMW zn~1FP6Jzv1H!(@<+HP=i7u7y8#uNyD?jXu;GCwD|;Q;xK>>ZEhExQ+;-3rWZ!%jdz!0bhZDyTS9EF9K`2Pkt$F_Zhsf5 zS@FAz#>+q_Y4ih4QHus0B5B^{^btRem8zH#D*yFerw7^?2FwVG@4HRD ze&0zJ6TNT<_Qd)0pq_z)zOpChNW3%pGYa@4)T^v#+&ZE;cWBxW07QR!HQ{+cTixKK zaAD}+dp9sS4gPGvwyfwYw{=l8r%Ic#1DxHz#U3J!syBrNmX6&u?AvSv%uM2Hjc|eb zl}8?yVhS4k`hpNyEkfaCt=|j|z!0d{8*75)s+A`-eqTy>!}e{gb7j=)*%A41@kX4r zDUz4Wt#T0;Lh%iMA@1iJK?I#H3zmgMnL@miTD`*sF%Z!X2(muvDR-9-Ka0*VE(!pT zx0B(8w9~D@AXpZG_3ixBjtj1Z=wjfyr3(}~TU1-a`dHH8>uSf$JHjSF55*!pV%=T7 zlPrvo*nR+?bvnRj+&s6%FS1!I5034(?>VpGpOhe>jjIAtn~g8k-Mv5;l~6&^MhO|* zLXT|R6p`9&KG?=($HNkFKMV~a7W1d6E87dfsO;OH@ZGv*_O8$L>r0r4t-AF%&b;pO zN}H^=k65<5DWK8!Aw=Ffe?z`Z$Jz_e=>`$e|H?TJtlO|s3ztrN3lu1sHs5= zkg>yDObAf&^j?!A$Fva&EJvTH=+wtOS*H)&Pr_JDeX_uAseHu%Yp)4$Z=hvCpdHw*s4s*$3Ffc=O9R%C6~jCqbyi<670($)Tnu_gL_+t>|K zK2SoD6^m_TJR3?`*QP6apiySvC$@~eHpn{7H)k^s&!OoJY4CmuD2;F#n}U;=dTb4Vg_>=KqGIv_C;yfvbjK+@xR9;G|s=*Ih@SYh5>QVR6hUuhXY7_ zf2Rc71+@XCHs>y$h3ly})qgBQkZZm3s}NZHj?An(2xBh48*y9*A%$)32`4hSA&@Ra z^aVI^ROS>U6SUX5gVoNR%YXqjGUl`vDRQ-j?w9Um43ZJ*(cVz|4#eL3HwRG-AOU zC^je2H!y~dC10?^v*&Ki`;s))uAJcp51S?zYFBc8yrGYW6=l)7q|~5O`Obl{Ds@4R zHJJ*n2JNb@?%DTK8;^Z{0|i3MzR+8J_`n0%~pL@Uxlm-1UZm z*+=cY#P$n1?IlG3<+JN7vHauZV{HBP7uBg4e;k7TT~1JyQhV%gP0CWybbJ_%^-cV(d#{CP8>&R0THQYZ#3)C_r~Qx>BsB_Iae=j z+ArSP1~v!bLOVg~%A|6jabLZwV`usdE)m(?X=CfF-`r~rGnJrHW+<}JV)f@{2HlaSYhLW4~%o^W($mhbbBRj)l%19`i89OgG)B1_ru2C z*(~PB<#yyX-E-gZ4z>^n%ofQ+_j9Pu`uCXYpXwQ3@7}U}g*9E!gQ0kO=ljI03c7Rm zP0je2(OA=Jqg{SG%p?xati{ediKq~oTsLos-R)Wc#6Ko0-iTV4a+W{x7sO_BHj&N= zfq<(cDvY|%r{~qTwHvNF+-UtnYFbDB8ao1roFbxdb~&KOZpaWpm-GMAN9=&65iPu( zs1sVU7qoBxDz*7=85^hj>(d7@_%mZ>2ZdQ|!3%+;tPj#{!B9_c+*l6Y`cbD}LzFA4 zNkpKx8Q9DiWF`9k&^Nu0r|kY67`0Sp5NvD9;{~jAu#H?QJI=O%c^9w)?f}uT&;F!(?!1H@RuwB0E@nTmC2T+c9CTJ(I|_ki zWJq|fU+_=foyRBUZnuP)eI1n&%ccfCW;pT?l*c zlS)Ab7(QY~=zb1uPs6A$yY(1Pe~*-VhqTDIK9}KnZt>yKC|s0(_*!>jOWOHhf$)fn z$nimQMU)ovqjIZ^AD+tuE*;r$Wxb3*UTUkqpJo4MXjbUVH|b~^tL2bvp@ zjcb+coUVd{8sT}dI|JsU?xJrRSf?TF>@l^q{YCLCEb0m2HEm7@S_39p?$f2t1(#nE zjrt3Idb4!C>KL{+|O5A63Q-^ z-YQ1nL|Xw?yxZ`F7ppRBon}g`F2ZP#G1uq4McaW?``c~vgHVYtjx{a&yf;6_&)$Xk zRVOg4^H~UmyeFTLJ|aL{-&I5HYgwvxrk86=uYddiD=O`c@AB_2jgDE08>@$+WX~Wu zseD5s3!&7*v`}8@h%EYc)UyzVCCJ%t@-i2S-fliNzQ%`P)gCpd64L9p3fbIn}l|l@+cc!;oGyL( zoCP$j(w%or0AxdHg%i4eKu9R8t;(%fsMaIa8`(O$b_}d_(Wlj2u=WpQ{>;jdG z-F|@;l!4*bQ5^XF=*;D3*)LDAqMsQqjTY}#n5mT7S_R%pYp>tq~HWa7{p!c52>;>ylX^=Vba;?~=|3^`1v^YGIJ~gj zHg|+?!GY#)=3Q4h|7aXAt_xeZ{!lSnx8O~je=#^}gt~QXBK{Q(;}{+|*N_rdeMN@S zP(JNGKxx{y(uaAt7TFILyQ;V!cNF8 z9qK9f$J9W3RzmVyQ)OPz8&cXzjU!y_P=pJt`fRbO-X<`Zsxc37NI$W3^vBmUgke#P z>`DFGev>;w?erp)HD3a)!L*3YlWDq6$R5-@qo^`!{t}vp)C}GGAU>3d$Bz+Y(HNZUs?SpL;d)uUBpByTw}is3RB$#@0VK z1>*s&uKcSR_mVAH_d7SPt4|<3u2)0$2@1pt0kk#%hr3~be$zqywh?l9RRi#FZ0YJ2 zmlEBtOy)kwE9W_MyUIt|$?-EZP&bo3yEl#_#Xpw98!m}$Uuh-?%IKgxkZ~%Tzwd5h z1uJ+=(AH+*SE@SP;6UN6!p+dY#D#IhDNK?)+p=#7F%q1{YUGkZ{T|?UfU&HagWIS` ze`L0_VusS3{l)TyUnwed^M=YEtA8no0-mknBl*3B<7dx zL85D8X$lE0P^3ZSybMNyh;A;QGl5netiJ5@btU|*B_Q*mGxfN7InJW~n>!GPMe>ni>`0jK8Lsn{6LY__v9rKTE)kKV} zG+(H67l-`#?IHiA1VHnm+Jp4v3Dt6VMkwx^TA*=KZr+Z)T)69`II(%E;%ci}o3ao* zpKTXf%1|IiPuUODvM>1VeH8qW(hKMIQ*^TxeP)ciAJTnvRdizof2$y1ga9DC#|D-n zuO>K;5^4zq^Conj>lEB>K z-{chi>#;>#)b?GsD0(iu*fLa#+P^!BPIkh2hnW@qcFde@kYGR7qixqBl7LHEwlqy6 zN^n11Ew&*D!E?juB2;RuBlw8f_(@%T#iMst8V#bdnnD6;BW1OxI=01!VV(DFP<6k- z`=+5UVq{bLK;&a(07<@x;p$-q0WJu@dyQ>or`pwZz|-4U19Y288BEQwGgp-PZ_a~) zW)u?XEkNubLURuaMgx<^J6GZ5n{d&e7P7#kYn&(+Fv@vq?&3c@xPyfyNqNbToS5$A zpR4nxX!*0z!~rf*@e4|tGumzL$l@W0A&7FI=^4X&^uB zbj5vx8V>T3hPmS-@ktsp3glESoMV*)-)iIk#=Q|xBXeZ;)}4reLjfZ*WPcfTEY+rW zRI#7eK`2*9%8o?Cfm|l)8xMTp13Z>D)w@4r32`EY`lNUv{pOlTTuzAwZSQ9D9A@jV z9nJorm(MX&Gyq&JTK9Q@k-TOb00S_GN3&o{)%*~K$c~qG3->B$BMhEynWN4frFYl< zTivuN(ic2fQ(Fv+@4L)7)i>|3dM%#d^2%qtNYn?YLV~BwpXlB7U??#j7d@ky6sMoS zSbwrcW2#mjJr|52>u|E6mL;wiYO3O`&gjKq0m^ijwu<`h>q*L~0FnB0DW`MJzpx zDe340ekEg5B9Y^sVzZ_JIpj4yMoHLTOAG)$w!ARQYEOeyW@bsyY{(v726-Tiz$0^F%s|yAHbS4(i zA^3tI(Tqpt(4Joy&=Mk04v+FtA&Ly-FK{qyRO*i`(9T(s|L9JWK*6cMLG5NV9w^iW zp)FSVN3;vjA0HzQi_+V)MB_K!jgUMo|DY!jfMbtV{;vrKjaCIY}r*(CwSmxZOTHrg@={0H23eb}v}8TcDU+)oMwq ze}lHMX~-yyA%_~{3YlPkZjVb)H(M`o^Eol6vUoEbSr(^yj~WIbMGu4WIXC9^LxPFf zP`?QaVV7RZ5yeKfvh!}gW6m^C>2cfc``5KDA%kmEQ6$1I=yU>{Zb*KeGBALT;x_n} z#>zZ*V`tx8Xm7Z>@DURo=jQbB?nW34xc*2y+a4GV%EtMIA1(4qY!=tEfE8Oo*mQfi z(I6i7j{MzE(PB#dGXFD&Oj!QKfM~58oWBzpQH^|t{e6^D<1}Q|!u@klP=Tarr&yX$ z`dvv^6E}0+U*7_c`i;BEiS>6B{FkNd?=%IPHYDv8LJ|m~Jo0}Z0HF=4apyBBMAp2y zCseFI>3OvQL!m}}XqN#5l(M*n>U_g9THRycXv|frMY!<3mP96{pQZHDIc?kBxSupsRXqA_^&kOnV+&>ZQ zf~g&L;!YJ?v=}@soHcxD_$Sld&V{nJ1sbZg$WV&yJhd1nB(8EPi|t=1f|aoOX~>_}NqNa>f~<8p4fcus!T6#dSdq5qX?? zjJeBixV)_26*>mHN#S29jta$6cL|bADyJIshF9J9^ck(}1(sh2^rdc_>s$+1?6J8r z9*l;zKc&X~Il=gRWHCh|@gz^3rRyVda}ui4ZE6r3bZCB2WE-w?!j1LQY5kVYKJo8a znxkl19{+&5F%FiR7fm#&lst`CO}Bt-nL?9ZIL!%koRoi5^ii?T9xIk?HY_$YuB}JU zPwjSdqJm*;dA6xEgO>ZbBahpR9L*7Q`n%5eX1Ip zJ%%8iH8Nwz2K=F$B07DOJt#*yD$d|xebt#1YL7d|!b7W}nwn4c4Yfh%^pU!+{If)- z-aOKK1Cf!2kU$7#hox${p{;B@#YIW4F%(wmQ4Yv15;&sr*SNqA*I;l5Z-zPxiCi>o zkx=+ji3f4^&9AD#XnRv~=v&YE=Br=mp<5rI(x-RBO*UoM)Y$ia+Txs;;AX^vmb$*k z)nDQht&R?&4oD&t-lfdFx83|W6~5=DU@TTa*jEcpx=;CWXbtmoINDIWunO@1 zHm2K^vM>9P*KCF?X78$@6v~-WTvpWEW5*zym0T`yK}Ivk7LyeYjs}&&&Jf%y@2S!G z;v|SmI-2aKAY?WX)N1|9#ZN)k>Jq$8mQnAlf=gHu2nlr2geh{&vdX;C0kJz3iUTqU zdkOr)ZNInXVWvS8d=O0phPmYQYZW}YXXNeoI+3>yenC;#^!TJZtxh>@OaU&-d0az? zmimS#pW#FDu0W5}@r_15?YE|4x12uFVNIY*RHSJkhJ0|Y;*>sGXV;>9d6W3~b(T&A zws*rgE9H%YVQzu3q-ByaMC{y}<~(*Bs7giDqS7>%l=#Cim`N*TS52gLGs*zM%;SvBhh zuHRe#ucgJNlS)4uzzAqRubRBi7+G-=sYw4*OyIztjb@k|IrMXx5F!)#WX}!~bfXHe zjSJgwG_3WGi2d6O|KKw$ITa0SBN|w-O_z6&&NRurKbb z&9UwGkcW}@v=}7~UzUt|c_eLH^U%~1eD(@p8uQNIk7dmqt{{7cz_{=nnNR(eGsOzlF^!bm8}|@{k=uvTTjX7kjmZ}bdq4&yXxQM`ta*=LIQQ;g{l@# zj}vJ$uE=eXUpRVaOb#a3`t8iizy|%-j>S^oSVEI&;b&8`w^K=-ZaCAX`LfsES!>o{ zdLL|>{JC7ma)Fd4sEr4l-8l}j0WuKQuGO<5r>w7-Y_R#<7)gnE|EoH$C8qFjiT+ve zQmhaJqJ0)rN=mF^8$6jrh3Il5sqFaw;^{1-+G?Y1-Qr%{gS(aB?h?E>#UW7K-HUs$ z;_hz69fC`7D^Ms_in~iszH{$5e>0LFN%p(k9DvG@vFay$*bp13x2b3=I@5ow9eqW2p0}45Dbl`nN2-tqh!u~U1~W< zT&iP_^!n7`CjMnwMSd@13B~ADWM026Ml#R47goCUh#*QAFVp&O>I^w`*13KVTp=(o zIN|d(`ElEtJ!boeFnLfTsiy@eR~JZe2zZpOuBxLd{qs4Fq62MPvk6FjUEr**xZ+6F zbEQ3%-wQ{L#|31*1Xy^yq+R{x#^*WtFJ7`F%6z5zn4>5Is6Uj#SiEU0ia+9Z@an>c zw{+5|ysy<&S1-eXcj#s8O(4(G(!zs3?0a_DbL^!FV;RU2ay;aO3s%2f{p5u=+`OBb zf&E50Y*0tb*MU0_{J!D4=b%I+DvHSqmxlGeWA+UeqQO0QbE?#tr}$J$)NGqF0ad}h zJG~h$WV)H%_k>SFgN#$aJWfD39kSlk;2N?FxTneTfo4aA3SIPp4SXX465Z$gUqZ`MHNsP$BJWZeZlsX+L-yIM9h&rz7 zvFQrF=@cbCP_w#y>Ys4HrzuiH2kUnRB7gRHZggG)0UsUC1kl75@?^7Ae;)4L&*ahd0BN`)|5^ z@?VB^;sCgaep!0hKKodBxF)klWKR&9p7k-uK;+%f3X$30obRxD?SNtjVXHcQyJJKy z#y1C`2@c8j8I)B4q<&^ow4(3gDHB6w24l8Dp{Z`8gNq|%gpkl}t~+(Q1-h~qfJ{Iy zq@-w+b87Ol@^6FZAA_skVx_dI0_rw?5H*VNQ3gWpI2$j+Q+J6URMz8`;=R;qCaTNS z)*b<_*;|7fpOv=H@Y!-+-n+GeJ=jt49b=a2U4Y{#7oG%og7AaajTIL)*mPt0u?&67 zY#CTqM2D*aZjql5*nWm%W?OrR?z21OtL*$NpMLFy*7QitRW(zg+fd6E`1yP& zZN#oP<@Y%$$CKIpc4LN-Vu(oP#-En_{v%8UXR#| zgFS~AV)Q>85!s>j&SDD?e9yrv*gD)SMT*-GuJ3A%1W;8uHh1(`Q|3-V6@5jAUVE2L zFn+ykO%}I;km7hA02MdlgzBfii&t%c-j=8Jo zlhNiy|A%M9=v@$!53KvRjhY*1kiI^)?n!z*5+B@C>~5H{zVQoW(>bSpV9jMZJ|v+x z$6@Qv&@t&OW4{_&a%y6`wp1^Zo%HRQsR;Io6!f3pv{YZsb~}lis?KbV+6cch&d^c| z!b1otcDhz!7eQQ=c(Hq=g;TV{9WzJj!1AtMm)LuV`$oP8=SC!fna_gm*n*O#^t#cW zlW0h(Qt`Gq!;(1q8eqVl4oWBzDzN6;9lN#PLOB|8=(fG{n+3w$?3o!IwNy{raf+@! z)4X4jm?NRx#TodZB_M>iY|z)J*Rq9f z@5%O@ghDKpbO#0ph`xNY;-iy_AHK<%4vai9GPmWWfhz$EusOh+o{At25{bv}!LIJ+ z(AL3!aEQWv(yFh<%GvGs(^|$S-JI65-f>ZWMB}84X=bN;lMyW`0WgS0b^`cU(NFIo zDcqjbr|E(Azk_2%yFJbFat@56y971Hx)`vnp@-gx_sl59mxb8pQO(vxRcC;yOfe&6 z?`(aD5AEeSKfrc{vYYEieIBQSG0QRcR4)O z;A%ChV(q{!bUncU4CTYWG81Iy_(MlJ`u! zb%}3CzTM&J^@>oOW#tBM6jwFW=v%#1NZzbB&O-u=wy-9Lp>Pg;DwY>sdTs9SH-1?o zHhi2q4c%@tC0lOYoMR5%qAot+oyq``QRG?dej+IW4;Nvl2FS5>A9gdwPpX1{B_qhl zIYL+;Mu3bCF?lMk&U~kKW7*? zw=8%rm7Y?dwFZWfsJoff+qAKtikeBuhYj~7Sm5rA<-4t04mmLOS6o(O3pF7glYNXH zCpx`-$$jPwkVB&xt&@CM(@k3P?Qvsk+s9aDSLw%|7)2CLWO{C$Xf!WlB7;fKlJu@i znEP{az0pVwHvejPv&;pL!4y^a?Z=qo#gA4N{N~OX&16SZ_q-X&p6@=Ja~bC`EM24P zaRu-hf^Mw#;SMU*sLIZST$fJBkM8;UqRc2l&DftrjzR-W9Xp7*t3T3KPkLSPlqmzlyXLo;yRI=s#qeUK#4&NV@&T00TbdnyyBEQ%z&6v- z6%e~Wsq{;P?FsqE3`hWZ=9M2TW^+R>gf?ox*4yL~3oS_&FOYHEy&=lVsrJQO7)=s3w876`f^|>aR%NhNF3R7Zytp>C9V78yIDDW%`T>7b1($<2QtXj1rgcsTvjav z?K$}nY04=F=hC10A0Qt$+?wB0;WiU^-h^)zwBHiz{4hZKhzrvAXXXfhj_5>N82u8= zGP8y^J`QLt_K?9d*PeXBv}lF7$m-tDXN0}jrsPW?EdjS@+9OIr*rN&nIO7pl8PON3 zgG|2|CY^pfI7^!g!sU_=mDHEMf)HjlAgg!cMS5|ly;5^v6PRI!KpGw{oOp9eNyR#R0!$GaH)GBix+f2zdip zNR zg`WRJmg^=X>T8LnD#(35N@~3@&O^u@ns-kX*z(#i&MaaVcII(X^^5}*%1w&iw%ztD zb)hRCG0K)VIo(hqwqR9`l05U-?CQAInO9)!)YS^7kSoo}>8qIyO!7Hi>KVG=eBSFy z?2|Vse6!gt&j8XGuF*rO$sx4**uU<=SGCoA^Uey4QXIV=0gUd;ef zqu+C6zWIkC&R_ct zEQ*j9PI}Qc#-Hp^^a@75i+g~4&c*^Sz1*TC#0)oBs?nwnsbtv+X4lj999qVaw2*G6 zXaE^JC}Y|b`8cDixQ{mniZWSogv1Ec=jxAFK7VJ=67uS#`seBLBQ0CaFC2e_s4CJo)LgIH( zI?mP;^4xSR;8p9oIAIp+iT?J&6(jk{w(C>e^zy14l&JFhm-VLel6idM#s{4|JWl^% zT>Wj3GQS zs@Sr!zUAIH{AX%&6XQ=(cYD&u`DxR)2POgQ7l*V~JRrw;RiKYuS8z>Q`mH_H^LPTx znYWUp)j$yXD;Fk<-NQ!RE#Zx%++ATu765wz1z~12;_r`6zS(rp^m1EpHv^BkwzP>u zfr=2{Cz?M5n~Y@E*Qhl{#GwP~b5VxP=;NJy{QKKwzliz20B9H(%w^`~vpS^W<9|Nj zq)qF`H(V+Z_uN?$2+IMo7G|9M)r>H-o~;v3%QC8K0}O-eS}!rfb-vCsS31Ikj%=?7 z;7{A* z&(%Pc+Vd5u@l=GQ!ZR2hJ2epJ@zkOn7V{b=%Mcp*@JGsf8+iqXSx=7|K$f(1K}|TF z{c#AxU9N4R!F`IkWL$UOi_WT~lQ>F<*vHtqK-*-09q4_*M4edleMj{E^U!jN^i|m^ z1+VW9^B;8a;%wxhK%PHQ%}c<_?Ca{^`v8V%krsC}$&0E4_p1ih3dc#<9zG|w)WN+6 z1odTRBZGWq;#7atiTYOIv5H@jpYg0ssrJ(selKNcDn5E3`=}r{xE^xZc#*7$0nJ6r z_KoKRc9|2-3Jc8Ilre6?G{(2XYx(pI|Yus7U6h6LM^>x%7Jrvt#s^%dh4DO;ZnJhOzXxVc{_U+y)Dfe?Q zC?6ZQ&3>c+xUqq9EBt+0-rx8%_L{wqxS7OHE$7ssNUl^T69aQ%Bn7~^g_u%V%Ov|faWnq?br4l>tc1Lfmj zKz+Z-VQ&7#)SH5kAWz-FZk#!*=}%5y|kpEC)}Rk7GvVYrjS&G>v&oT-0o zHNOSjI!=*`6ZY|W;jK^W3x`9;tq5O`ODxXh}8Ki`#5A5 zS6iH~DG}wZf-)9VFs= z1{|~zklNyCWOXd*-`GM#Fw|wgmpL;-895>A;I}?r44BPa(CVg?(NuCIF;W6_fYvV4RpmA_v31(S;4(xoSM*A|80z< ze{$E#%^F`8$TCFi;p~4HhT#pbIS}Ur$kTdzqm4}04wZ~6T%hM7>7Jvk3c(cLJRz*X z54}om${J|l`Xx5!jlUo#Fc1bwFa5P_DBsWSOhdysWD0%3m5DbF|MaiX88@IfJp(p& zQXbriZxcFI3h_j%`G*}_rZQ6hqvXkMpBPR)-{t}{0q2(A`vJE-N%U+7otyP&(3*T3 zZM}h5dFJrMt74PBOoy7LUjY}^%8mg<|78-j%I`qQNQ?w3EdX8j?7jUyMAm}8HZ)C5 z)o$~-kS#BMGN`$>zpsYnrX2W_j34=mp!S@fn$RKf>mcti<>3Z}XxJr)4VGtqILk*LUT zTFtdI1)+U8$XSePfHh=Llltd>UT+`C#L|)f?z86KA@w-Uut2wDu%w<1M^NGc5@5)m z1D&r8!e=|eqxAKH$51tYZ;MNvpvdh4aNz{if{>zBY<(18pIM3bPisi7A5hdw<*9Dp z(??}>kroDmw@U!$de4K)v;?vGmiu!|8**=VyD*iEb@r!@v&A*UPm4ZHJ1G3+$CHa} zpNjC)qxI8i134Sp>|xvC-^^BBpqM?reXWQLNwYfugE1d=wp-B5cFUo{79xfL5d+5G zBz3p_OJBH}NE0pv%Xy*9KSakzelQ1T`(f7r;7#|*A+sLHiAz--o4%Nnf3QO1rQg+% z!<+8%65_ei1shI}D1j6}5xJiByP`O>$)I++8F5m6inQq9qam1FPKC*3PPZ3q8(o(2 zXqa3^o7Z;uVKBDtOF(9hj=iS{bt78E93R>UnFc0_@jTul^3_0s>my_T9emc8c6P?>_lgHphTO~HCyXGA1 ze}n!AWsu-`Y}a9=M7(OGf$dv6iu?1EKTiV({?^fc1&_y*ttO}v9fsZ85xSiu`H4wP zPRQojO}y7jsyTm08_R$Qvxm zV`))i804p}e9ThxhgIL&8puxSWv^zZ%A-beVaV5(wKNR?Rc^1*S?>eASsu(YeL zhp(Tx64ToCdeDU&R}ovla+af5e#)_@68lNh$yAEt<89pT6v4aX(9} z`o({(KF38+n;dK(ULxl6;ph&P>mQG1&{43qMDKE`)idzWzA2qJI6)j&$hwTDtf3{v zsRl(Nq`3PV6SFB;$nLzqw_vmH2=@o{>8mbK*JOzU67)cKz2FCysp=O|G>cbwR6Mur zd-J*O@CGiIWKU)nv`EbdwgmQl>XdGy_k zx>h&N0JZx`B05rZaVDsuetDBLxV0LjG82q`szEIi66}1D{Vw5>jMF<12P&TsTD=8Y zUAR2rq42_x?AKUR@B45rR}}p8p&`*zyu~@V@)|$vBZFCR;8sf4zFo=w!lkK2n{WT3 zJwnKVSg>AWb;J#(!rv5=c9PIR+F>{reRBNf2gD)11=yDMT7xmWzVLnYFVOL5dkr-` zA%%dp<1tHO-Zm6GB1Ti|2YhdDC1JQEZ0p2vd*vzNF!#sn``GL;d!fi-Zlu8x$OJh% z14`(>8!2TyEf#<1FRnh6UBVTx+=y67>%ugi302PyfikHA*f#4oCEU3SS!3Eu)tX3A zF8W0N@Pj%n%pC*&g>UmBvk}#ftXF+i()Lr(w>%vt&&1S`=*tXEXJu~Y>bM~Bu{Cef zCV9`bTS6@AhU}Y~29tf4`1hyZlXJQNw9*P;)bVCx96ZHuQ+_;#kyDZX!iZ($rX%w| zwxkAU8Bv{@5u~v|&PxFkGQOwzg%wvMapXWb6m=Btd|8xSow!SMzlT59=$v1ckU=)x zb>}93S!;EZl9Ex88Ex1cc7bTpJ8Fs{Nv#HB_9DTLntvDJgY<-+2C1z%zTnTP#va6!Hng?ZJO|JUsU_~(RmJ61E1o28EGw(|uEVqX*iG4)IG=cuTh z4@CK9&dy0M$$fVj35_m(!Y|1hFx{LpknA%V?mcffC65_=-B5}G=t6VUgES%ltLGW9(ucc7+8+I8c zS7CFAtV^+ituHm zKKx%8T%3x4DL!Qu2ajrWbmd^ir5!d7FZ`yNy7Mg=qF=h{Had4(^QcpfE?EChIjH;q zw$Kc>xl8ejB)z{ikt0xEGNhokba6$o`_G z9sZP1VBwp^snM6$%BI286)N3}St`W8`;537ORGInP94zAQP~ee~ zyh%wGg)pa2#w6LKH>bC+85t)c&;bHF*sV}UtvbtS^4VwxU+am`!YzC9ykUIJ-NIKM zFAvw^Ary~&$k1Dz#ZtuzY`D&jZmc};+n+2G3|RYHYzO>w|DT%HCV)S*Lu3NLe6be> z|L2P>CH~ty2LIUO7xI4sT55U3j?~)Kz=SigC5aStdTTf4&=VKoz%!r(jmN84hFR7i za9GH9mcKu!XNT}G#`m-G{2hs#_(w4g%-=NUU>WV^RF3?B)tp}z>S-5<-4|z9 zs+MystF_f5BH+SximB*2=A5eU33Dos3$+h;|F!h57Ll3Oi6qd6HmGfCxr?c0Ni>p+ z5b?I2rs);uZc8XSNicd`x-O0)5t!>kwA(CB9#nVO%BSNaph z>r53vMkcV9P5YSez8g!@QF|xMX?+QA(aO@NwAKEzl$gn``1__+n<0!;VCm4wuCp~U zzV5C-jl@!0^U24n5@O{f63z=OMR9K!dDjR{iCdvs6+j*wiC!aE3xPqt>_@1r?@~b4 zy)2lW?h8d%eTSsWwagI-*&mSfJCxR=(Spt2)tuTKeBjT>(1U)lU}sP@LU2;cdgm6h z8^*4u>m&PZaN4Sinn6nguL*lk8QxpXhGr__h!fNBMaEBPY>EORP1b0Crrc;*9v;uw zP1qZY;VYMW5c^q3x;cU8E5oQo&SsDM)-KDut?{6qIPQrGPUhBA@-06>JBwS2*<+=(=0@BN%ezpX5@%FZ&`g zAM1buFi{A+c#|f2noyuB3!xgU5d@BO*0h_|Z4{&8vW(LNnw>g1c zVZUa{r&!xJWGQX|u!(Y}IF9Nx&*2GDN>VnqiJy8oIrF$<>G%m)4`{`^QM1AO%A2a9 z62v~iAc#~+p$z?v*hL?lB%(5R9uz||$O~IUq<*3Q(CVBseqQ-C%tpHlXt$&HvbLyQ z6Nb2m^ zBY+fh&wzYIweZP1-ez!OXCW(4JXvC-~se9>=T!;G$ET)QzQM5++=$r*3 z>A4REmgeBg{$O30Jdm0nJ53&~V(v8?O6WVnN60@6LbIMkvN((0q>d75W#4?Gj1-lp zB50i-xudF!-}D02Fc<%Q6OT+9l!MBkT#J}-Oni!JD6aY2r7moGikN*t$2L_=(0O%b zjo}zmjGoROAKpb0U7Pb;tPW0zwNGM`;kEr+TjLzF(5ojO;a=|IVM+;ChIS`GXaFsp zFsL3Q`66zDid{c5UIEpC_()KOoXFI1KX`041!OmO>g5!(D^g0(d4xKl5cdf+pAD=!nx+cEO>e=6eAL%Qjbk8# z*Y>Vt1gpv)z|9z^-WZx@u9f3*LKS^1m6MtMnc>C8d@j0y`|p?@KL)XWB&PO(aDLLr z-*_kR+Ym|7Syp2EuQh3*E3MV&_YDkl2!K(uLz+A$b2Hn79K&-muwstO$xD*!Gd}6DdfwkR-W6igBtYd>MB4>x%!PqsqpjqY7n~s5@MIx*>n~Y3XndWPw zyOdv&MFUVm1mq7%_8{CNB6rJdGWe@sd6FWhREH{EeRB}ByoLhT6J=qGPx&fEUszZd zINdJ|Kn`cc>`EzZE$3zfQzXRJXKx&$wg_j`l(qOgXSA@?bvlF} zKJRVT#pERi1KCWvl~)~_B?c$Fs7L7YWb*2q+O*f`IfUx4C5>SQE{10xB)K{tvP!!C zWbtAQpsI4=ZVB(MOoI2+xNwZ}_eiasEh7CoDxXu>V3gV~is3$z&RBO1S1!c!RC!|a zKeCY&n*$kce1$A8F~N4iRz`AvCW6BIrR)e)>qMClnUds(u-;KZ`X{yoRkS9mqr5Z8vFIB40m1#p;Wd#dsO{t67&$xa}4R z*7609a2hmTBj-|>o$}#=x#`V}R0qv!j6~YNSkOgFjF&Y3f?cA%@a~amuu~XuMQ6{J zn?-V;VE|)5r-{?cesipm>kSe$Q+fDBS+9?tdH_~tm5-vsYoIb2+YiaP`>W4a)LaCY z?$*6uhnAE$qyC87kLTkAn~l?v>L}$GxU8VsN2!FTTWnM8=ne65d7K^_rI_cq(&0!3?R{@bj_R^-1a_We+3 z9=sgw(wp}&S>_clPveHthG53@Xm*>Jj?l?F%B6QXUD6W_Q;*24q1+7uCO^_8?|Z^? zvhEv9kG!|}kD39ZW=Fh4Nym2zfwYh3PwyWCi*;*56Br+ZtqbuVQGc1)9>xX3OaB%u zCSZ+CKo4@62|x?zLI+K629HraweC5UfahM!WM6^NuDpRD&B;cT+}GsQaFdVX26a@l z^wkSUt{A6Kq)ra|g2qBm^BM;G0`+z5wqjPYzj$n)$g7 z;D!uV%-o0@#8;z-)k}$^Mxfu2ul(*2kOxBzuLsSkXc1fy6byBug1KNhYQ59jYMAdT zhR;~y%OtM_Kj0i2K9dsh1nUSKKpZt26MqU-ZkpEc~Tk~5oW4dCW7Bz>I;vH zZ^L&{xJ<_-`{1RkhQZS^>|}`7P!QGX*jE?Dy6JRnHbpMQ>TMkY#hPyNLZ1ugs>8Zn z$1y*nA&>YIEUULtk{@*CH~W=UELBg z!NX&Oikk?7%MDj9btQ3$xNXPMo=GU-=8MWMgIeUK!!E(lfmkRicyw@c`aS%QT%dI# zB^&KA%sxqT)wY3mpnlM@0Q?Xk(MQWza~4of-34=DjB1%XC$>3LWb4hkC45~t;SVDQ z-9L^!ChNE%xG=-i;uL}NB?;?JcE-ZFHPnwrWRw7``o}J+UX0z-uE@1AUGA8^G1nw4 zg&+`oh#WcRqqn9E$*>zMAxMsYZuB45(0d&r{Q{B3ZhQdJ&FcK!+t_edB@f=UvxArz zgb(^nGt`tC`Dk4wPnek1o(XXd&}=U{X>Q@0acH z2kgqhVx3d4dE<<2?Jl01_72qns^dOn*Nju4Uhf21E5+vB zFhh5LYbh%{pj`6m;`EufW>^qg{6SgiaH2NS>Emb1p=0P#w$`9;M{2zC^VqmX)2KEX zVNx@|C_t5YMhUACjSkt5T}OgBiZ{Jyo_^^@gShlx~blbT!d-c75BG9 z5kXjQhdQX9}!V`~W{S!+a)6LzuF1BB?x%Xki^# ztU~K|ts7-bgz2P9hCyr*Z-Q6?mON7cI?wi{Pm3CG(3JiGu<2!CY$Ym~7kqrebCVbF zof^KNS2GVd1X?r$hTLr#EBk(w$yf3?`YzC>Z*pw4$69%M)YGDUE38@~E z-GeYy91wHcxI|1Sh=7&n9@P24f&J`#!mQ6uMt)0&RpPSDYpn=orFgkJ7t06p(-_4# zln-I?-1~qf<{-V!k8(p$BRy?Tg38a?oFY>n`rhIZ*`N!Y!skWJG%-3)_vIV&?^Lfv zerEUePY4#9S{;W;T2Nv-$Qfd$%BNbbNMgJ@R->W5wVcXn!L9SPTp%cW!^0M{V$~_5 zJjWg!TPP4p(dqf1Nn3jnRmqLCh8eNm%|BcMp{awL*s2C`{0t>o_%rZhPPLgRRJq&i zPkwOpp_O`rtdN&M<_FQj+AW@rR}u|8ShQn1`aJ{RL(4VN)Uh>gaqL>meP17_?(-<&I4kLs_XCf=-rqh@YsspK29IRBY&1&>pRBsO9 zmAikF=Y^Nv=F`g0y;&EgULiMn57-L4jga?!&x{(|ucryu5oX<)m=m}oo2IG!jX0Il zi{C4`hks_zO1=Tz7B3`r1B=V~@6)Sn&N?Z;Ry7dDoG*}#sm}SN#iu9fcn|}}XTWPW zmd;nC>@q?cJ|2;kKPTHFLBF}Wli;LYx&i|yof>D=^puKVUmh@)p;~^A7j!Gwk-n= z>e`)m%cJ5Emv}_|(*YR|XGJDFjPIY=5R~hOce=e2wfrbZuKXDfeEuU|;fI@GHV@Fk z({77fGMd{=x(z>T`rXm*m^WenHVcxlBtNw+ zdTuGRl2jl;sH=>SCvSTStbx?ZS^ULsU$CG`r6hgM6duSV#8$Q&ncoQo*|%>KutFrERV@hO7B0H77+z6izMvGG%(RKQg<=zBY*7@I* ze4nF=hVS{@!a~=TF^nFFm(yS_<3>*eHul3Tc4avyd2_&iwyM6$xw${p>9_3=$xS6- z&sE%XPaPFKEi%UqCY$JQU@L-M3-}!Nzw2MJMRcs2eGO-nX(<~=JS9J1XbMOjlJqZBK^?nfK zk3HHxI6=qLL>9nO9r`}uR#aPHHEA12v};4I>JWUdp}a6DUliF`M6rh~R0spd)iNEs z`h%HxAZMx=GArZtDQW8AoIy_NGeM^f2zV)XG}7b&>}oLONy3)HLxQ!8Ma(&d>$&S? z8gOgr(Iw-I*O8q!_z9*#G6hY)%5$aZBz;h(P)lbWW2y;GhXH3RGfI+2X? zy>2$lU}K0m*HjpUuQ?m_Un+8n!K*VY%nt@9qeRM%}L*ageVxJEXCvt^3phi6PWx|c+<88sAZhew3Yr!4P?AT~ zjSgk3;Ar~2KU6V#Wc3>;^J8MO{@Q1Pf|Q|$zU~jnb5P}@1I%o!aJk42b1<0v@S)T0 z{#wfFhFxMOub=$Si;ITnWatQX%d|=K7mhhN!(8lJ=??W-VA$48(A&ky2BPHsOIF>I z7H%RRts+3Z;PmK!$L;qb2hL_(0P@A$SVw+u0|wmSuVOWiewFkjCYBp@%U?e3vw;(9 zx)jl<+8Qp}w}H7!*a-R5v*}Ww3!PEaphNBDjfF2L8rgd@buE<7$p=G##%E9cQ#;~a zgLTP=X_P?&>Ss1dvO3zm1-o1-ZC`=&f8n$tv%f317xznY2A zzp#hwzB7P5ffIt?H)f={)Sy2%;vFSeZ|5YGr}L<~>qmYTj7*e_4Z^>C*5oJ^(MVI; zL7T5`qly~AOrvPPU>=Q#w|%2{ntpAdmNA2vu;6) z=%Z2YFurP!5;LGI8&7h&A9Cv%m5~ody@?1?zHJ&4n^xV9R`?t6D%V_A575Q+miUzS zcBhAFG1O5N{~?IGavhRD=hs6RHRWJ$##+IsJcH4JJGZiUaykDmi+boKQVe~?-<&gX!1rT_&HlldcITJ+5M)U zid$G@CnMHl*@FXWH51BJkb~7s{{|_6`pj#=8hnk>M>GcT62U z{_0;f+9>v4(iT`kacs@#DQj3+3){SWZhm;>9tJCz^&gh$OyEyV_e4cV@pR86l`nlP zZa8%bQyz&&;>q6|2!H1t8$tcLH0D{u))V)_p6fBd$@3axUl-@;V$xWLvi+VeZthz^ zY@@TmMKD(oPcN(xv;tZ_;R^U9;O?qbip@IiX#N&$B)@LDr)vCDoUzn4R zI6P4%zRgr4^$gO}>GnmW9>-!+pD3xANbb*jVMtjyK(D^jCR3_@fiGPW3oi6LBPEL> z+|)7)(B+#Q;f_+BiM)UJJvb)}6TkT=^vt2)1&|>I25jV7b0f!N$_S%l0?RNLqu%{f zB?E1V|IF`|`vT9~9qkh6Sk${9HT;v8NT|=@8j-G~WMF(z9c844QO8WWY4LtmfU(1y zh{mXPxbd=eu&E)#Q|WwIShFit_n#+B7UYfOnAf$+F(W10Ay&MW#4EuV8rx1rME{qo zr6NUN;7{M`ViX{z`2V1L3m%hYiInvxb1K|#?#7IDR6UPX5B$}a71;2l{5pd-Or~z1 zC?^LcR;B+LuGBjD@uQupk)vh}F4q_pk1n+D(m!ISb;V|>fI#^{pnW{j1gJ9t{kT(9EMLyezaMC$(Ag9 zl_C%qwuU}og?H|)9Yr#Kq)JtI6*cV?RSc>S3~l~mW3Kz?hpBRqXn4tNK8Z!s`-f*< zp`cKu?-6(YNR6u3U3_Qa18pUQfs$t{CWjWYJW1~j_Qns?jSz;g#dGQXgfR9e^8($( zO;>3g-49%1Oy`y4X_T^x3&YdTu?NDDB{9+(9mUc8Us%f5wv45!#^hNBN0V683I%L)V_zZb9nuP+l=pv>Q~`*8 zZI6+TsxR^kIr>oBEfj$;hjb$E&nuX%JwIO}e=FE?M~ur_*>hMks%6|fU2xM`oMrcn zQ?JptVCQkY{g1>VujC2 z`vf>y!0cMR#!yK<2qPk;Nc1>*+u1oc49Ct`-jG&B%=+9*HY!s_Sjt12zrIS9y5*AQ z`?@tF;oAE+9Z+Y#&68W!y{69{gb()cu!T?sUzwDvl<;XZo9&QQ^!%eQSpq%E>@)R7 z_9|4(WcgPW4B)9ud*PT6H*0J%Xag3HQz|y-b@F9cQh8g%6uzbJg0Q!|#m!ku#X7FZ zswI7@AJeqgzLzmO5zs|+Rsg?$JiOg%)V9)^pZCF1-*UgcWx_q;VA-|J4?dF&(7!eN zEbjjfReRoTwoWg`E3I7mPAhbg-?NN%$s-8~ljEM)N|J2Ov1L)C%^w6jU(e>uZ?)5uKIeRwE$IhoZt@QED z^j4zkO0q(F9I5!&6M>~ zfmsX>%pYXd_1)@XTOkuoowX?vTUjT9b3`T*uJi$9Kf2}=MVDR*`;2g2J&IaS@%NLT zupC4gIAUI6SW|ZrbCar`IG)bG>G+YOmRgVjWHm$u-MXIRG3FvcTMB(fB3~b3q(L>J zp*+;>D`h^Ig4tUm-^LD|@5$?~0(=L#rkIw4KGX(eD(=fVn!rwp`RUKRQzaFBtvYVf znP{h4lc0oX*L;o~$pe{J2>r9djqtbD8(4JajGCj8d^->_KTY_44wh9z=y1>tkI&L7 zwBx3f>Sn^reQmtTbl+^6Ii-8OBZ#{sgC~*=yo0M(zd(TucPqNWjbFdQY1k&v_U<%H z`>B5xh!_&8gt%+hmA;)fX#JbOLPfSpeuv6z?SEe_R#TyZPoubfo3{<7`Z z$Y+YRIzSwEvAJ3htI0sK@;=Tul{XGD*%EgJBXud_^<(RsH|tki72J7*mY-Ow{@Sy4 zo%p#=GDjosJ+8TT(r#J%XVk5n8_&i$ROAYW^W z|8e5lxsOPE_afXC#vSuq&04fd!GwK-Cl!WAvZ@3cYA-A3AEZS37c^Q!ui70ml?3!;OWE2yLkggn!Y<97&tTV z)NQBJ+lj>_buBP+ZKdu?RcA)_E$7sCYT+{b8boNBe)atAFWI+C_l~fe5L@r);(sru zr{CgA=_)jbj=LipPZZ3M)hTPuk!$sGe~JrR|Dt$xfe?c=vOQ2-ppGTafMk1sY@UXg zqEDiZ7cQ!{Yu_{uomRe_ZHmRYd#-NoVKh)aQboqc14H4~DoLvqQ~Ix-)Ue;p5Q;!Y zf~2VyugT)j#pk=XtTL6t%Qi>KdK)sk(%xFH z!&`~h%jp|d%hUIB?ntN>+18tkc?^MFo%JJIDHVsboo!2{2j!?m%`g@}Q#r8WhP-iq z&#EXeen_?Far2i8VrG1Kk_4D8+(8xk;sx}0m=aq{qz!BL^4*rT?2W$fDj^Y`m~dNC zpX-||;=JuXWVg5n_3D48l-<&-UcUzgAV;UjZc%TVxcv`pZvho$*S!y`2q;p5bc2L| zAl)s3bR#XHG$@@jB1lQMGy>8fEjdbeH^M004MR-)Z}cgj_xHT(dDr^BwbTXU%suDq zeVuE^IcIOj@e75;aIN9VAp!UuuBBVLmErl+g(uL5d7Lum9rZ9;sV|$sSmb!RWw%*9 zX9%;u5r4#WW_7snCB9~Nq8DnMbuZ(68Bm_$#!1t&YD7kk!>UQgSS)- zgtCa)D0b)-Pg%r?p2fl67VwrU+sleo4`aH*lSW7%ZDeb2G_4JGJLJJ*!Gr0vv-_0X zv-52^d-q-(ZWuJzpR#xM5!kki+bIF}*W%>Tpi;iZaszsI zLth%=A;Q10RwPH5`MPkYTEQL`-SEIJXs^Ft+mOyixQJC^q5&=~mV@N5QyuV(0==2T z)yaeKO?^O5VuHMyf?R@6~!6;O$Up~J*1bl{2Qle zUC=2mSmr+JhPUKq;##Sw%bPI=%itRM1QLA-3thxZ347XP8XaZbsV2aU^zngBV<>D* zU?ijC4Vtn-^O^5bSm}R9@5XpvURdL!lde`15}wyKrPZVy*UwWJ(H=jP3UWXg6q;qW zY*2|s2<6tzbRMnP9BvqE0>vFK43vPT(~q$qp1=86TY5{(zqHwp7tflTOHF6-Wjv_a z@=sSVN5E+8rz<#9oqKpLTOO<2I^q6sHGcX9={O^oG`5A)^^fA2k?nM}TDsU;)4^3M z6hF&6s&(#VVWk_uFh-KQ%L~IYN(BnQk3Ee7LAl#Pp!)PLW12o2RVwrxdn!*y(no0;?~L-UoZ7aLzq_PvuW;P^b9{jdDV>=_hig($AsHwj*6AdrwE5GwO- zG;tsFDwu_7btI0q)~?mCySSGSvVzfpi5PDynrWwqmR_5KMVoAXN6r~($^j{PBG-mI z?-M=<3+I82$z|yeEs2TEnpX=%+SNta4rbAss8YdtRb)!D3191M$(S|i8w`yf_O5@Q!wF%7L@v~06)r2)>()YcYmC+d9d`1< zN5`Ss>dLB&>C`rm^<{*MdK1v~I&pPa+oE>2{cFLkZx$Ovc}^l-cdyB0LuI71@3uv! zXGac}c|JOZt@-Nvo`q9~Uc^l8(k5_BUzrr2+{F*z@RllsE?GTTI>Yv)8E?$4&(^^> zkD*3@-Cugimq|c|Tcezu(5F*OTA%FRcw#W(20Z{WkxZua%1cL^LSFClT#gQvYQ=K0 z9n1)OkJ=^gVMdQC+~OP?(LRfq%SsMJim;+o7G7IE7@IesM9=CQrT%yjP~9utm;Z_G zIK*r4;o#^YPXPnU2yg${v+7xpe47BRy&v4uWoNty+PlsqR{iOI-t80v90*`K{rh~r zbJbt#-P{dMH?&b9Agy*uJ_yld?xzaJ z?3e`V`j^^=k}9yXf~9LCSViCG=r9yo2xibVZmUK#hLN48lHY+adhgRFup2`;MOf_4 z6oajE%9T|IzqVGo%~Buj%{q)`b!S$}vsg<%&Db)VDd@s1m2oQ1QOQ%@k_15pR(JN( zwAWe1?2Wq?VbRpHHfPTH?dv8!%e66}N$0e+`$kQRM1XR?c%ZUc0(rmYJ#L7O>un=+4NNIBFTp^1z4}^x6k-quOxjp0-G6#KBN`>>2jrAShR`LN+~5?E6)= zPN5_nrZe?87~h?K9@Y^J?~~H`*{F({6zK5ObAk8pz@u`&0*GBD#FDS#P&GKiE}S$N zpv}sc+8P>}tQoO?BAzf7&6550;@*K_EOX&t&dFAw6OmquNjvGF+7rEF`#viN9`smw zi(C)lYRjD+ugS?S%T!Id&Uq`u!ahr(&1UWPUHEe0TTAnHgx5Cb?uNIUA5vgV2;}aW zy~NE;5&6_ei6FIIz3}#4aPgO!lN3ZMn`TH^s6P2*RMRq*5*atOTa|y2g3t!_O!f0S z9nYx=--oB~GDl|9KpFsU*u{_)9#f(E(lh?}1~On#_aZ z4Ma#*7c@)2lMbcq{_=v@RcTQ!_XJipENafx@nGJi53dzrWjllnw2Ij40kXBZBn8H%39(0PhB%( zrR3s(jk=EP5qF(m>R8F!HfDZL1==PR-4-cRwBS#ZCBHXEf$JLVhe3Un9dFIt?Z$h8 z^-w>Jlv1{lO}sc`(|Dl6T%KE`;D91E?8_qk$Yvzn^M2N(k(vsAx))e1Ej=}xU21KJ zx=#Up%5o3D)^zlNoU(iQ+Bdj@mv4}#N30jf1ss^XLqbTjkXVJTwdvWoD~#pII@GX4 zS%4?CH%}0+4(Mx7o&nVrYwi8CIt7Yw9V#~Gx%7?l;F&(dO+g7&9fzT(IoOk-eq#sR z)%({fk_YpH{d&9*JAH^1dT##oww_ZaY_j@Kj@xq|uPaUMahPr1J@#}^zx61bt-Ym@ zLEcf@7L$WM`>j^17LGFsGtUKeOC;PX4Z>4B}| zyh%#rRE2%MEI%r2Xh!Jx!$^x(l$G^RXCNq+mQr39!|ztlUs<}FFpH?)^59iM~; z?~gw*REV(joCzzJcNBj6X~FV=*;6G zp@2Ig8cyDv1W|!m&0YlC?^tOhJN2}Z*uwIQth_c zFF-JPb3X7c` z^l>&BNf+O{K6q>U7zb`2S){$Q>xwX&=O|aPa&S+_av)E}^HS zLg@rU-GL?nG;`zFAf4&Z`6cn=@lGOWE^O#F^#LyprB%wpS^z=7x$x(<7BV!%q`AgH zUf26<8m$GCk6Lr@;RwD;XtrkDI(Nt2-1G?gZQEHl>=3@ZX!yldOZ++Og5VxflTC+O7ks?WT9HS_+jBH3E!a&xS%s>QDH-#wRkCdcMh+qa*s} zUaO%gP+XCx!E`VABh`LNDH4FOWwhZatgfQ1kWjufHt+0dj0Q6mQGrL0@z8e$@VA+e zpb*T=vUN-i?>f0>?4*&e3+;g#hAvH+HbL`-_A33o^43EcMY$fb@!?)(nVABmtxL2| zn<24sK>Z=feKU3?bF*K8tAHTG$-ARM>$Ku#Ea86?NSF^j_#o)b)#LyH(yn}3ivmGc zD-q||!Z*}I?sBQohq6*ceSM+F*@To{Ct4#FZ~0c>shpa~9(f@-^nzfYXcnJL924Yk z**eL1W8%d4Kyh+yk)Q26T^m1ED7(Ih)js?%Jnuh}b!i2NlV3!4s+LI5h4^*p z8H2AD99te|3dc=*XaH9f3RpIe8+as0SnZSOf_VoRd>60wy%R)_s2O-r4b){b$G1z% zYhmNP+R9`}Yq8`c=!xrB5?hwU@}SnOLlP#~*DdjCDJXx=BJ7QV9`AKvsHORh4RkCc z<{+IVXsg=Nx|rk;fBW#hY?>J@#5F1I6W+W2j&j@cH=;B+vgy>$;HcFGx~@?#E1{sU znMk1Ag-!MaJGBos8C$xn*WA@K?X?)vt|yBW3e2_VdPJBnD4mKj_-B02@j?nSI`*<| zeWL+3UrMt+zNDqY-^N}-8$@@!Y1w$-E+N2r48^~*F7~-AV-COV)QC}e1@eFuIt{$W zB}P9eBNLqma}4S6rg%-?Bu=Y3icdXOqk&8fq*op68lu$JQC?Lckp`lr8Aor;z7TD< za*vvoa7=dxrN~Nod(nub7M+Iy{edQ@N~eODEukwt#dA;Bxvu(8Qh7Y+=)#jmAoSEc z-DK2qI^pnq9}?gP*rC+6j$8j`Px(`_5Wkv}DTk3`A3BT?A=&&D#&*N|?XhqDuFZdn z@WZ^SIkWVzwG0of0Em+G=lsNvgDA1_oci(J`i0?ZY^c}@7Pwp+kEYl1Hli;<-qqL` zplv!C#pN0*!$o7dI3|06|DFKwOISCe4Z)A2jlr2mSge3#HzuPD-HmU<{4$Z(7ZmyF zyIQDL+{8at?Z19g!rW(v_WeQQ?ah<{25lt8ir*Vy0stMOL}+4>OZB%xKS9`=fIJ6U z&n*IZS6nZ7>?)a!0PTg|Ea7>eX{7wF)Vj+(j6#vz;0%5DMj7m!6YOr+yy)Q<()Tb} zIWG~0YR3=!Cig@fjFU>`+m=wXGBMV1@ug>;Cc&*4!2RpGYD}g3so(Ia-=3Xi(1wIb zd+KAiKInt+xe$A~q#9yZk-V#7kz^EaeE{>X?8-mOcQy^u!M3o-X!no&Vb8ZG$x;}z zsv5_v7AeR@R&>O&D-WZ9BCs3x?oN`eiX(!EgSJ|pdP_WUJ9TY*>ack85U49&7~g40 z8fZ$oVj8fLUtt53~K zo*q~;Lxk@Fef8QIN={x`H7vn)ZsG=F)~y7!OFp{cG8WMG@(p2djoW4;N3@nVR(Ew? zlul)t4E7NU_o~16_-Zs7e!RRccH9$pDL}9Cg-w~3r&`fpG!d*ZHiN7S;`n1!GSb!z zk30lJL(DW4Rg z0gZ28i@~+u_A@=n^)!Y7hnu@wf`t_gNoJlDInVxEF zc?aK<2)%o@OBQIF6!dzmqjw5FCFm)qE%t-zR>wwNGikV7mlbb>QR@RGN5Sx;WIUG+ zrx@UjCQZ*L)Ie8fDA`Rir8JhB?vmyIe zsUK82`QI}Y!`(3#Fqd8Cia!M~h21I|7BwhX=~cUDuE?)^4U3KcWcE_yac?ad!Dm3D z^E%8BP5A)ULOQuY_{F4hYMi0!W*Qw8T$ci+B;DGRP=zY zHrP!SNr8Z7$bpvm5#ch0#%g`zF^{AO&>qv%jBuWA*F``nuV8>?x>0#d7=>qAJg8rL z`^gsyWqnVQeh~Gf|8-@6- z6>EFKmy#KPDnr$l^2I`aN9Wld+Oj}BI(JKCk#xuE&AZw(Qa=(w@oKf^#)sj}(}T zHqhHiV`fxGSl2@Rd|-y1Wq50w6!M3(Bxt!hVEVH|tsN-UpWI8(&VkVxAfwMgDxWVt z{j!O{xL^IQ4vn0`cqh=6aYij1<&5vKMCglzfiqBkumVCDHnE(CG$GSJ|4D z)GhzKgJ*zRHQIbX3YPlA$q!A?S;2KgX$FJZ@$Ga)2YPe~v447{7TQRirSO{3zM?6n z1EBi|jt$QNQ1iCQp$eAiQcr-!sH2&ZMit_Exo+Bmvw9=+tsO4M<`Ue;W2=Z9Yk zCi&EK$#M0sFE9hoa)E*SjTORNI_`zAu$FPx4=#L>0~&3uJr9!wWi|nwVqQmg2=Kb~ zmHJdHrnp2P!G_m7L`U7o!*crt$v!-oCql?*Z9Zti!143`54?zMYB zGm!8%7b7Y!|7de5pD)ZRUo2dVkKLx5w)UxXkn654Xp9cyJOk6u1pIywHHYW1+swfl ztAl+<*Yy!&Ni(6YtG@ZEOQ;_=tBJJF_+weaKrVhQjJkB1Ce^K^Yw*8IV)TGC6%fj0VfsUFpJZ-vEvNqdvz~S7-Wze+C z>0e6V0<4#Vrz&h_jGmyD|qgA^$9(swb-UM35ts$?5$mz*Ywlz6ph_wzHC zQfsPk&btMeRLAw74OK6Lw8`rGoe!`NpGrHQlNkm39kBAl1Xto#Y--o~$Diwn{tz1* z37Y3_9mf%N{W->nXji*rw>HgNu5$Ri*q1MRw3G0Uq+K`YfUPjvtc6VJ2l%=#ix-8X z4PCFe+LtlZz)53R#U;ZZVIM*9uAw5`&_=$4cQG4}+9a!kO9TrNjlo$*2R$O%Lt0-( zdrcl1w!i#MbWS}$zTvIR^Rv<5BEhWSi(dM&S+dZ2Ur1(sk{&V)*njr1Et|a-n#X|= zDx-f^^oMacz8k4Jo2Zk(2%LHCF_+I5Etd0HHqDu!1X`-A!xt)I03&h9!IytvL7zgi z!%A4ov*>XRk zCM`#Zi;oNctk-0{Bf*?B(DOl-=fmXL#7J}?Ez(%h=@whMn$X)OxeU>^LT912H+;{A zaq!ib3U#6R4Yj?w1fjX2K@Na|m-xfqrCJKlq;2+rNMqkI0Oc%tS&E*+@MC(QGRoG) z_-Ke)w9lwo)eO|}y@!mz&ENQO`Va6=^i4>~_2qb1a`70_HP5yL8tt}cq~xsK1!6azyY{65f_=X>Qi9_a%e(SWuC7?ByR z`7+sk;a*m-;rA;aP=l;sY_8?bJCL7X_EZCKSpsLg6t^biGIcONzM!{w)(er3gCq?2};!(8Ems(z@3vRIVbZ8XLN zFl|TMFLx;dK2@N-KCMu_G%2lAlgC2AIl#U~+=$wJ)(+Hk-XQAq4&bha`q|)Xx+wvR zS!R8mTpaMcPejdnBUl!=!^*P`@&NgYoqdj%cs4h%_Q8RIq~~oIcO7Yd(~np^MH^F~ z;29^ywirtqN4w;+3!?qSUU8tm)$qgDcS$$y*1)gC6p3q-pdBB~0aN?2kn1Gc5~!q| zrc)HShGZ;@ir6XdiR2JO3I?}Tz~X=#Aj^7o)9M!S+Zuwe032;`wZ&0s!+(BuMG7dv zWBm;9ctA5MmzRU~+2;nIAe{hm$heg(r^M!M%tf5z#!+jh8M)15Gx9Ag#Bzc!6h~4nmvUnli*H3XVP0d;z1gpqTP(}*d?CWd+R=d0W1&TR-XX0JzV`)>pb{VUpT}?8!9o-PQEjj6R5voNK>vEas z;RE*iBYbW}8X{;whflpZq0wVlDDZHQ2-YS5a*mI15FmHe=PjC>h=Ed4~gQrNkEv7{)`W3Jo6ss1$aAohzDNQ zlm_q2yNgjcL5Mc|BzIR4yovYfND!DFyA5pSWFJ8__LtE*BS*W-OG`vSS z14<~zd*QptL0(v?u9wV``tWgX8C|_a27ObrIA;AX>36Y<2|On?sUN%g2!!_Z zs2LMP)+iHsnCY^SZ#b!|lSihG@Mn*^_Tcp-9*kIs)>qz)+_oclgem5%zdlM&BeDn% zI2vINIBY0pCmhbj#Me9zp&ri7pe;A!1db0$AoqeChbHtR4q_h0B`gCM9hRy8V<L zx-Ys|=L&gLvWy(CqO{I8WGsq`!ih%fDcL#_e1iK%(g3BMr|~v>7C=28K_WRlZ=F3m z&SwdcQ`Eej?$M|XS2UM67ctu1_5EoyLDVNVA+9RGo{)ECezDeR!-dVL%Xx42q!RML z=^ex;b{M`4XjuYJ!~31V4C_pJ$GHy|@z*N#WRe_Nk(-Q0wJ9@oGc6UY(}wmu3ySmW zR(SWQt?yWU5wGsNQ&uWS97_WYzq*SeLv1e7prg!pRbO~-3s=~Ns%x{hXS-Pc$mAgY zQUK)j(&@Sja_IUeHv+@WZ_%B@NCEADu?}*;$)NE*rsIn0^@e6SbP{VNSf$Ax9G$=A zS3cvRxDdzAeq`zC{jt|%mB-Dfc9#Uv&v4J-r77jf7V`c;nwTHe?1s9&-4?&B!rL~< zAiVANee;WBH%yJUYzv8&W-{F%Q~9+DKaA%@G(bfvDDphxbkARm<^IH6y*3E^`+@<9 zA9dc7Knbo(eTfqN1odRf)z-KZX_+w{dbq0ct*}J1cc9nZ^qn)Dar4T<&}T4EfZ&$shh91Jvga3M2#=3M_R_J7rf?9 z*Y^@@{lk@zs@)p}qSDEXFKnC=1!kE!nT7OJ_9+|crRuGGlgm#xDI7z2j*z>?zQ{r2 z&%V1RLr5e@6t=_^7cURt;*Z-QY~Ww$uQmR>qP%RzZ9*Mnx&v%q&6Bwun)TGm`gBY`+#XujC76-8)? z^n{)F{ca$Sf>t`NHJ)`GdYrtoXJPX^wL5MOBeGsTqwr2|JsjNTJxF3lUuJwf)mT74 z!kJ274pAgg&2B0g`NnnkhpW=(e`Q)BMPG4$B8Br?cW1bO+@tlasrgs4wPiO`)=$e- z+ap3dYWe}VgE;aLGgF4I2|MwE-L_Zyi>At(nkToJ+GmfgDwBe~ju?+W^OoOF_pCQM zy~@u|vKbQKk_CnA!Bw`V5uzK`Fur)fCci&R)hKHdZ`eBTwJoN+*jS>2@3lhrCm~28 zIo#dW>`4CXQYM48=&n12->Q@+9f504w^mgbJHbda2_(X?5FUHlDY16}1;33H|H{Dn z)KVW|>1-f0@4Kq)gE&q+T{n^VKG9#~aU7pFG-6mB-20&8LrFEZW78QyGyH%DzG;GL znQIW(Kert;Ggl0u-{E|tKgYxsUf}?xVc}#^-I=eWJ6J2?lxUrGY+{e-kLlDlcQ!zu z_lEKL9QNy;eh;GAtA!wNQ1DfWmy74m8sQ^^W>E~N@GLcPHIQ10+?foF7KwyGzAuIN zd-&)Ub;;TLIDsV29#W`e2-ByB98|Mi{WNnGX+ez1#r zUICp+upWnra1ad`ej5y5+%1PK)_!zV#IY%g-$zjul=bhJ6;Hhv*0C!dBz#{0i0FHU z1dgiD9(`}c zoBzBEMul4?mYFo%<-d`6qE8VCzdBnx{(V$+z7+;VJr>2YrjsWe9*fHRX%i=!Z9o0 z$7;2^A`;&mytnbpazT!Fc{f~~Rw46ysSz=6EpZ(B$lt_P%RDS1?-vXB~>|YhwtAwY5enA=)7>INCjwekHxlI2y>OX$v{!jdX;DVsj@0jhb|DdiM*l^-zv2Lk;0gj*F&)#Dmj8PA zBB%c>N!MSH;0|8}Bbgnb&HbGLTz{bD*B`q6Y#seO5yH<7{|g#^B|@sZH_id)B=>(> zS}Xl2|N0AvX3SeSjOyF|?F{C={PR2Wem?omfiH7@a^I!?X9fIf&;J)y5nS^{Az*^88+ZQw)2Rqg&Q>%2 z{}RnY+s%J1noBcVRMx!kZ_Vru?d|`~j$cB#Gm@rL0=bX(?@b{<3X>+_f6bZy1LIu6 zcsN_`-@_OQkn6(#;)z^}NCI`H`oDRbKBo`<6@FUqB;Z4~fPDres=oivzW85Un8{M3 zO(NYgVz5ADoymXl^r)Zz#pDf^ZDI-2 zhanN88OX7H=s&Sh*q>#80y047iQh6%@Y@pby~v1b3Mj;4q=cK{?h-TnbW}C@RDuPo zhWsw|IWXf8AlQ5iTtJsuSFhK9IRZr?(U51Vi_jp`F|`B;L?UXr4CXp}s%ET~zP1ug zXYVH(DKdL!aRNx%&54fm7lC^YdWKF^6Txu*gL%QhzTfCk2^!ULuby8lT$5VdU@ofs z{fuB&k|z_5h|O(td4>qY&zXa8h3zT2-MGg%=ZPc|A*sL~>#Fp;@kThh_pn3HYvJTjrd!qld2L6Ca6<`N%*NMd|ws(83_cU<{}g#BitAw zwUdjBk@IbUo*3qr=jRMO4#QW=pE<|VbF`AC0C};!FBc8_fae-=0Ba|K?VT8Kn%jDZ z|4b=}yK3kG*))UMqWzN;&hKvrcSA+bS5nIMzF&C%*r}6p_1W!oT!2APHQ9%Zt4C?G z0v|15j==MC7X$*O&m_;J?5UR{Yy9&r(nP!|A->O3E>j6MsC3jeZmI-Wf%FQH8p6{C z!S`u^kf+MOiDoYYb!JH^TtyPOYa^wo8NDF%JE+b)fi%#yvZjN(4u+60FyhuJ@`uXtrjI74$Q_Jj|2^ZHkqlh-)29NI?4%aHPoA z#i_k$eYVu?Nqjelver(TX20ELgg(>+6a-%$-2Vy$rl@87?d9~<7q9{4;dJZO)n>$r zt+>!M{ z_>E+*o^A6O0Ej3vpPiVwaH0uRmth4Lcx-g@t0UcZtbAyDGa4Nx-PCJ$F@al9DkuJU z#+|XMcKj@nRS@vOzK+kywEgh&G(pcB_9;C$3VF;d_iYZ{AV>rFaT%D1KD`#|^>`_z zrhhcIkRAuewa#O)UpGqM`O3QM#j$+RY-DPO>V8K<&^{1ZEZ~7paR((}EQPEi;F`uC z$AN&*9Xp|cxjzeg>kc~%;E{tlKho=wm{ypzw z7gjCgUGobmcJSpHuv^j0%LU1z3b6-DjK;cZ@R;}48xa$4Kl!Y)Q_oHZjF@>bw*ptg zb0Sobr{!pZzURR&0^UBp7Jo6+^ya(TyYS=o3^SR1W3dU6U0LZfISqA3A4OuUn@|i@ymP)e*2O&d;!%IqmWA1=3Er?KeG;b!nfR&e)B8 z0P}+ijG+5zO{lL2Z0GF1jcu|n7UC{}q^fk&g3n6jQg;1+@}&Pv*~r2}(Q;dcQbTcn zffG??AS+&z(v(&-K6)oE_kH2r+trNwX?HraKRm4G4EN0SxM%cXlJHNLbwOcX&am6N zi)D^6RZ=5MMaol8ctkN{#4u5w)tkh%WRRiel*%IBLDtopQiHs3pXLGuK0$H|m$Dg< z1VUiVuqNPdstP-6e4&Dx+4HJrTvOcscHIfjhbT30>@(^|M57ms>I@{qIW z+*NCgKfe!Bk<4$Q1IexQ6n9o~v>(J3e{L>iE^Xc#uFk}2vpGGb%=hRkIT%N~dOW{G zrBf12sHg`W)(JgNcImS5a;cfH7PT_$de`J=f~P5Xhxs0s%}M1;N=MXHEZdU@%%1l4 zbTPKLKy55VqPvqmA${kA`!gb#dz)o=TvL<5^&3iKV+pjiEA96((lGI;MLqVueenl= zs0=#VfT+8CTRl-Nn)t4S1uCmAO|zVCkUmo*2@Q8K6&3_kL%E&*QQ?ksK0QAb!8W1nD+du zi7^8#QJ@gGzUaEbxe9!HBr}?dal(WPm(awt@qnFF55YsH(ux|l5s&xuItF?MWyD4H za?g)v%@3kxZp&pA)@XoOm@w~aTR*8;GPQEd<~JmVCa0Vkm}5qlf+L4jj9 zM6~pA&4cN-jmZow`0B94Mk-A@c8{oOOV@+cCOIemRp&Dr|it6;|x8yGbY;F5lE5k5;u0DY(T8mFfxN%)3Kgo zJp8G>A8Isri_&n$*W+`Hn7)@YJDT+dEM|nLl>118 z2h>MFa9t$cveMqL>jTh{_3^oYH?xaItewPpA`+7SJkGZaN%1=vBI~dY3`VntUi$K& zl>Voa_NeSCpzI;rqUvjje zw^laf-K%p3sZ42%LIfLdaowtF`HReC=9#XcNWKC_p|9IOzeInvQKqZ?Achq?mQ98x z0~(Ifr7|=aTQaWoeJb#kwH%Z}MD=Ml2VN_vHnFZ!Ekc-uVaMsaU$=I})Q-}3$Da;2 z;o=73p?MiuvS|8ci%e^Ak-sb-(dU@=s)M4mFCv)8)=X4%Q3`KMJmeKcTBIz-K$(w$ zk?AYW%3&UBK51~@N@JM?$yR&2xKaHfa^|ZC0;?0v#g8yHQp=mW&w4v~a0||RO(xz! zQZE-%`W<#|FB4(B^B% zPLf610U0;%x2$COCo<0oi%h&(_cG*A`OY8@?kHSwYu6yskxg8yC_jzOX6lRdsPR3b z>XQBuK1^5h*~@9kmzaNLyW?FdntvvT2@5=oiP*ZZ4M%CHPGB! z`4W-B-xt8+?Wi6M%_p$0sVU%O+XhK<`$1wr-W%z<++M-hyn^<6pG)>{hNaZ&x}H#u z4L-D>jrf*5Bi^)A#Kox+MdUTlKKGtvDsE34Zvk zjSWniFFK-Oef5z^%AM7DBR1N?6qRH3j%r%9CaqWdx+Gn_qcvdR9ge?qoahRxd|pqH zx9l6>Jk+d^`CQ0hVkAwJ?!>zTm$l!R2tAZxWxOfJ^p?{*zw-}qDMwAkZcqDkazR)d z#zU8qcIC!rI$)*jZO2>3zXSR{@kA2<{g`o_AN*2L z9}<=_`HO(Fk>5zkSem|5Snx*KQObF>-d^ITY1!PR9->84hE03^W&tiagiC~08muA~ z->AWehS*%+h%{r1p$YC`EQ;}S;d^c&M=~Csn50vK2;DsCml8?&-GG>zMXAK=!jb{7 zmD+aH6Os^8NYIvTaDwd#j3O@Iw8F#^Az$w6Y7k{NI8xtyZ<)!Lvxq|L?Ko1_pG|U` zbwgjfOEPY33G9Oo$fRQ*S{3EjMa-1)1TmxOhY~%POFL|h^>f+lwt^?qRqlVp0Oqql zTfR#qZ;7Iob!%m|0>}V)LH7*`(P~EtUlGiY_8TErwK@_YR0BUdF0{ye9hewB@1c}^ zACI~oqhCU=2^ecW5$()^K|%S*&{i_CL`^U%P`4?7HKA&E3cb6KEq z06vx4{Lb&7zLacJ#ZXC!f-D%mT&DWIcOT&V&ac}z_$RS(Z&f03%)c_3ZUx-CwN zf{KdX{GSrT`MA~ixLquuQQJ?x?iB~3B_Sa>-`@#YS^i-$ zm9R$&u*jimjA-NTgHNDKW@d|+OWKHKN2R}rn5CvHDiJGcTTe&g!84`sX3i2C4FI*; z0BR3v-9I}4Ps-xr6Ok~f-g?6`Oco|*gSv~;L6mGBRdb7n+gA0-T= z`=7rH#*Q0@mpr=IS}E)7lAPHAo=AEy8?kWg&f6_H>g)Eq8M-uReZ(rXW$O&QEYU!l zc5>I9%A^Z50Jv2%0+ax9-(Zl*>z?3DM=;_HhBpzSub~vl)6@Lj_Fqslee{BDic&SU0OmiyLuvn)qRFO55++^e< z-D41eu9Or8J@WM^0N`Cr?K7M<&n|5Sv}JzZ?*H%>m>ZTqQ9^C-F)H*`SbfXF2)k^N z*ObR80Vj?I%R9yr>-T7Ss-Si}P^qS)I@Klrv8?In z?C-dB>*WeR5xOr2h&Q!SDC>1`0wIgl1#y`*-G5N5!71)pKMlaB8!1VMtm%_TY*&Nw zrvwyl;gTjcQGPDZE-^cD+Ni@FRlKj4W)$>0%Pu~C<?CL%^1*M#k4F_1EHZiH_#!IGH64p^ zsyf)JNCn1B)DIhe7qZIc&G3`D5@egO%K^|1Z?AyI0NyED`vQq*KwS^?oi-A*M;2QT6!;yt!XDn;BpM0Ib>N zDE0v0+|k>g_g=9c;a=SVoP$6&i&_()*1)3wIh&&v0~#O1Z>y&Acd=UgPxI_ZKBwqY zOy$v1yny0HR@AuOBEL1`(ixUcPLH+JKiB)Ka-E`DMn%Wro3yCrL}zEkD%;ZABaV1u^+Lh!g_5Dk~a99GK){KB%~ zcfK-+u4H8K`V}T_wb4W21|Z`75s~&d87s+Q}uaU%LWMCu5P@?b`lt`>ofHG>QIl^}1zbn$C9aPwQf;eX=-NLNx@ zM#M>&!}^D7KosuJdaLo%^8KZ%j7h6r|De+7(U-TMDSHF%NX1Gb@Qc)??OHF{p1u-} zZeI_Wt`A~>Yw(pk>|E0yq!4l8v_Lv2i0R?-2(5{}GTemn{}k=8fM# zlkM96lzD{EyW^Jr@M6fD{jjhe~q8**Fk8cB0HTTg9ZInP(Q z&s4lJSVSn~sm6z;OhDG{N=7}D9%A%t3T6+BL;O=$oVcSB%KNqj-@Y%6)&Dqs`}1809BA-_&x7-2=y15Y3PA8&MlV^8F`n;RDW2slbqyU5&NQjr|U>-)zfu zbV;6@pfO=;(+OgG6hvP-e>Us8`%(yP-hvoEi;`X1vdx$uz5q@#c)9-qhzUj#bfYF+ z^?r*91i#Eoh*_~K7#NQN&b=-W;N2Dxrd$BOEuJTxLMo0Jc7J4Yvfdx zQHsmAiEQ6VPU<*z0&nU5GC@#cqA!z|-gmyde?6=4!Wx9iA94jp-n93=GG3Rb@hOff zHhQ5NzT+Npi^T?|&^;}0k(nS-Ra#1wTw{g|sG&deYt{i)Igp_MW}R6q3zW?+KI7c- zK89WjfR2<>ez=qAgLnju#>?=HWrx`6O+K3wSUuY@U{Dw zaf-S6%y~Uv%iep^P>3AY%zQ3+Y2SfUa*P)Z%>|TP;Xu!gU!s0|@}|#v05Fx^-SE;| zEbBtq@z}g`mZ+wz7n81bh6^_d%lWXwb3caoMy7rGsmArJd>H=-kL~wGmKIT8u7@D> zCq`bxj13&tRpU4;d$5}^6$a%)?m_`(-lXrQJ|5eVY5CdPTUNmXJ9i0lywjx}X330U zsf>4a>izm!u3oQC{$`D)c|)v*;lnpxi>AtIkf$h_;lry!kuxgvi8r_v0>}muytSI- z>Lj!ea7t9^i>#2(YhK)bV+1v`G$y$w^O`n6!IY)SFS?uv+aB?BhX8Z3%ZGjb1e<@E zTI$gv;L^e8U)*e`!UQx1)-v)n6Adkq2W2_}zt)pi!gp+0Y(3TAq6x-F2dn`=>Z4+T z3p3Nm5bq_WL}4~*FNq;w_m-+-O|+U27QmTwR-(*uFuv)j=pmgIa!36ouBSPIF#51K z;@nrsO7yV%-NlXi!Af6eA8ggi%&4K(pRK3E|*|EGk z?NdDiSCT9txKj5_-dCv#G*3&=cRAj-y?Aw2E|kUuKnB4R5&aT*u#a|7h&;dp;Eb}D zbPApjqG7vel`O@}-@0!PEwAY38P)tG_UZ(SDCN9%0(=9D%4#dKpGFolSximIWa7~? zA|X#9j>!?VVP4_f{_b`gSRs>xm9Gqgj^aUWVY%WP?1!Qli>nLk1o6HuS#eJ_o{YVVDQE;h+GGw}E*Rn7rn^3g&H-_B=}fgN@w z6lm;tw)KrW)JxC`lxJ)Bh|WmvYWOY~0wnf;t83(BYp!>Wz77GwyR@%8iA^!!M3%Hy zObCp1NnZT$(=lSmF8v&dM^Z4paHuGCsfs)X>sE9AaC5LwzYs9?LG3#yLtfN!QiF_L zZfWhdyFR&qUCK{6sYK1Ijd>H^?1j?;1?nX&FdJOK=6Rvxv-{t;+#l4!@)nv!zW>u9YRGQ^6S#-l_4K1sU_*ywySbGG1vak$Gjlby{OU3M< z?b)0oP8|miCKMXm$=hi*slB9)!eL|3ZJBWwF&l|^{|VN;b{(a_veERWZ^4xP=z)Lw zXDs?M3-)`$y_Cxv_MUKYwXuG5r&#BD{Iy!j6Ep7{!4Is>(9p6Ni_`e&d@d-aMx2tL z6=Kd!zxI zcc!t=Q!x(v#}=z4yciO5X1Co3h*+$IpQz|D#Aeh@(IBxfQD1Cw9q)k5hf=yoCi)Y; zQW;qLg^p>9i#>J6pX2lwLgnJHvl)Ss-k!dv8{p33OG>nt%7VMdNC!Z5r&QF>6?Qid2W>u$b`SdnhvNLN8NQ zsC7uSpA-E=$J%4UF5=~wRuiwby&)H z^!NKnKr0jdg<#S-15CqQnwTm5PoX*Z&vpJdzTv#2&C@4gho~@`qOuA>coI-cV$^OC zp)&>|my+b)RmEI)6#@BV6x{MP@EA8eQLy5~%W>5RS@p$C53$zF%mPJ{;)G6dr>PKr z{I*XDG2=D&SEsMiOs9pVH!jR~7xRtZd@)8G4aI-h5#Re;wbjX1_#`KWfAO|P_3p_; zl#+<)_tgb72Mb}}8HLxpNdj<4d#?XKoLmIJ-EUOh#y~|3G?MvM=K-!;kGK6B*?_Pd zRyjbW3O2p2$7J$P?cB3&^}+EZ6EkolqmzfNRNgyawCPP>yh|ZAul5T5kIq z#4+<`DrWLRusAF+kiB2+*TbnsG0@A0a_=Az!Qq;kNblMo>BIs~i0aTte% z47J(RK(TxwLumxjR6yyMncb(&hkhcUZc|vRtSJU)NR#hlIuBj^0?(yp+(LO_%{dE7 z-G|X>drGJUS1a;ewm;oIk=F$2|4|2Y0ldue%I637KX_}&@^PFWE5NAZ{OTA2>zA49 z3M-Eg6*`j=L7@rBkmb=@6dJ&DadN2PA^L1A{zrMO{t75A*5{YabM3#Kj)-IK<{nTm zGueiOECunu@D0wp2`=&kqph&AoO^{a(f{u4SE>2f2%o}Z!9 z_F}vDUU|!&WoDwx8*g&715m$!k6_orb|E4v1IzfkCyCOa*`EGfxB`I5qUZn4^7czy z9tKe8*-~(SuGH^yKt1F<7AW!jx_g$pR-o{FNl8`VmViwC#Wd{u0iN2b4*(^U)MtZ{ zVwoh(?aikl`z1KhVweNBcgc|2LiG69{9rrIOWADSAKC_`u{~$1v@as)^BNe|xecEu zj7Dg%y=NJioSEI%rbh3*zI$U|xx5e~IOuFMND|>Au$ceuKwx6LENM~gQ6=RRcl+yP zm64DZI%LA2XAppD{nlyPCzSueufDMTP&;DkF;r~=gBHfZ20y2BU}yqH0s&-0Y$hen zc%%%^CKMGw6Wv71O3`9kW|htg4`fjgv{`9%7ug zvAEkQgySl8`1g6F=V2api)ke|6i0Kj4#H5wEy*9-i6ftfnI=AsPSHoCl`tVoC$Zz| zOHQ+5IgIy-@6=q%>Z&b!5w1u>UAd!%Ls(CjJpFiZg~QMB9i;q>t`D_DwztAs(*7ub z2>9Mg4QO`>cHL)YW{vnFZuL~b$A^~p{KBrcA|>sly-E4#_-n4WTDud6J+W0y`Ie&` zKbS`a61p!|zUl(65`>L_a7w}5bht=9s?&+)EjqPMNG&XHhQ^tNUE|cAluE}jU0;~} z8n|&pATHE%Odu5r{bPOtF={w-rdfI9_iPFMS&l8*+gtbw-whPno>gHueZ_Wb=7W%2 zTSwAtQZj%6ITv{OgP#EOuO0j(D{z-3Dt^`(HE=kM-UTL8C0hHTb_`i%!?U}3F=1bm zaKaa-r+>?^uSu(v=q&*Kb8MJnUuaZsg$A4PW#+hlt7cRsHET(I{YofP3P#dGq0~P0 z$VxJ8L-85=CmL{tkKsnX`-9fq&A;-nedkXOyRkvv#(|;&k@%QJo};W6$zT=el-j>K z^Ay!UW&Du&jnbn(8=4!8MRGhEvah2rn4vwd0{eCBu=JA25Rhem8$qNNEF^9)T9I!( z*0GC(bM%Q90?4fOYHw)v>HqpnWMfi;I~*Q|34)<xge}ep7U+t@vfjocB>rJTJbDj%@0WXk5^^u1#iqa52 z*4TU}KlTr>J^XxO)&%ePg%k%RSVm6`{oRWGho;{Syj%d(iewuFpwl7UZOHhVGfb3v z77s!CMddV>-!pr}*tWcp6jwg_YqbhhH2pxoj#TKXY%wyEc1;#(602OfhU~8 z3ckVI!Dz_gg3%4#5Uk(+flm0IHuXh5Xf`-e)(78Pq$ zx7hkft00I<`O#U}W0Nq}(QGFxVc|O(zPxm^3u$tEYG@Z2lt>AlL-rfiLHDmP(;tO7 zq}O#Bw@eCORQ`Jk?{O23fP^SGI5dldd8*g`2an43=cmsC%pqSg=#J*omj~7GMeF64aLftGKxtFw zLE^EH|4V|MjzR^VhVrAuK5hM)K%)`bO|-`y^KbDtA2t175W2dm$G=7o zeVH(f-Ql0!n*Ssq-Wf@mjCmguq@6WJ$68Zh!nsUNdA&Fnwp@EL$Hxt6LJ}nfnC-yC zDq=g9aDEI{r#Ok2!|?*FsD_cSD zGED{Mg_gh)gx4egMd%Y>`(|#=OLl?O zHej>iY=yEOaA$Mc1EK#Y?x?T~84KcmbmSr{{dg4{j0H2;M%FWJ7Zw&@_EXu_{UHkC zg#Rq>No=>7P-Pqu1G@=j^Fh49%nx-^ljxYJZe;fB9{v2&dS_JX)Slc5c5+3cQ->S- zW5^II{A^0?Rr$-cAru`GSxgx#duGb!F0#iG*AUNUSdHq^W66hm#c1QI`4I}o@tWx5 zG_;8Hj|m(MjEBX>0yMb$;NekL0&||E-WkGY6$E%?(S$K^lBZ99W5~Cp zN9d}p3J-q}oaO*J+G|L6b>A!2#?yEKwVz92mIY3z`tI+k!f_eD{w~4{5bW;cK z_%G?f&y1AS&l3H_dx8zE_r!6}bq=ln>iH!qqak3np1S}K{r`E7b+)OY_@UB0c^BfA z>YMMaj0BbU`u}C@0!a@Qfpo8+a*7gCqQEvUo0X5-!uOOvNe>YFziPf6Xm!Xv zIIWEs7)nEpq<~@QOJEJf%e6WMO#}6q^4)`|l&21T#tvu6z|pA>Jj%gxegysOTvT(8 zRb%=c^EQFurz`q8?xpzZ@y5z8n?tU;lF}FI*;c10(sA}(lqy-M)-+97!?nedblLD(0S^Pm&&=2+@KkW-Gy~E}E!v9gW7{v8K=bNeINnz66=Errh zlPj7tCDzbyT4M_7C;p$z>~d|iIN$~L>z$*9Isx}mtTh&&FCYNnlWB25LxjM3Dn3_{iSw}S zo>W#5yquh;q2vlvD8^uJK=_g>AQ48Nc~$UtoNY1fkp|E?PF8!QaH*6>jP* zEqfvO+A$dQ8}FjV)*|Obe-d{&=}MffV?)OyLIV|XQ;i^giSByu!ysAlG+7IY<_Aq(ufAN3iyHlcDdxF;ZXnXVI0x<;w zw^+bmw|-P-OMf1ViP6H&Mv9vMhWN2+M6LT2Vl3*SDKP1QFL6sXvNwmGZYgDLA?odwze|)*De!6!S?1X^&7k<^_zVMqx8;s-c=n`YgVm}m< zIcn%TLw4W0=!x1VH}zM*$-rb-jN6Sca*I9ZMBiIM(54>-ja0w-HVXPJ(2S)};4p|#x)_u*q zkMK_Ty!;!@g4pjyaXE}a8xoR1f<&FLMVGWzPtxdy?M=dl-rh|mY{6fYq2IY(z z33+r%yZ8qBXUJ}2(R&;y-%>5Abk(!J{FIZDiowYCm9QW|d&H@y=(+aqsQYm~vigz@ z!%lK+T(dY6%7~u!%o68+{Kcyfsy1e!7z%^JFVm}#4nFQA%PS#En>fqConr>uJ?*Fc z!l)R@rL8MWCE0({U?1NI8~B7!K|Da_L7ug|@ZkTA*Dw^#2_frpgT-i|wwMB^i`xI< z`na1d$pei3%rU58aUvP=c8iP-YJtnEldt@QK+C^{;gF51P=IzpXFcV=0%b27tMnayYWotxEHM3i+?6B!B@| zpY#7$C13pPH4c2c7t`OzDn|zT9|2K9;kn+S^|EUK5*BMzUSg%+y4I@jMC9cXW&dow zdzK8-!=hqs(0s6ho?6;i5lF5C`L9dKtMDjtvQo=-__Fd$Huh9`?#o&owDoT-K*B;Q zX2WKh8{ccEpX%uro}THEp;VPYB}ntI!;?tpj97yYmp@mUZKI9}ib$^i&PcE|w<3mq z%odj0dKo-B=JX)srC0-2Kdtdfn%cryHu5qH6E;{*_aKZ&&@Krx$y`kKUOcfO0;P99 z^IW{4g8vR?G38IGx$lX!7V_=hTNmU>qm%#ksSp`0b`0+7pvZgR{(13)UA2?7l}*Od z!0cEa6?ClQc8<)D!htuy0DH(pXgxJ%2xt%AuT%uX5fi;Wr;!YdPRd0$#w>nknjq%2 zqprEQ_Xk#IHLO8MWb@genIB>k{R>aWt zH8}F0P5TLuj{wdKM`{i=&7A*;62Dt5cBor-mD3;cX8`~{)@$Bs`Qq@;zN~Z%mbO{? zLW15H{wlD4RdC?-QO<9sBF6B$L>(N=VBqNH%3OF&4x8QlU zL#D2|1jnlzj31vz&dw8QxSIt&>8Nd7vfZV%ded#>3}T?7hNmHAZ$9K`Pf4V@IhR;< zJiU`jU-d#qAmS0Xu>A??H|%}7opnXv{v0gzG^11>j8!)Hn3Bd49^cG_th`PR4F3kN z<4823+?{)%8{U(|**I+&ON-m#I4YfWxP8Bhwq6!Fv!%nFskQ4a8FftqCB+kN4yB^I zOmKRKf!DDx$5828ql-Kd0LqXkfw*Y@>RsL*JG+6_z)3;sy~u+eQ)7Ni3By${1QGBm4?V=X3>_-qh*YlHQh+}SU)5gSbk%lj zly)2Pe+O0;IMUw^JbPjT7CgwTYy#PmGd3)x3@3R6pilVjU(oqP@nLBf+1ZUkuO?Je zc#%3`p+JO4apaAz88#WpBe0aKRc&+Jwn~C%lx~5{y1feRAEtmhoCmX^&U5p@*H`<> zwai^dYk&z^+^&GiXFI=IsN#F$B_U6}5sV_zjjVxrzA2VBm+AaPGI4ie>Q`?*!}fs) z>rJ=ll~)sL9?T9~0f9N@kQ1@!lm-fQdAPX9tzp&t1XQc6#N}$N=1M4G^OX8Fhwhr) zb|Xbe?bA&lGny7(J6-7t!zi$|u#MR8+0r8!2#I_~96Cf8o^WVSz%hwTR}Jp)LBO@^rbk^{rL?LiT&!@FStxY4E^h`CQhsOSwni z&BLH*ap%3Zl&qA5sqIftXlq~edJme~yWCO%WSl=?P)PNxJq?E%=AxI0I6_%&!PLFh z=28jZag|tIiBzwe06|2r-tg2PWdA}D?L~l_(h8`=T(~VKNkqa`mG=HYKq{j48uy%z zWUQ-iX2Sz{GcN^M3b9hg-~H?Os2J9U_az65ZXD%z7gXg-UgrqX zp9&^Vu%-6DnpQ?|lRw5;R{`DzfDsl=t4b3v0

|M*`rAtr5>|G^a~xfaGXx?m;_g6>k2O+ET7yyWdjU3&J^% zmPdQWymrt@Obg8v30E>c6_0x%@QwMpDgxXLLYZpvkhrBeXF)i?ge?!o+R#uU`qJ56 zAScY>xDogl!}H+;IK6!eh9A-)mroSU$y>s_)DU=DUh_boxz~b|Z-jY4=~^FjnE=~w z!mOy=J7iSdt!pf%dkl@7r=aHLL~_9K2GES;sDo zmCr`*^AoH~NqaQ}YWj)RE_|ahOe^+2@NQo_L3K{K6;GttOKO2GV_RD?PgmNdc&&qT zv&d1DaPT#{mhY|8YMp-}#^zuAh>_ny;vV`axyX@ehc)oO(wa|BA+|rmNs_l%SQUCL zRH$*B1aU!~J;P(vX{CnCk>@wS`@{Nn1!q5RD8X3;Kal>+oLn=rCyxZ#L3a=SZVZ@% z#fPuq z*N^m+Qo>B3Y)6@Sr{|Ur@!6B>B78}Gh65fAB1>Ar5ij!RLLd7;i;SUPd66ObTQCBJ z3NAbL${?5rmejCMj2oi@+o?sQjYnzmftB%rcW)Rb|f~MTtX{uDe41s`6I#SH5TNunl^FPyxXjvFwdcNwh zm<04tL=3m%%fb%t-o(j!<0)+wJuNQ|j0r ztOU{;Yff7114#xiJ8+dw_#s^{0l13n95=FRQ|kMV-|DX~)kQwzE57#d85*8JtBaA% zxwTyCrem zTV9beb^_u1#BpvQUYYZ+8%Nt|x*8O^`L7$N>+7L0qg5fyrKx!UtaT8S8|nP0SXrP= z`}fNKj~b{uCX|KmLj&CwzhJrFg%Qpjj#u9p4705Rffj>!7l<{g>IMNsW+3qi60cd` zQW8fs_?K!!Wb73EM3n1iNZa5K%$q-lZZb@w=-R`~%CI@+QFvja?7QW1H&7+7xaVs! zOr~mkV@9m1V=AlDJe6Jau&;4K&C6eE&lISm?2A0*tvDCNz6+Wz%cwvpy@YhnNv7&L zfOcPFE0)^xt6uu_IOjtELqZ2)y5#IS5Cn3P(hv)irOIj@Xab;yQ^=cD;FZsr5t$w- zxu}U@DT(k2WvR$_U`jpn#GS`K?t7Y8&?L9}`zs-DNT?slI}@)>6_QtEXCQF$c!hv$ zk#y!sWpMrBMnG#G#T~fbY5T zglZ-E$o=n-tP+RSFkNegA&eiZdj$L*uK(&OfFcxXj$~{od?dzjU3>dJO#`I zptIJBTxGM8X~!qKdtL_i>+4Dfw@I;6Vq2a#5W-5DtoOQLvKmCN;w6*xe({wV{^BaE z`E}u94qW!0lT1+&|GWqc(l)&l@tdOTlVp8a4$PBqUZTNH(-3*4YOO&H846=c-f6h@ z!2O_kA-hB;{0Ehv(SMddlOXPayfzR;-PnFL?claPt&S;ps!kJJ z*4IWXSvTBW>VPSKKpST#(wb+sRCaCl5*KipY1ehQt&H9D#3XA^sB{cn`-Mq~nWG{9 zIW<^TpNuQGzWx1?{Rjhk@JiG=wCv4^#VsoV*7m;~#)q-)v zPs^WtOupz%6Rg!wQ>pqot^@@V!nByq`O_FHzO^ADuZyW4HC}&L;4VIS_cj&_>Iy(< zE0BRGgZ3(v2xzp>#r2c0W>S({z(C6 z!-uB?51Zu`ri_c1+~v3Kfea7!vm^B-o_99;x)+iSZ9q=RiTrx9|K`S_k}#uBkuq&Y zR61%5Ty*;fnT`i9hYl!Y9z~0@Fp6S^MWw zd+%?{z5Sj!I#1oy#bf>2*xq^O?%_mqPXRXWdUkz4Lz8>lu(s2?mywkE7K(HIY&(za z7&CY*thPr`SLRn4rlA3w3ZJ+2ps2+-qi75ZJ^U&3k77}e6&q|P*g7f(ktk~z%&-3a zj*~C$vx%Qf9XBoo;?LdU+rK-JemHFq25(3v*|D6H!FWozE&1;A1V_JKOrzsvArn~| z;>XpPk?2I;XW%V&tD7txbOlR)l7U#ytI_XQoNgn&UN};#@>2aeMO?}Zg(dWYHE5jv zND45f1)lHYUS3}TQo~MuvJzQ`py>&r(ZX8`NKdr9>0PgjV+!zE8^Xz zMLHbI09HD#>hSRQz6LBuG)6JpZ)(U=Vo@Czi zC*sVLy~AE)5ED8=gDY_)M`~K@kh+f_@IiI9-MH#cWRxXu)b826DjueWWLh$yF-Ie< z1*{WaK^3=mrgC6$;U$3E-on1W!JY{EfaTLkQIWW&LjoOdZ#1T5aEvfcFJBIv&SbyE($nRhVbd*rwFfUX$mUGft6k z=U@j$atz<1AeQ@pPEEj1_w+6E64)5VTN{$WEjzV@bRqd^M4QF3#E;B6u)rLY#iTS* zRhlX~VEw3;nKf>Wz~K=Iu7@Ka3{06$n(JA~U+4Kkrm5Jc788WcAWCwz6S|AC69wL{O; zgOH>4<(Ja6a2$tse+*dXQ(elk%gD2efKt`wg_|p!_25#9{wx$3lgDv9i8+B*jPm!2 z2L4Q{X95G($jjkJwQgOzOWj3^6%p}HjEvj)2UPyQa^_2a@;MG4%aC=uHf21A{Pfn~ z2G&2}P`aa028YN6{eE+OG;BOVUW|dW_5&Yx z@3f{w1fp_LZ^{^~j)k=@_~Y(l&-GePy!P)lCGR(1=qrS6?rgAebOtvj(=ccl=!p!y zFq=Sto%bd5%twVa^_(M7Z>S$2n!hE-c zg!tHVnH~Xxn5Xf+?D9S%Wju>s`Kjp<3u}@+5u<03sxqtzuY`>PSKt0S};U~=~&f@;;XNloqm5~Z(-Jy7pI z1oKUN|JO_V>7g?7R2BAr^w~$}t`LI}xKw}+9wvzu^gMXA05+chEQ}apn?M`Dr5pLM zRCPZVb=wMgNOC zla!u`lyd$I$-vL=dTR_#g(ZlOZz!>~45<z!oyZ=EQgJY#?Q-ehk{&)wi!tPOJP1nK1 z^}J17T~ogpz6{bp2u5R*OPE;c`HhrI(GVN$VkFwSKql@+=9E4E&dT_>E~cP>kV`ZC z*?>H8_7EB(?~Ef;Le=y`S|o~GW#1@dBH>a2?y;W@1`!7tEiLX0XQsr~ z5b5v)hIpa!MqFZJK8T@aSL6WFo^je2udFnMyDxIVfp>6qMdMwC8DOMHs%%C|MrPmjIU>G1gKY^~zMJ|0Ky(n~ty`$SH>-bR8U^yBa2kMTAk7YQ;Vi6T#u%NZT@JWb!kx{)=40Fv^A`pgp1Gjo=EG?)Zu~MeyOnBc%Jw)N{ z5nE0u+Q}j8swI^FPPvCX@$CaQRaVNcKF!w#wZ;#PWQD_vP2D5cYq zZ+m8B=GXb4ct4*sm#5 zI<3dF!MiZtdT>V|iG}ZfHq|=t|!;KAZbmRshw9_-^?ANL01ghU}iQ!fvFkLEUIgZ$%9rXtF3FRmJ)y{J< ztug=QO*fFBdgY(<87FL?o@7^ljC@=*e5$@Q z-xCsAk%--t_5?a1PqL0u*6s*(>J&oCO0b_D&*U-B$PxZ~%nt90V0W4D;u) z0;VC(QPN@#DxIe%xjk|8s=+HN#VRaiKMZfkSss4v33YarkFMasXJBJ%Z(QCu{Isj( zh?>K}L`-QyA$#SQf7|}Bq9YpCUljPY8$=EJ7;H~PkM-!wlFnPI8hy;j-~zfK&D5Vj zZjkIJl?^C&DKwIT$=qUR^jvuO5d90>puapGCR36JO+(Zr`>eJ4o1h+E8;SE`5<4l8 zE}5LrEIR0wcy^7=mTYM!y0Z2pBrT0h7S^44p!`trCM4?lHfVZG#KokJ9Ex9}`-~f$ z?o|{TuD5#vUXj7gzgqXzxcxHX3U?5Q2TDtYmTan^7HRNwk-Cv;-E9#Uubo?N<;Gf6 zMY`sKA&PaRMkChb(PhMwxjnEGrqYvrG&5XdtK59?u>uc#fF%6U0@5{g0?77d53`h$nxVx9*nII}cr< zGM6yXFT<1^w@hGf%@EDEAoxf1<|ECanG;n>PN7=5D8qqh@TM0TPC^V9(aa|%`(`qD zE;eq!6j80yak+*{j*dV#S6S92G@rPK{&XqanUjozDg#?2&;Hg#I zI~-v-PieBry6^-~Rg*#Aj}7je5t}*9fZ;a1AZS(n1EtPhJF+VL&lnKp=`<6Vu#ke# zhi-Ho`rD;-IIY9!+EI>Wj*g@qygvkKQ$(InG(1*YPSeI0?nebH@jD{VW+7;V!7^Q> zR%Vcv(|LL#RKhXMIJ~t6NL6wPz7xq-BNzlYK_oZPji&^Io+qX$D_)A)OH^I9fXLDk z9MPV5;(E0`oOs?fU!V!!y>w@(_^S-|KwJ&aa+vgT`fNzYmk0 zWnghXb`!G%l;p?M?iT{!N2+x5ZW~Z%dGsK}XNXAvCOq6>k#p?kzEC9cjp;0Z(?dsx zq{0Gv9H`AR&hV{DR*$5A=9{Jf_X%uM6uAX&lYsV!G|7P^Zwhf=?(Mm}5xM>w>6l&6 zH&7^t_H%pq>}GsSa%&{%+{hVu)0i|R$?j~*M<4agV8l^*`IKY%Lb`^Kesk^tBvubU znzhFZs=nuX_t#XpUa=#>atPDobC~pM}q5e@1fKTNP^3XV@suv0RXQoZ91qHe_2{9iTMm z8uP}Y>q;6L@#>yiyvxuHqAb)GrYkj?=1AcVN<=C8EXX9Vt;m;Z;im=S#*(HR?G#bF ziH*%A{0Zm23!Ck>!-(gYin}BE$jSM>%Bd_aq79CW!vSYm;>r)K&;aa-0{WKzZr^wK zDbb{22pS^aI*VgCEdYnx8${Ua?Pas#uFD}Cpa-g**w5Cp=H!_ zbjJZ)e;USF%O8qO%M;&FdWhFrpo^Ri~pSFM=yK-YGg1;iPl=`ZLd^ zuiUOf^j6sHvSLCa%+GB9$c(a$p6>`^rkS(x5X5I7=+yvi7^yB&7NC)_p`F0x6?a8e z&hP7k>4iOKf;D$Oks&AI*i4#_(zAcugGgQ?1`Q7HNYZD;BchfXEDm7hNQMImo@sg5 zABht!cHlh_X>V2^mI5Q;tDf>7Iod@J#Kc)lduF zG3=_5s4ZBEOrAAR2U`7;k)GsC(6-ez+A62hC zsM>in={c==0Jm2kFh4`TTRrh#ljJTZ{Vfr55-`l;SY+&FD)Fh46y|>newqGLnu|(s z%!B0mS7cO7w1mD5Uqyu!J*B~DIf(@393Jg1qA`{mS^o>VPqrT4kL=mIFR9^@aTVw8 zpc!|0hDsm2jPvQik;OcxTnTxu;_5mFBNybirIdmu99WBL>k|adCEC)4ITUveH8+}? zFS9#POw+XIs3x5cGc;B>ml7#mborQigE0$g4-w#5B#YG{{oj5__-o9P+YiYg#YcH@ zU{@q4Qy+@WUp8CMPm=OJ44i98hBUcu20UCIBeLLx)01hn^w^si6?2QnYo0Zmub28= zRl5CIOzKf=DXnwRckamE3PQtK#bSJwrGcZ>yvhW|f-YOHy?HA;ThnYKpFqE~oA0Mk zbO5fIRv$3wFCUo;pMv<)i@Qlm1@@Q1>_S4{)Ib2R`$XNX$XNwS@Ev!%)f9;JE&I|gg z3#seC5ks8qkvGdT=f~CP8YyNq4mmYjXi^jVVE!loOmVWjweN5{Ky{noOh7)7mqF(5 zn_v)(CsLw6;OdGJMwXsRnYa`a$jH$oV55WzOhg%%%pP#EYQw*{o~B^27Cb90&!#KI zx^qTXU)7clwqTk_gB}b#j|Y?a7Bbk_Ag6Y;rKlUhxybUuFAJ7Bcv#Xjkw00(K~W(n zK^df^7=2KLbE#w2T&~pNKp$=0O?pHp>VS`;^|aW6kiq25PH`-2OkI{FZ?fEH=mG<< ztyU$4nq$V-(G(9x5=50y3u)eb=xb$dGcdY}>?h&RbK&V?xyRZXn!k+K8vj7cKdD!HWA1M-Js^kb| zHIibc^wu7dR>p-I$f$T~rKf^Ax^* z@DQfo!)XKaiy?LrKN!4o#iGN49Kw)DMkI;0@FxH6qqzoir7$Z2OV}Cpx_Z>WM;-#B zy7I!E@i*x=_Z%U2qeZFd^fGVCrmDMGplDf*xsTfXoq$F0AUQ`%L{Wh9ufQ1Qb^Q-t z4r6EBI3;jx$otd61UqS=-r~}Syc(K)Z1(6SwD)Nio7(z zS_&RIC-dDFMDs{IHLWhP&?!`Og&vBcj!!C`_;_dJ5+CMlm~+iPio$(kteYf<=RuXB zWvh(H`9^2s$Ywi~+07~iHz(9$QK5U%ARQ*@kQuTDrP3Z`uozC3&geINwkeUiAFhtsfHP_ije!4b$n)*?w!WJ$DNjs};$i!5S|Ex-+`EC~dh_m`LZ&NWy1Xw#7_wTGfD%l378vfk*~cbA4hOAbe7k|Nn8zec18N7LJo7!^-Y z4SM##d>67qwTj09=6rn_6d1@f&5b2CV57Um&9T|_R2^qckh+Gi-Slxl=6u@DlSB1U zTYvbdQ>O=gz5HDEcrj6H%1KB20^RzJUDO~7U6cqHe~}vMEo@za5em$IFoBy%Wuc8D zbe5N`gp*=Cun&=k>uL2;T?9f$I>qJQF#XSk_{m~wW_+iRO`y=>mc<9O zT9}1d)0D2_Y^@mBQMPhqxw}eR??x!^GC8N~*LmVt9cxN`^8fET^GkhtXrU zN(f4j`!8dd5F|yh8z<^e*qMu(R2ldDq3qKP8vW%9a{OC21&2?R*jeaIvEhU~|Dy9Y z)6i!=k4K7o3EN*#i9y2@Z8(b6W`mB>j1YR!dHyZSammn4Xww$58VI?gQ%yG!%LCf^ z49U&WIQ}8ARl>PhXy&pb)4=`vi?};7^_dr)p!)-@R1B@~q|!$?yA7tpLhM^dT}!uN zbL{s|5rqt3fOcEn7mLXpb9x%ZhYZ;?TAsq{<>ZJ@OBFe|55bqW#V}vDP!qn#x|FCD zNdjBy_+pkbh|E}vHN;J>^#0nE-m6?)$OcI$|4Ro#nbFg(biyzJx#rIr7Nbfjfj-CnksPG_Kj~)1hvYiL5faR_#YI@n+)~)gX19|VkVX9O)EjX))+6f0 z@ZNJ>CxJ*mHjyZbnTLmwFmio}rQfiiaqSY14#cQMM_C$}hZn`iV#I+chFu^XF@gRt z3gx(j)*5>WrC(X$bdW%TNwl3N;3qut zo=4{X6Ux3$;=tBw?up=Vruvo9UlFFjAK$y;<9^-BP1c5#K%Crcy23jQst90g$r4D= zuPyL`?Wt!NUO6Yq$)7lOJgS8DB(=PHBVu4M*eOWFh!a!pra`+qpE;}wdqt_l zD7VO1cr27(0=1CUXeTdI)?}vdS(?7WSEMXSuV|gDlo#p#!oLfQCRu9}ceF(b9l}f6 zIhP}RrL*tq%S}g?YO3_RGIU(chE)8-q$jbT7Man6R7tG9*?=bUV8V|( zBX?yHb6l`XzsyfKD@WnFyezWVril=kv=IOmo!B#f>jYar&f5OXE}Yo*CoQ(;HNWv^ zCVPj7RKBEI9m%zZ#q1iUsNU~fKcv{2AfT?m9h6Zz&~Y-`?Zg0e*r@SW3&}jH@w)^2 z+79}E8I7TSh9~y6kl*B`FF6x6sQc!QrkV{*5^6%rUqlRaLcm>G|GG^YYve`1FNj$r z8{wkqqW8I#%4uuj&8*srToAw0S$z99V+2h5o~n%1plj>LIi@dl=N}|P;Ux-)YZ!$q z><7_-Syg0^P*i%uN>ix$9~Se`=SWAbh(<@@Bk&5gN5wdSb1C9A!BcN-w$~9Ud?u(A zhug9!87m=Xow*GKyOCqIO78GjaI7IiVLbWXi^y8JYsXw(3hU-$otP5K(VSKvZB9YY zQA(p1)#0yDI89N#Cod{X(MeUH>`VITV3d|p#X_2qEr)YCS&@Hg%aiqUMtU_CqP$`_@0H^ z^JyKAq46g4rIm`%<7>btgmGT6>V>Qwfs1kW+pv+bEInNIn;u|*)y%N-v_kqNNJt8w{1KKurWr?#nd$Df$zixZ z-}}0L|K0b${&Ag8y zhL;3BpGCX}CgQ;553c?nu#R7nt^7GipWwCP?8$_dn9}#}t1lqP&?8sk`V^7gfhFo4 z0^)jt<)^1e&%cV?5uN_5YNbnM&YUTIR{Nh&#RABfE(JOmoA*<>og5O!uX!rv4lDHV zRc2=s(q4=!u0YGLfAFSxpGv&uBf^lbV<;|wnCakL$dwrR_xHJDh`u$kuujO?|L7o&~qJ2dEw(X{AAymOeljGor~ z6g8g)&!%USe!0Kqe(LRuK`$Zpj(kMcQTcrd)24Id{9#Y=p7_BJn$W&?eZw}>jz%(5ukBEx z@jLNgGq+3%gF{ruh;&$OOHA8x0SL^^9wF;EPegjW`0kSus(o*Udhh7)w|g2bLdO!B zJqtewpIgE9hc$K7-5Rs)p}MVaqjTB9GzZq8T)<0 z>v8j9*&9(nMEYFL<4@!q8@i6wr3qH=loNSQF&`088zc{}-%=S9lqzGu>%JUb(ebWg zp~l6Oj()80xl?Ft>N7^cNuy}zFw1QkoYOTIY@|w5*fs8)$PJvSA>t{&pF_2Cb1EV- zD5QEfP|nva9A+I&Jx0@Kd6WsFb{>W(l*3)^Q;l0`f1#IcVvVZ6sBJ0i*Jd;!U#Gp1v0 z63T!_blM%?w9gj%wQfRvh)Am#j}5{O5lyBk`W(u4PBgOr6Elfu25#s;g{K;f|HBuk#gH-jy6Zx}dV!z{Kb!5>3tjhMcM zG=A@92yzlEnh|DrlMa5!)8-{x>^FKzuXF0jsRj)!!&LP?QuH{t)!VbG1}=G1iv#`- zfBy5(>_Cy~urIwIYl^qseq$rDnRn&*gc?F_Hl!H8?v|A`NCjXqr+d|Bitt9Ujr&)= z6L4nK+1jja?~2$8;8h)LfBj`kY=aBsM+kMd;{yIWFDVdXO>Z$+&O8 ziz?FaELw24wIM4GEWk9I?(`5-@r0Tpg!hk{tlAG|b2`6v#6LFgWJzk(e-^veq*>{GNAG0XyM%%a)ZesNE%>fA6(mD|z*;>ex6dZRzE z%$Qeg^e|@cponnt*gq_cuJ<2|{`T~-k$ms^WE`9)6J7645ZRWFhfemd3gBbtArWs$V~n#Mr0)`uoF4W?Uh29KF}IE{>iXkN5wB3e+OgQ+T_xK zEHx|{%^H2m`0IqF;*tmU_-o2xC1=Bkn(D2|_rawg9xKGl&zPdwaNOWY4eVs|k4j-} zjj!f#kD*De{e6PuhvF>IJ$34|`wc2ScC&9&&CNzlm6S4g#TINH;mHGZI+^d$x!!); z_=_nwWMJj+porQlRYD69ss}d?PPK~M^EJ_cZh?3hARv=ONgH9)LK|@+)8BHMX4SS4pPctYoBtqi z_CLT>v)-3k`wq)z$b{6w=ZVhosMmLUbGyi~W-Q>x`hXDOK2ND-$M#wOxB2+wt2kL7 zhPLemvU#dExGpV4WlOuf0Mph6O9TV%xc;c{l9vp@M!v#espHKu_K0eY zQy+ukP)qA9zHdp7S$>ooocUE>1Sf>48gcqVM3?0pe{={heElh0^JYsl@MiF-?bEr({ zhpAkGC2`eJ_+&`l2BlE{vXI~h^@j!J=)Id zm6?ZfUP!Ph$ApTU^;euOYUTX8p^ff3m-rx9@j$6a3 zmZb`>lz|DhiDfyZ>K!2E_dD7~=0Z>J*Lz4N44-F&e#a%8z-yGaRq1~I(bZ4knN{sa3d~n!&-L{Gp_^bG>=euCIt?hb1mr%Xgjbuz3p#*rP zza`!549-R8w)8&4)d zeb7q9ZOM#nY=YozB+L!l3Q|^mc}$d8e^E!7TyJ%O+gO$N?Lb`+8S{29r8L{k)949f z*`{2dWFw^|^~bDyTlpP-)Ul4#`x{nhwjYW3=T{9`-$tk!L_CXg0^~N^2+XX{yNM5N zhV+70gXv5%uw*pNeBsWtXeNo2W9DhY0T9W;MayFnfs@ht-x1qYnc8CiLTJg@5B^e8 ze@Q)=Lw<7_sqgV`F!i|zYxBoN8-CkVOYvgS+gdSf>Nda>2u=_;qG{LDzVD=n>odw8+{HTA1*`3z^mqP}MhoZGi& zV=cFXmh}YOLsfmnsQm&G_k+{;M&)I>+k|)`{$QGY=>{X68Ztslu}Ow9(0($3%>l+@ z4~}QSF)l4DnX9TdW9Kfk#hD{Ak%w81ch?g-^5pI3r$K)vES!!#y$a})nj|0z-X65} zH$lv_NT9$*&jVi!-Tp6oV>8tQsJ0OrV!WXvr`ce+F-lY0hPHtC2$_bk%45ak&@oLJI+LYqOd5LikBhRS_n|8CZPKwOQ#+v-&T0>c*36q8*#j zG=^j^bgQcU3&B8P2|glGy|Ygbp-@87y=bL%bL=ljf)dXqndOIk?dCg0z0*5FJ$yRF zqsZ|iJ!&>l*=t%OPQo=Jwh9km<#&8fDB#%@{Z0GFWO>u~P-bk6TWtA>ds@M-!8X+6 zWB8(0bW%y?r>7Nl)e)sN&z^w=lm{bUf6;D;((wIPu|K_Z2@W8$LHh*#`}Q9gmFT4T#%U8 z8#?ss3fB_g<=LcN!**e|uNPL?2PgWLRlSQ<{KCHjw| zZyAl@IVe8#yoI`fbC*a}{Wbi|V?wiH`B9T>_7aBTKqt|?QI267mOj&!6} z`H1Cs?QySv`hY_oU&rYDlsF$Xc^%Sb{ZxpM0}?y}EQH$XN!c}$TE8){AM7>Fer0RUutEVEVP;&%JJ{t!)F{VVZp+;W#W57>fp=F#Kp;_iG9ld++V#bC=X za{J3NeR6LsY48{^wIhK|hz!r(n2t8!o21pdFjsIc!NW7fW82%$iJNQWuPEDHxM!@I z`}-KT)F@^&-|B9RIhF%h3!vW-mq{2{a#S&C{vDy$-FBnjYH(52 zH6G{dh75=!8tFUi`PJ+?%42@gFT$SrW%InY2~59ptfWPM>wgc5 zL0bzy9r}v`e4nF_u8iiv=${#+w!U-w zwPLO<=z+iZvyq{YQ`O1@`^`*hivN!wS_rcjr0tKONDuoqAtE zO5h3;^(G$e>D?hqzlwRJyz$?EAiulhsi2dHaA^>?a2+NNP#;@bqkLqu8Oau&9+n=1 zY~b==i-fauhsDuV-31ZcT<~E0G8%XO@Yr2cd(BHSQxi#-Yeb6KZ#{pG@?WRVk%T8i zD#su0!y~{07m8Tk8zanR5b8tM&!^chm+lCsuj%_2$T=(+N}E+7c{3Ki{k|E;HXRdR zEA}2~wCOx{6DtWzh`r|KzYpuX293=6IM7z3UaeGWxLhT>HlpJW! zcwy}O)p9#vF4%l=E1{VDdt!$O^7`4&;l#^z^JuM+Ul`hQubf;TW|K4|-O!?=#VvW5 z%h^jC@?yYpTg0$Co=C%g=E*;XN&bsPLk3^zlD7^ z%x?7{)PJ?THST~K4q_Iq#3Q;nV?T2gUvoy~vZjmF${HQGr%7HD>V5vXLa6=ukp=Z? zRUwT~^b#p&iLZP8p0g$ZNrj;-n+qOjeB0J5rWV<(Zm89}vI#F^`ITq}*YP0Ep;cyC zQyJEp!LfcvFH*ZFXEp1hck-BSQg`=28qa8Mxn0!EYb_Q|d%@{rp?~9!ng)pg1;OGq z0qhw!qHP-T+)DMO%SXhw{d~rn6DJ#ip!l&AA7CjMK1#yu9yP~Jsm3{ zU3tEk>+ojGNF6VxrXuVKf})yZGkW^swqiL3g0hh{?V%$-%(w42#(N)S9e|B}2L)=k zfIRQ&RF;B+w8S_6bF4!OLJes@Y&*$&y%Lj^14oq~;j7L5Bljv6PR`Q@`PjkYM0+NA zu_g5yZNHfM&^{1*(bh4t2fX;%c3DDAlzUIEB9OCjt<}8Cxsh3eNEfvwX$YH`(M;mk zreMgG`B+mi9t0hG0Z663aKahWvA7fz=nsoQKWTR+Dz)L}gR( zB2Wi5%W`jVSf`%f5p!I*a5?-!0=e`0l!0l0{i1;AJ$JCokYCbE5W(xWt5TD5BG{UB zH@sHHu0j^R;|^)ZciA;ADWsd-cTe^E>ljf80%@%)p|`Dde)z`yQZMczegf9w_Bi|( zs?c=$KGGPMwOnuG6|CgsEwQ^jiL@P?dgo+WUBavdeG8W3wga(*?=l1V)T{?0h_Uv5 zRK+7zwt7CS*#`TC_KSL6me)PBR)^z8X_w5n&2NC(ZlQOg!Es-~hAoufzrfX~i?%y`sLUGYslHK@_v7NGh6~iR9Pob$#=N)bA-u4tVnpEgMf02^k zaQK8|rn=`P0v}yNb*;UTh(X5eeM}=uqw@~k8Q3xHCS9c?Th{(7Rqt><@Hf)omb2w+ zZS5a;=5ybulLdFb)TC!y( zpP(Crm-d-dRPI;SPhD?R@WL05-sHC|-_1NkFKBzIP(+xI9e?XgBpuvmZD}SHbu+IkoZQS+iEyH7Y zJ#~CCEqJVdz4^hGENLRzcI%dkdiHH(DM0lYAV`W7w$%MeJFUiV@<7T0N|rQug`+E` zr!cSjk~&Exdv1EMA!X|=Y#?u@j$iHAs=oT_RP4`3rxXMYRDclSWFt)Sk-N~e_T%i) z9TvMc?a)HA!&tdYdsktyc#h#ITqV)NxX-_W(ii##(Y(%LaLdKPOe^H=s%G~rt7b3- zx3Pg`f=L*m@x~i<#6YH zFzAbxfetRtjT-CE3LiWjwogU3f>Ex8_1om=^Ck^}ebesD2GblnyPhcNt;~+7oD>=K z){*;^oxzPd~KC_C9)OWyzwi{)(s zc*`_qeB@Rny@NBOw{I})rU`>M-KhBuF}j{;IIVQykv{rDBCSYD^(2WwVc6W*NPOk$ z<~8fFc<3+03b3dkee={siJoc}(P*RZmT@R;HTzyJb7i^~R0E)cAZjI`Y#IAbDjD(e zV$fypgh?TY-17A=$p5z=(7IW3(Zhdz4+6{@$wyCrmQFDE%+iiG^=*+I1d`+B+(Qmc zp)KKtQx3IVtWNI$&00QmnPylD?q5=~Xm1E&MXq@Hy*KcP!@Y%A+k_yfI8LL!zU^r!U3O zjUHj6lTD~_PoPED27a?%ohBx*R-gi^d*I9gq(0&5P86S0&(WgkX<+S1eAR7(#!22= z;1-wB5N6x&L?8P3D!|lKZSAY=%%J4WEi43jT5L2D z<9*gBbYD%ZijhozLMa-SJKR<;P`H-HT!QCOUwdMw7hjP$x{1(m-?gLYogYrmtr8~U zQLyQKkrr2B=U06Fm;V531ztgz&Q$Oul7eq?5|p;X#E->e7zusa;BwuZ=f~V==jsR> z@CY;6fMbsRl`NLUgq0`vd%R}}j>WumFA{wb05QJNuP=2cS?n8o2g)$Kk8}Jcf-fbl zzldc-N}5oP?#)Yv)#GwQXA}3J+uxukdCS56!CY1kB}<7@Q)ovVTi-OHr3MKCWKTJoqhb)zNVk zvE#m2xD4t4^4mFUU{Jgffw0nlWMTgb7`%;5X4f1TAklQZ*Q`Amh+y_)^sNH-RWA~g z%?#O|ks|bdwjJZG;p_FYWO^;rYbpvXKllB1=JWNI0O2VGfuY3QIEX~VK)g5u^UxV* z5`qk>BRdtD<<6(I$LnFFWw2-#NND@@sy{I&(KBUn8Y|-ueX*n7C?L#9jy$zGT)-eA z{V|(1qE_8Eg<0yibYKJ3eWsLE7>ISbby4O6tE;J@(n=K{nXRQ_1hw(ZCvskrT8o9S zLy#AWK1vkFG77UU>NhG+S099lw;Doz(#uo)p1b10^-rJ5{F#^h5v!q;g`=6ZX>4eQ z_&EFtpzTT}n^2)5X4A6Dj+MiYq8k5kj#_B*0!%iwnhCG{aru%uIF(f12rj9W@bY>` z!EHENydT5ax9=E@C+~$h6OEseTB}kjSrO;mLU75M`xJr{^wsGD0{WcYS-K)Rw7C-e7K9ZDEHTm-Qzpe*`hK1;&l$b*{Q=e> z^UFZP^eI}b^xU_O@D#i0S(bU-Ypv{n!zU4A49AI8WuU5^pY@l{PI15m zJu4R~s-_0ZjuBVLO8@*Dj`P6$fG4oq*?DcTksY_!_s*VDm5^s>5;RRfuV^JpQj*p- zGJrO-?xkdZt1OuS6sC`yB(taCxk&7^JU8=4!;)j2pNOa+Ij@&i!q;#&O==5|J zts-06Ly84wd6z4YWokb6D$O00KB74%(aZN%BI!LSavK*zP3aV?#6Li=EH?T<;Rh+3gh!wJc&SukFSv{X~AmbtW zVpvz@)+Q-Na(F>&vf>agffvdG^5qPCV{ie*#C+vL(osy)V zus=s1mae3>9=v}>W?FJDO@+;Iz3qnlKJc#sn4%8$%!pFH*|`9ka;>SCJEADeUpOz( z8=bhJf?p`F{3v)~fhhqUVUNAEJ=sDKq>~AM0QjTObGV8mwl%*IhM_h5K5+Wz=Bctc z-@1)m@5L+Y*?<9q9ECO^`cEqPgjzrE(a>@_Eeqa`jB6v9FWB7Fwt2^2u62}9XZJ5L zFvPD);Uu}tU30a#t`I~M-;i`ze6A-}j(MTya0jZ2nt+LxQVNw_JaAf`q{irzOBMK? z5~5JGQJKYTc5u;?o{f&T+LG~S7P$W?fM}pXLLy7XS>|!;`ih;FvY_Q)a^Hw=Nhv4J zJNO*=ua#D%WN`i}hOUe{&+O(DUrCqP1uWGL64Ltv%#ofwXqY&%i1`DLQCkb901atu}lVDUS)puT2^NgK5R#Pw6{(%j$LvPqq11B;B= z0Ul{f4X*9B7pq`n@OS1?T}oYGD!|WU|DHNIugp{#GNg~KBo%!tv2c=&cT%qNL{Uxy zS6D=@#Zpz+_Gj{!Y~Z6@UDI&($1@_@kW>e&ZS9Mr^cjK5T`hx)9f`9vHr&qQ+36-l#gc+bMMj6 zc)^FKbeIA5q}TVClxwnMWzusT7$6{^mXr?S!PxnAmm)Q~S2ssY{UjR0Ha|NCyoLGW zAH&w^FCp&>N!Q}mL%7^m?cbv~&sz$AtBXr$gle&oDj){kTC4p)r4t^;DSCIsV{gAC z>o4I2i>p#*%}HzJX6_1o>9D3e`gpG>}n^nCxcFfGTY(*gt4Y$7xc4xOb4 zj6BoZ8&Mfdst3BtKdPc6cTW{yz=|J{o0Q0XXLjkCMPk1~rYDieYmP@D0thFW znh<%01O~(3M}8JN?Y{+{t4u>Q&1blC^Zh**5ZHN_>{HuIJ=8Fom#8v<0N0`J$DHcv z=PdN30vQtJ4PGuxjM~^^APZ8-m?a6R&w<*u|Jf_6&#_51gnJiUUK|CqSg-GN7rswi zx~Y=IY(x~gqX*1(gfH1O*(K`Kz9Z zFKH6&O1%O%NW7&g_PUd>S$>w3^An5EUn(P74qE2?q>H!t%3_|URTmk%_IXpVMA&@N z@6|)1+J7aPQ3ri&#=CcuK3*&>&3W)jq1VVzoYxcm6Ko->m#!oTrFqdbuzkY(k<|wq zD1+#)%pRXjEY_b@@Z_WJ_@=Q}?-z`ErR_2P{3EHZX_suV1B}yHxVVRFhZ;{!*yIpi zXF8UfoBSHwK4FL+OrcMYF&7%Rt>3R&?LU(@kHTChvf;qX5xQ7rk$yCQ{&eYU#gY$x z^rxA=g-}LIrA&#axbG&`<@p>6kDzQ5L%5?llNpB*XCOpWqN=oSMDYaXvZ;b(Jl8wejBFfgm9;+s={jjrP!{QAmyEW5J01Fd`#Gmu?Nd&7 z>*-MO`uSium%& z%<6HfhY}FoPDEU+NPVMgG1&cBE0`fIIiz;Kk(OaGvA|6(FT(qDRa@8(@5YzQXm5QO zV=R0piJKE{dQVuP^MmlTG0=8*hD>Dcm9*v_c#+&cc}ix!XZtu03SnhX$ZN~##Tjf# zb~8I8(eY2?r+$p9%J|a2&V){4aA2xYJ<=JWl5V#-5=K7L5lW=V{V`(Top^>pfRrnt zB!KSh3^O-stCwGA;3O^W>XrS=)Df4{zHGGD9g<2|rj4!fK^V(!LTpw*~2NtkSNEG|veiB_k?q zfL|l$9ZRI=hA(J6`|$%9m=l1`*z@IoXPJAD_MJUHQ(ck&^`JV}z(t<`J5z1%x4q#_ z0R&U#hwrk%)4{Y3IHPUgUpr}?Wq1%q(<`XD`JnO9QX_fG$&32)gs8Vihpc*Y)79gt zFp&FUS5rGZUH}M6q&=fY*jo31tzNzM6@KuFzR-3khOv~gjQ-F+8 zh2XxHW*Oe%#^pHV4mM(2c241__e$u-Z0xtsv}bU_U-?)|83<%}2f)8+j>xAu(o8eZ z?qN=OJBNnvOP5&hKnw3gbOt@46T@BwfD59u5aa?!UG+PPO}dnUFns;u_oXTDR2sSr zz7nK0ap@~#Eri)@R|ma>D4h$d{6BQS#5Q3UUsiqJK>1Q=V(S|$(Xg0DW8A5n-~i&p z`wu-}+dV02ETW)@>)1dmbiz*=MH&GN&$S?0LEx1Vz`zUx6L^vf{Sr8)cb~X#X3RB_UYvZ{%&rel=b! z$`RCGsP?v(1sA-6E9)=q2I_eQCDbh1{w|gto$?-d7Cclc@{1XQ%#l8x=v!q>$`uhH zlVhyf5+~CxBaH77(IT=xJWzy!Zcb%03_z)Q=Rz^!tBNIuA`leM{SlK*ftCnQFpm7I zT$?Se?ufA>*S`xMHqZu5_KbA1pZE~_J;iCy&hVrVY;3m6CMdXY9xF*o&RD&vDH99f zGsvUEK3bxN*&Y@1&<*X0Y2^H2bPFrGAQ;Ii6nGkWVl}GbLvbXO?APaTPJF{g=f#&K zRSmmhH2*Zl)c-(n=9u?Q)*l06PYal#M~9+SOmq7ZwoK?sxI1`Mq9cNN8p+>jWrLYKLVLu(z`HHItjB^8 ztjb3OS^Px`S@h9v+vS{*4+k>ooUK*?99e^;Vuc(gphv$^%Vs!n%%5QN(S9{caQA|5 zQlZ*apM`xd13ZxX_bxOYm%k{rM~|Iz9S9PzWrCt@F*~n*04cun&_eHjyUBr~@HhTA;IlnCd_S@H~1+ywI8r>NAi?Ppq_C#ihRm3j3_NX=nsz^wjH z-M|bYdX{IGb#;Pf8#tJJZLH+~uw_8hJ9!AMwZ$(OtYkA}`lej&dV=qwQ2T~sSS}tO zp}wW+8U_xjsfl=Ee|(8Wvwex8T>vFz@7Rl{sy3_8&)eV_lo?gq#4_IKgI8iamhg$x zrvolvo5%aU?v}ri4KHlaOSDf#!;CB7HJ7DQc4P-0u}P3O$?H9Kp}Od(__UoWj3wc& zI%P}%zdV&kNc6>%T<(ah2ycv*aeXZv^RdEUfQih)vWH5h{@sgj((3(CXS`KS`}5(4k4obekx_z* zm?@k|3@xu5Cd$U#bAk1zGjW`fX67KH*j)iIq-Q4{OdjPLuPJ1SQ!u)c*&@_wW4S!s zr}}T2?HxM0N*o#tG@Z(%v=GX`cZk= zLtEy2jCn7thPg#N*gF>NvCW3|{jyik0AO+{6xis_9(Wd#BH5PmlZAY%P*e>QmTe4R z*zH(g-?iG-py*{CQ6;yVsMEG+od|VCLYp5!#-C45X|C*z&ATCxJ#O~f@6h5kcQpT?RXGH*>=(=Emo|ckLP1&-vvf^cyG0M!w z=`TAS@H1b-TMmf~F#X_AkMQVtVVcZs*kmtNq~7x-=Dc*QpWg2gjG#^gYcMw)`+U>4 z)Jpr0ufT~R%atxgEf$LFATmD{1g_-FpI6~rdJrayPewbbpQ7B}?)-gz7M0gSA$tw0 z{*;!9rF(G+Oo={bpj7=ytNYekZ#bs4G4U>aB+fu1x7~vyPlSbB2Y0&av5yqWUUG0d z!Mk=(KuSj6ZG+0sEF?cl%FG8bKdv*TAhH4${GS>7a{c!5w94jD18uI;cc3L?G}Vds zd(UI1ESx6T-Q|KyQoU#gGIBeO+8fAKg_{}Q8xx!Uo9$h3kz|rv5*7-p)a09)k#vwH z8jd?eu{%V4DP;p!fD5-_?^y!k&g_Dwr-Eq6{ZF9XRy?6gIY1tA#GF5*118+p|KT5% zMm^X`TK4DiAHL?F18?D&?eS+*Dp|y$Ay?EBekEF0X5X&$VAY%W=;B8GsA<|a+{Q|K zg^l(H_b=&yMFRIn;=I7PZ*0?xd`T`M7Z?USuL=3BsVOP_`#2t$ORhC!G>VUS;EMI^ zrenz(n3YWa7InsZWg{IRq_@u(;t2nh|lT?mbN{>!NT^Zpo)gFHCo3Q zqQ+*;&Ots?{fgark0R!kd9SPq2c^nOi}k$3ftCDp=Z}>rX}Ra3@g+@YJDE&UcpFlXQUzQb50QB-6^)#~d{;{}!FVh_SF<^LvjPU`G zz@-)bTHWM)eM{ZPP{&Kb8{qC9O^$!IYw$Yuo=d73Qt?o<>K&A_HW}BpX~~m*^V70K zxGT1G`_NM57xtKhh3qd`H5bb@vv`-9Ou#<4Xbs=8fi75sESgUwPUe4YV@vGjog)eH z){yWwu0-v#Fkk7|1Zgp?Q_E_zJ^oq{*SPi+VuR?l_d6@viU%&aJJoG5B4!~a8G(Ii zswiq@tXtx?Z_C(jWt`u=fpuF#gvsj9BG6@jemd6Yh{+nFL-y_MUb4VHHqh}^Rs;*^ zzCrRJhdvKn680ZL-~(9Bwy!FxA_0g}-3YbHw`%yD_)@}3pol$PHDXy#Tba!eHr7Wv zg_0_E`_kAa21squgqByWh9bS*zXB35+@Cn#A0|x(5LKqFqn4raZeXs1=D5CVM_7 zFwGI@NkssePRT_p!6d`V4oUR|)r3bBLKcf5jSM9i)hnl(74vuej#pauRvMA}fNijp zZiIS$Z=>JQgF43BLcKp*d9>xtjWT9QGEz4C6-xYWBy9{ z-uW&im!{_EFWR3^8J2VyE(3+Xjc#9hTsbi~qh4J9XF`4P7Z)Y5ai1F%Ba$KRoK@`{EN8PZEC5n&*IRV>aJ0$;Sl4V9!so@t-vro`5p3<=*$0- zc8D(6)EocsaV}kFLKZEvO zrmQMpgp~h#mDx~v`>w=qvf^g&meYBK;*@rvJOFqb+q0R10&yc(Z<{5j>LL-0bfo$}r&F2iD zWtm@Vf&)Bwky(w8(XWjl_4*jA1s)Ud>7zj8e?oP?M=7n@J3)imR|I&lJ0<4=MHW1h z%8skQg7k@CfirOi5F27#M(jrQ-EUSt!JRcj6ss%ZYv*Y<0R*Y^>u$nQHGus{{n%(w z?pmZ)kJ3NCKj$?>jkynn>78l`Up7v2pf=Bi141R2aiWJ18)$~a9;?AbVt6Hl_?E)L za>a+$faanIYVaD5hD%~@-zpU&40X_5>76@_l1}`j`9QA9NHKCb%>0`vFlD4vD<9~^ z2Kdm;zGeJ!uW#`P?@9Y7_9n~j3l*Nw%do#On#p=MY*QxRq20R_>U!2YY=0i>Ye7#e{r)fYUYiO= zCOVbRS&%NXjAGl?s7U%E?^VsEhlbO9d&-yb_|(GvgHsr?$;^2fuKfoisjS*`nLuXC zj9|8U1R|h&&Mz znPGZHbBW^dx3ZWXVXXE%F=?u`6(JnsdRdu-q*FucuPKTB)IV*^pA$w{0XMjR==r))Ql|61{tbEtIS6qF?dc4eK>PS!_sZTIupLH5IfDk zkz#dONvr@eAo+bF8kAzn1FDjiJY4#!joL8`#l*UAnZp~aRcvCpL#*YVc@sFkah`l) z`;A5{Hg{OQp8vC*wcI@Wo_XO<0?zOE6y&cykE$oEMuO%1w-u4zztmCuzeVFT$wkP;n3ZMI-Ft0wnIP$`iP#->< z+$e!#&{8{ACe&yY#=LE(Riw>#=Mxgl%tu5UNB^E)`|5D z{%*A{OJoj9S6?i=s3yciWH~!7JVwkAwVp`-O)yOE34oY;bj5Yy1VF5315<3EQzk*O zfkP%pvkpg2JlOoApTP?6M|vzLLSf1eSsM~7`1OdZl_K@y? zmog^I*ElI1&tK6e12gyF^10CKENQcmy2m;-%qsfGy0=iUdUe5RQka*nNJtqIC_Sw= zE*W88ePmg{eb^kL(%>f8C1Kfqa6Hcd@(M)!{YBG#fZQ5gu#;xoTrm?YDrCfe(+X-$!SV}sfc0jFgQ3?{Bb zHB@q4u=u~hBCGjmc15`@H@d)%4w%;9UY(P9t_l zvbC{IStB;x=v7@1pBWx2Qui0wpLkqbIBS8~O8w(Y-HmQuNJYS3p%{|{XxjmgQA-ha zUV=;#)d%-?+2eLN!sYSR*^Kc!`>70|ADt3%zN24_8d%9N6?6>PogMyQ7_i5Mf+89v z(HluDXX@kZ2g{xV4}1jwF@)0;I32}^_psrGo+`$LcFtH&lk?;ZCh-B+&s1MG0mI%{!bl%4_f?XI8e~Ev|ch$Vs7CjLo&qM5A9^! zZEn*k1k2dkvUErU8hf*Vl)*5MZiPP;8DQG9+=H8+$V4O1JAQ6>64R}yjrVez9cJ;F z#j|>$5j+^))NG(-=Q=3Ex5i_ui;AwG5gH7n;;cp;{lKzbbu#_<`jLC&whroIhc@tB zKv8+(I<)jHM$4yE{j>sEOuoVtw}!0zTfqaeb3G8S@jHFl@^&<=IO&Q8&qz@K)^NkH zsA|7D-^N9oAhZK0e-KO^AQKbULf6y83ay|)iC{(eCT|BD(I86tW=<=Zj{K;AOrrQ? zYljko0%Wb*Z#7##+%-p=v%99w*eN0+`m#FVij>Gc z;@11XEqo`An%ljCMjreydg!#Bq+E#5Sd`Uv|=Zwh{W*C{AESV-(L)@jOn`&B`))B4UKQ2_|)Jg zfY~?lm2b|UD{=tVoM!UWA4_v5lTYA0VyLW}A$ay$?DeW9_)*6O`XT>Ia(V~xb1N2X zu8y^)byeR=UNKfT5*$?uH|9R)V&b%5{Yp7XOJLAbTUYq%9+5848_YY#9VAr)>z&JAth9RM^|QF+a0W z_z1E($9P_;RIa)&^SLhSV7w=st3?i{NOhXan7y+XnKPCIeL>srAxhv4MUQ2xK7PqwwK# zegrbm5`1O{<@{tIf^bZja=W_zXW?zdIokTJO~@VR`jZX6hn0@@@*~0ukfO44_hHFE zt}rzBfthUpOcw$xPZ)R=VnfB**<%i#cruXOu-khT`DFtL2}ppecMNRZ!eDCS&I}%M z9JKU^@U#{_@#9GAysCWu?5_xc4730UWT1J-Kqrx%vV|Gqx`sh;S~-EDMH=EuO$HK| zRnC%uoFqsU$Uy3$LI$D$(AftDfF6nrWa0%g8-G^Lwawi3A=V^! z==(_#U&|fW^3Y_U9e;yt=tCq%#leW2eO+^CnmED8Ap|yV8(>TDj)9|hESx;n!o-Yo zte2V$WMJyO@MpI1fw5}>{Bn-5-M_N#kFjZepggG<0iUFIAGpi&i6jGck1LRYVz6<+ zWS~w3GEh3=`~Ki$ps8OV>3Ib)VXL7Sz@>}@4vGRby z>YXU*yS=o540@*Bmyui2iy;4aID1FS?MpXiUuuzk$%2I8@Y}ir8C};{0(knjLC>v^*PoY4v&irl(8jsX5BX1umgGBcW220kD@1+oD9UFj#Yd(jg0Nh@b`^_ zvsVq8)#YtuVKAhdxzZa*`rzcs95zo7j&PmI`E`YW{pygxpuGu7{OAk>xyK z?pA@6(I=P;M3pOsKST1StuQgR!EzmQXc|*OlpBL~GzPt*xiRQT)nMKKP-GxR7@Inb zxX3LqGqAvNvm|UNKZ2TFRC(b%5~H`k%G3%P2J5k^^e{wIzs3<^1a-cmx;aDmeUgF5 zpKBi7k5Kmj7@E*gVN}_Rl3_ogRL#fnacK*@>@4NFQid%}w^^mVy2eQa`uRh{fY#)21a5l1#L{aYzNSA0mNsZtVi5wYd|ATE)s){)9 z1tPqsV;HO8S4i=gpHVw_4C!$j@RY6t^o;#r<6V!GuEVO&R`&(+tDnWj_2FDqFpITt zkDf$!^B3&d^4m^9G)C1F`PZhSIS*rNY$7a~&0q&jt9Yy_{t7kI6aEK*Z=DAnp198&`T`nI2VLa)wnVjXgs-bV^l1Z! zo-<5neH*&>D-on01O8}apvilXcJ73SxdZe}Jz(VBi2|B`Q}1p94fu?!p66)?4)j%8Rt zQ9 zKcQvce<3F^8dg@6JG6k^miJIG%vBqvei#`jeuV-V=vj#8Ap_M;oWZ7G7w)Jqb5FQt z9iL+jm)k%M9WyqNp1B{aSB-LUE|vE3dzv8w$qCU8n-P$&V5_U(2Q=*d5yfR~@V4h{ zAfrGD5*o<{`Wv%>f_(j4FR=~u zn?N7~2_JPG5({`a`phS)+U>X2h9}!}=k0Cl^cz&e!W8?PMk=1+)4bukp zlTHR=4k&?kp8p-q&-^Evhpr;8@fAd-iVzaG6)xTzU}5799XjTL&hzwY!uAnvr;SQ~ zj~N0PXaNw&Ko2GZjhw>vv~BF%@`tJnM5F}@46^@ne_LuYkfc%qd!{etm4UhhGLTv@ zGSHNq3}g*W>jG?UI1T9o0zmVUfofzhkslMLXPJqJv_VxJKNm8`u&MllJW=^L$Uw1N z(!q2EoP(tgu_pCl&|&`4pQY!q}RwZSgN!8K{3Zyey~5K`iafgXrF18N+k;y+#z{-6S+#;(HUGW+Unh$Xpx(4ykRa8lb;rOc_6DGgM-9MnLdKdvt z*3i&#h5fodWDfloT1P%a%*Gh#>zF~yaXTU^zg0eZfCbZ@k&}UJp=}z4ps04lW{;@O z5p^J4k$)H>s>rZZww>I*teZY&bKifWY49>~8(u=>wn~HqY=MjSdRSPyL5G4%hAuGn zZ%6Lvk2B-6cJMg1ZCV43X)@4nQ5k68RiwtHz{ZT$Z3|dMK7;(;@43MEqw2}WKn9v# zhw3QsxceVy+4D~{^q)nRY%jvMlpv5uS)Qw5X64F~%bJGXuw6fi;)$PU$K!8G2KpI| z`#wdu$10c@(NR{82rN3yCE-;1l7q(Bz&`j)lO~NK5YeJQ9GLhHv4J$7IWka41rZX| z^us{D;x$xMNnkGb6)iowkpF<=B_AG3NCu)YICTr#Hb=qCoOHS?%pEttacwcy7d(TU zhA&Y_V{qUSM8mfsYkLvC4uOB5W5NhOpnmW7D9EXVi!D`#aff^2t0)@#FSPW(gcSjF zM3_0W++;{@zN8w&e{4|aE2^6_gx@I{h#hTp3u#dsU}jE7g#{v{;tV8{e??tO7aXlT zp>JXjgTN`2s!iA|fDE)0_T)@l=Se36$>Tb#!o%Ot^vr*vb>udRTK6M1yB1*~DRB4Q z2x~h}=rVE(S6y{m&n?xyel0Q(vlV`XY~mhdRMBmMAHP!uB~a^{_h3a(Ahc*hh;=$P zlpmFkLn1;Q#pZSKuwY3YJ9ut8fZ_)w<)?l@)7YC>;T{SjV+ZKyt-wn0X;e@Ag2p{3 z5#{R(BNKTgSUPHOy4ry{=-J7AZWr#@O8MW(2MMm?+mVDHAnE)7tGtb&W#R%oV@FmE z&&W0mD~hh8c8_{}K#8#GwgcD@WW#JwZM!%GzdO|$FScbG? zTl;;~%@y~cxC$dH)3~QL4|qQI&i{^@@^V-*;($H0U7CWV`v&t%!7nA1?;x-J_AI&V zNhbrT#4`Q^O@kMa+weRhQ!3a-dKa&CGqDUULsysu_8@nZSAq67Ap^-l_fMN2_WcK1 zp7~Ek{%LvUpJ?9mSJdzL8I5~RA#T+Mn3_32N7ofW#U~*epC!JM-pNf|2U82u-S*I4 zxf7N1l7Vg^c|BF*t%-TS-nZ%wY@no1QMN=j&~E~P3?zIMK77uPKn7ZZ z&+MR_pA1w#bqD>s4nZc6ffVo)DvPsVM2m}={HU&S`?gO0Ha3@z3mSae0r4$cLT zfkwVVZh9KDb?JQEFt|j&4DoaSH=6fwXE#3(tPSmaP`Kt3O$Pb_Qs#gXrpqC74s|`9 z)4t~)Xx{e^H1>avEh*)&w&4OH`W8NL-gabe=k_r|AOkG`0vYH5WFSg1UPM7@4Xkx2 zapwz%&>rOO{D+zVo#M6cLppK`;$cb@-+)AJ0IQOLDi=)#B3;qaJ_a|}0FE4BTZFBx zm)Ook)xJ3Vamhe`Lrdo#1o(w=!Exci>DB_;#O=CbQ|{KO@O_xvYX_Wl*M9s3a(lMZtWTjrnXIm9BQ=qxt|So`0a z4D>l7yjH`6f{ywjSXp%%;s+*(>A0u&v6?EhvfyqMB9+KM-yRNUn*gp{3x~4D)YC~Gr_YjrK;OguGO&upVCcKQo zLGmT|*Lhq}=PRn4GlbtO8Hn`FzoUkZ3bT*RiqI(;^j?$+0u{m!S;`81_R{bSN`SVv16sPEX8LPuqK~vunnl|a!P;nH} z88Xl@#IN51O9MI{%@_VzAE0cUUoBh_Kcn&4uMz1zz2>aK>gqFcGSH`3=dlhJHUTiR z3!ER8KC@&X6YrV7n>((=YMGkjDaTGBYf~ha(eX=WZZL5PhpnwWbc~3&R*tPhJX=qW zx*b2Gw*M<^jokokJsW77u7lU+w;+DzZ*#mKZr#nxlfuQ!npI^q3vEDB>rZHX?krNa z7Qw;H3nu*MQ3G=;KN!j@9O{{P!c;CB6AL%^<{U!B=+Y#@Dhje-OlvmHYmXijJ!ni! z-9vT%eyp_RVi<;Ys}NLn1I>GGqoQ>RKBk@vdN!&GY+~*SJyU0{8n1;P%<0)IJ>ame z87V#gp+p9%+x-({4LcF+8^DNQ2DWtU3XKChXlMq*Gx-%1Pg0dL{{8-l`jKnbVBx$17FI;e4S-X`1ac>Se(2-0eoHbC#Rba?(qU*s zM}@h=KJiP)CNCo`Dg@d_&M>qJ$BG&{B23M19{~?f2BJj&acqrEgQYH21+{_8=I4-4 z$KER58 raT}=#n_yu{BoG_uSZ+at>>O%#EV0i2weg1|195A2Y2G&I8&clJ8%{}I zqjusvwnYR&%g`A{w$WH6yQyYvSpXSm_z1QoCd%V#>*1ZaA2o}NOENssWS}{G>X}%^ zf3Pa#4ZR;Cero|Nttgh^0DbFFcx9g8Vi_~n{SC=L55mrWqONTa-Y)(yFtUM$X(^Jr zE1-^aex^w1I9SFJA^XbIt}b4XHz7%TH*1_AFA8is4}9 zy&yY5#RjtGY#>p~XzJOcG{NhKK80-R zXO*_KyZ#-bifTA2kb#ohxy6OwmJFG?-RD%uK;kjl0ncxHAGLk2!{31``L6HUfbD}H zV5OrS^i17hhh&P91pM|FkXo=C z@eD7ad`#VYxzFA$oyDCQ>kA#)@j@lG{fAJ@91k_^A0wnhbm5VK8V25lbm|wCw!bZi z=r5L{u1}aiUb^R)`iaR!$^3XeiGPR0(uFAk(*vo<#zScThS)J2&`kIn?d`59B1TS0V$I)49&18;OuI z^a=tUykJO%yPyzjp1)UWC{yn54myoq(tr8iilI>LB_d!nBT~stb%Oe9_ zLlct|5fjcKJGTpG(2!k;*!60Y^|Lg47sjMI^2bf}t5D7o1?d z@eNe++wd`V8cFLzp-Y4ob8lFOzs9O1vhSvTL{0w*#43@2K1UM;wfX%hf&v*zWS}rc z1}bKbx0?2k07`I+MfGrX^_U?8ZD~3OsoM8Il?5*F$UyI*YJ`G2^8R+-LvpwmjLc}) zt^?fj&Y@PVElRitNozkGX2?Lvi0dIcNU;qCFfO5lM+Ta;Clx?JqlvRhWFWhcCZvxL z36+0sN<^PSW@b5Tj7_0o8ikO|mynVtfs?5TmKkkESnj8&-bYTg;_rS&zVq)Dr5qsd!{!STtTs@1rBnX+`_8}2|aVoUv5pTYIqi5K4vpypy;|=GwTk`yD~*G zkY_2PB}Cwz>;C@EY|fQ|t6*SAiHns8E&3c46y%*!rzbz0M^ElmAOnRXyiuK=TrBRQ zdiWHQw@)EHZx0g7_oHf3oiD`AFAAc?t#;v#U^7_-*Ql3JIX$jur{KV+N@O6fO)sK! zpZd98xBH(ct!RR`0vTv??FBTdl`B~v z5b<%_U~OWK<;Gcvsr`0|-}8J>_7&S&#caJ$ZAYLrN0kgjez9uXhX?Z=A3z48>yUn} zsU1RqkJk(tC{lbJQkL-G_et~ostOsXXn;$oD7XC)4Wl1orMEwfXqT;RInw$*#(Ezg z7@9gj$GR5T<9BE0vl8w>$|D1%Gtz+k`=zocr$NU>P{mFKGLU=ZvnY8Cb`ZnPzoD5& z2HIHvEvgltk%5vUSCLuzJQDJEVRPevshHG`jFOH2fwT4sQ!mwUoFd?yO`{`;Krtljlbh-E_v^l^n2=|uZ%L{uGz zjLluDgmDU6Hg1I_^AAnou=!==_w#KyMf{AWsRLN)u>wX!ve4Urb+R)MlMlA@64LT^ zA|`Vb(OL6h4AJXSX2?JWHruf=dyIV-lRb%T{j^Cz@w&HBB9*{i%Nn{St}t>*#5(au zNZ;fST|GPKSop#!VjskllzidGMW{meMdTGrVQWD4j2|3U_9A=ouV_&FoIjy)=LsYy zZil%pt?AaV4%DZ6!b6(<~jE`l)iIYCG0;@)krb{czH=V->HA?Kfg$^FPtB z^De{#hmk4*S?~q?WZI)YlCYqT|Kc}%Lq?BwGL}* zzJX|mI$bJ%#$EqFX-+aMtf_i}J+#Q@7`zGT95Puven9QSDP*Q*z(~^)IwsCAbk9I^ z*%`F#{fUt>Rev)F7$M>WQj$|)p=k+iV>ei@ZN`?WDTIf*LtEDp%Wb0&lJf~9RGCI) zoX-uCN42<#M+Q2}h=lyMOOJbycE1E4TYHv!F!X4`_WpwiwX=aPC5Jr+QM_;IW78CO ztgPRM6*FX@OspwCgO=TMJoApfqha73gs+W&fj$u}Lg1giA0?BDu=+KA@-8Hu&%wu*j$QJEiBBuG_w7fhtrc`kd|>W7g3@RCaRdTw zoWF=d9vP@$+P-0NwR6ZRsD_O^u4WTbhm^j*&FPmv5+px^>2Y$-8XMOk?FqoyqdgnmNqmJqg6!N6n{g145y%7}K= zxWgvp5E@uAg8#dFkd0qKS#>?E^=zPPk}J!+7M;;tN1I@P$LZfHe>WFytB`^Cb9-KS(b!Te z6YjgYx6l1`|Z_v1Hb$Ut-5 zL*3*JwPYZrXP^8DHC>+~X0;PE^z5K*xfbq;`yt)U8Jfy{-bd~3pHNjhiC|YdJhfaG z8qQhRAil&B4bmNdg|u-OLIZuFsc#J}vpB3Q_za?5zpxzyv;ENc`vrA7P9iQc4i=_% z(ABemdH62m_7EY3n?JRCZX+W;4we=U&^54xNoX&!d+xC0l*+iH_dIe6QEPl*L^{*R z9uCQ0pn6PMc7p`!oVgBeud{*z?}R3 zdr10^A~7nI=|)}iFgQiO3eldQc+bi`Z|%<8C>9UE*Tw=``qnV=&O@a11{!wGJ--(Y zlGQ!rq-Md+${xCgR#@g-imi{@2t6@fGXp^(KatF!=Z*l=0n&UgK zAZ2|p44Eu#VYz-Uisi{i)$5baqUr$`QBeOW%El?EGW(rm_yRH$)MD~N#YN}>I zCGb`Ik?pvR@}_4HY;OT=13MVHBw%&bRaSPKyRZ8wuib%QXKQHaSwr730jmnXfpiyL zQ{B6yN;`K@(QyccUFT3eIs1Mna8^`U0XIi0Ry{<+hbjzQM7@gM{~mRFE+Q!+3YMmH zX1oO)6F)@RFv){2TO!!~07-o_T&+D|U`$Tp25hLmiW>fVpnsQ_$;Tr(ha#k@8;>lh zeHhUJj@;1>7QS#vH~0+A?N5nCPFH&5o`;MozVZ8FcFIkZ zSC%oGk}8Q9StVgpH?6gbZSO%gdI81R>Fj)QJ!VgKA$Q-jJ*nJh-S}zb?-YvJTrU}+1rw#!Nz1ToY(F`5fLuvdrB5l)u!kuHrKUvSHdgy1(ZyG&jXPJ zO3EAI?Q8*UJu7H{*5-E5)-{7$@_Q&Anx12n@VJTexM-ML zI5Pjybj2hJhVC-mrmjEPuJ2Kqoe2YTS58NoRUvikwnCdbkjUy`M(YWc%{NU$Y~OuG zlvD1D;KL(u}ZYrBxvO-I=% zw`J>gQ5D>0*6ms(HC~XfNi+s8Atg2eR&jN2~ zPupSxytcjrv07V`bpI{n7d61c%94@tO#*8W-*{8)`;zVYPgGSk!pqGb8ah7k+>iqo zk2NsVHG$sBL1cE^WQi~3-;o}g^#N7%1e3-5{v#&~{%`rdqL{rW{(aB%HF|y~{=)p5 z_BQ4htG0c3knh;g|F&8(P(%Nl>^c%q*Rts~>{Egi_d=dqUsP8YY|GMfE7Z|7qM| zcaf8k4+k4NXzN?T$gc(4+ODIPt=s(msQq72T_T2!9c^B; zgO*_q;yW*(dg?x_wp+LB7t~V)$$9ZN)Qc-;d>ljnF_cn;K=vJN0-U}ky^njS>3k0> zJm~l;N9ddT!D;g+knX#O%7zJeSvx@6z!rwqVF)R>2tFeMZ+@)bG708L;EO% zuGnLPXVEM)TgYNKVLtrHMVXbPQn;)Q*(4JM4YU-t&)W82=dIq3dB{XbT-Z zFSw_^i_%dw_9K(U4`|qN9b3XfVPfV8Jp)S^MZSSbI{K5}_K`sHI*qJyG3+$V7!gR% zH4VXKr_s9iCrF5tsq&k-KqRd@%qDlYH)5M&_3RR{s`v~VpSgmptSZ=)%HsZA`09`_GW+R7(2tzbPa;Hl)%!! zfmLF$Px}UqOF0+hknQ{{D(d!Ng_jxgnY3-A5SaZTBvVA5`fvAO)WSsJ{P#0*e(T|Ofg>m<(#o<%*HeAyA)M&@e~9AC(e&eXFEfK%r2Io4P5$4DdVJwc#ADi;{!`LO1=H)U{B8oydI9 zhtY`-0>y_0KUL!k7NW6d-@JO0zHd|{1F>A~`zpCwYR4x75s!iR42t{);ym!lK(et5 z@@G`N{(~TV9{$7U{0L;ACHP3jZp(2)g$zW75%bEsj7 z0)F3jAlvaBHis&Zfwq1G(eN!bZRP#fHN%1!r|euuV&YSBaSg(oeui{hwLM)!I(iwU zqB;dKP&>9$U}L%s$$#Jk(o&LO%oxwkFmp~rc=btCjNV5bIf9IB$|+yj#0^w5y^4sn zNpK3tKxEA^ln>pXz5c|HsPCp@!q&5UBu2j9wl`6_^WV9nfoQ=Pzk?c;dtD?bs8ZJk zTN^L7`&~yr4pHrNOc*T$oP#;jcKi&I*N1RUHti4%XuP6X1~FQ>iW!9J?n6k4 z+YB8;2k4r4!))zy5baPHtK+v(Gjb7S4euhe_G9FBoMnbA_x&A6CT>DB_$_iI?;^AQ z3l#QWVxYoq z$vL!a9AJ~CpX}Zye?fJ9JKP*Rp=aa(Z5>Z| zZhaoPT~`?)Xy)AF9X~@dd;)pfQdsbtNRj$P21@!2;t^%p6{ISVf&7M0G<_{T#&1E? z`!3eFD3E~)&Z2<#nyEP4AIeGLTsqoT6Sv6%h#deX%)p6GftaWipWD8cTGkTzfH? z{3~jjN8sn-2VFy|q~rnDm`P-|U*>kT$#oYwJCc!8C@iamleIVWjhvuwBqv|c?wgr; z{sSt-b#SnDVgw3ZJ%9LT9Yk*b4SB+gt3Jj0B%RpzIWjWq;A9_y;M_Nn+k2PIYkKyY z!E?xnS0Dqe+KX~>9{JBeYtrORlr`={u$=-KsPZxzrsXv810=m)AYoH13=OQHrRNMw zpJGJIzC+bG?F63Q4MuAcId)f2R`(oM2gks5RS7oLoy%4zZ$wa=1AtBo!OZ zvjDa7eNocv2DU{O1Qzfn0 zb(kgEChwrC{UG8u<-^e}5^KsoMe(3~9VUmi{|h9C4EohAc?!7=$0 zmONFxwgSe#NA18@NJ>ax<)L--eBr!)1bI9E$({X9?__$NdPGnaf~radGLT&v5*of! zUMFX87qvUjB4M>68R#G?25-sd9L=-y*c?JtLr5l$(DtiCX8UWn^<*bAMQczl`dga-V|ijlmO0O^JiP zjtz7Tonh*jf(X$Gh)3n?wS4~8PTfLP?G#oA1!I{W%_|}U6(U+XyYA4uE3NE;pDR@c zbbyXd2wcN^k=k*djc=}E1!vz&$8MlhG6-L{Am|&|L0j7ij$2+r!N6T+kIm}Ijf>Nh zhI8r3MPmFm#C@M2CddPt1`g0Q_ksD^Q55yuf^??(6A{Td{aMxgHlnC%DUp3F*TH@3 zyO0oJiTa{3*!vMSE0KYoM=@1i;9r-?^=~LCYk;>iktm4_REaq08Ad?lw`KEn*JW(m zv}J}2D-F$!)S>i_D)Co2lcO$pwj(iMJfAaCz^*NFvBd}cG0lFq0 zF!OCh_6V)5{O5QS+(X^2?~s;O3|l&~mXU!vk=pfF)bZvvTek}R zZy{}SJghBfe%NBUOD2gqtZ0WsfhoVD%>SU^ zW=-dZh+W5#fpp9_z-!BXh}7Der2B6muaHLus=q#KyDFe&;zu<0zm8R*>!C}DGy_|9 zyq~s?GrZFdplJ9fUi^Uk$6F9}e2$#j_mJ6e6qTblmHbEP)MXSle9ZjEtoC!t`cqMk zO={$*;wy@keMKeQfo$X~j||khVBc;&FyFE7Ej0n)0yOl!!LCDb9kFy^mJDQ_jg3`D zQA4o|l|E^nUqVXsCXT2^WT43BP|$l5GBtJ%t$PjaBXG3wW<)hEu0zBiU#zUY2JxJF z@)l~E`xMAPRY;@c#LV~8_GF&2Jz264dooJdo}~UL_&)F@lH%j#c90*O*N;9HJBWcu zTCEfC_6%b4TT5p>R*Q~9Oh>m&$90%IDY1AJnQ<{NVHLKi@|_6DvQwy*ueCEi%#NQ> zGk6-?lhe5h11ttKj#B=$$8MmitP}?1vl_d>*d`B2qkosjuei-{kN%uAhwh@Hb|*sJ z9HF6S3r(AFcx`oj!gW<#oHUB5(~1^oZo|3itTukv~YuA9A}2c?rZip+JYecnhWNfa zsGsDGD~g5mA3|b6Jd6yjp{?f(bN5_C*PcN&Z9tN*YgGMFGIRu6v6{tZHXM<6L_FQE?ZgWtFfor{g+Zpl`SW!FivcY>bGP z@*^TxCABNaujqo8wH2#iu5Ym#9(gCxH1&|rMmAvI=SYl;m6L%s!Z-OK#42mg9kzM2 zx^@Rby#26T&lXxHp72O`4uxc&(?+q*SFk253AUv&W3vCm3IEMFO zd;dRJ9822sA!0U$!_deYT3UW^j2J`q$UW38P=Au2D5`3N+cX)-t^}Lu2t~zM_hZFP z)JPj(YCvJjS8CP6*O?779r;}3r4 z({Zn*K@DvTX?ma8mD^iRp;{>}PPXM)9+B~<+fZDx z?`v$0kCVqGU0~r+fJn(HRE?>LOHN!tdEE71I6Y4r+0@Eok)R6tk~)~vF;9#Pv=wnZ_tnHAX zHVElPZu=ETzy}U1 zy0Cp}+IA#@_ri4FLxU0RW@8f4iR&n7J%XgnVmMm4Lf?Q$fPS!vdl_Z4F;ShoCw@eg zq!G^cly`7~j;=3!()OchSng}e*JIjHS0bB4fP)>gg=mwPecIQ|ewpcu*+3_ep=bkn zWGdJ|myxqQ9cI%ukYgz}_uZbu20DVoEebY}TP~txXL&ZzkB|(L4V0l^1Bnq=zgQb+ zY}N*{D@8)XW3quJuA{PUG@=x?_n9X`#-K%)Lm z7@h>ozfp$CaUkIRSu#*l=OC)bepKo2Q9%81hR^H-ukHVcCDr(C$-t~A6~V#L9Xf{A z9CNbeEUJn8qS7}RxTg>o8px_#P{6}>%ZI2Mx~ZnE9OIv2tvwkyRN{}|3Io?pmH-u`x6l`Z`p0+yVM_@d(P?hjh_? zWY>O)y!ua&U3~ys^STfm;14}b6Lt==&Dt(x_WTSP1*fRLv3sZ;JB)<5Z7?Ucvw=A* z0?H6x{5~?J2aqB?gyP<-oKjS|-$j6oxPjLY?CcE#V<+euS;8@9FSb@6M5g2*lIy-j z^~khL??Bc#0VjJq#>mw(vWB^TBVx;6MW*CknF_|#dv4$?~W;PFp zci~~qS(9H^GI|+hvJnKjdceqb1svA5AhYidq$G#2yQmrZ2HUo8hn1NJj9emEP^NsC z3?w3Goknh9Ej%3EU|_cn-dlDcZ}2YnTi)|kkNtp#9))u#$>8pL2i0TWtF)b!$vxB# ze~$R5c-Y!{vSjQQn$?m*gq01=_9 zU}5VE6R%8cYCNML!?{32KV>oylf%sS*HG286Coa6Ft!bW%}Oa!DPhB&`3`D^ze8qG z5v*10@h_dkMx?4k=^tOa_ZkfM&&MS2v2~isR=Z7onYWvjE&+GkdWcX`nnC- zu8YXZD2Jno4J&V|VHN?`@L|MP9zZ&g%Or0it@;_nL?$w$Rzu&OG39mjEnv2O2exWHb`w=S2eI0P90Et^Td#m`-kI5XK6V$?J%_P1GL%)q z&@o#9$H0`Z=kTM4?Z>)tZI^>S1vZzTxGjMRIa@!+<|P*S)^~yfGw3Q zHMWC>mK$8w)FZy&1!UA5;ABN9Jt!lhg~~H<$)eJRy$G_mV5ON2oa3>o{1WHzOy3t(`e}X%tNneUr)vdG zT}K$Yq+wAr44GVxcGR8+ym)|pjD(A77EVNem)TTXk&eON+8N8K{I$Ll%$<_4A^Rnyi{EGCB)$3|5_4PN@9hI!Z7Y^Ars`bA zK{6yXTvsEbM*}6*vbS!4xupX;)=D#M;sLh6O_qqOs%n9=qZhP|i0EqpGv9i|u(3(& zHI2;!NGch?N-tk5r#Wh54_%jJgwwoZvXEo!4n#di5gkfL_&GwyDhv)uAG5^0>a~@B zC!@2yb1S*>(6nRUeG9^J-$#b*AX25DA-D4q3!1C^_8~#`V`W7xT%6sx-Hdjnh_63| zYPIiy$mC}bzj`IxF-t*V>&*vHK5%np&QXv*FE1k-E`1rvH3yMd z^D>glUq;?nVt|A0g_-UHrXV5N>e{bre zZPA;o>||J21T2mD>tGMAgin+zCwQvSt{W4b`uhTU|N4TSlkrL~R}Fp+lMyImS7aCH zAGvl6t;0S90BUv95|E}2X`wwadUvX`O7&$vqmAbJHH+k#!Sd^O-svSP+yx8aDEE}U z*U1YBSrtn6MwU#rA|)j0b@=Gii7FwU3wb8NtrI&KR*-LZlTMW_CNnkqUPq++^3M_5 zXQWlSmz2yuD11>L4u2bP@kHe-A)B3ag=L0vp01M@=Nr({(rBv~Eq_#SbMrhc$&Ue? zYTZ#w+|t*5zu?ZxQh&oxL-jKGbL>rIqc0h?Aa$Tvi5!_ORb(i}^rSFXP@?E+E7Fl0lWZEXmyf?hD=>+~{Mxkk*cJXoNnd9Wi`9_3g_ay_ z+PzJ><4{;~=UPJ-!{Ou=zA?5fd-UF!fn}5Fxz=M)6~$ zPG650p*`fAl-eYWg{)D3Q~6@8z;}dvwtLd2xC99qo0IB6y;0q0X}oV2K-x}clK>N> z35Pe+K1p*=Nq zc~XHq8~f-|fJWn8|9AlWL}h>psb0kGg2HvJ7vk)CbB!ni+`eL>K(upLPdi0dHG`Eo z=&lPH4^a>rFt+`|7YF~8c-QI7L<#~bAe3JO=9tKphy^I~UQ@)ow7VLGywzUh4YN6SK`OQP$h~vguuMXrnK( zj)Fo_lh4VNklh*Nk{|m$BbgO-Mb51RMzuRjdN)8WJ27dax#5wNuCU!)qet(`@O-vk zyviKc876LjZDqfh1?^Jg%j_@s^JntAfKlsJyoWJN6|N*o7NPisAM7b~t4R%gqLsHL zb^#XpIueUkA`2S%5`Q@8l;_t@{27?jUe?$+F3H401fSynimQ6yJ2dVJjzDPqV?P(U z#;)KjUjhd@zfm&vqz|GsOa_>e)xz!-i-F|f;3B$4#3~(p@a+Rcba=t6hBBG0)}|n0 zkpr2L${-5j)B|k%w4ua&x^J4R?&u%Y9*Vz=d_CS~#r&w*U`uK`ypYntb2^T-aru|D z>8TjR<%+=wSnTWR!4L%R{wz0LuBMTa&v_9ANlz-svOq`a@0>jGq{Pe9Ts36N;8Xu4 zYe0L_%KCJThdN8fiJru`8I6A`hImQCeX`g+Yzc^h3dlu5ucyTVV~aN%Kajy}L#b~X z?r`PD!1ks0&pI?!FvEMr2^IK--0BK~+|GV{;Ky@MDz0;gLhyB_uyIv@(7w+0ED&)O zAPO>?dl@0BG9T6tG^n3ZKIXgY%~WA zz2P#f^8?2r_Ej*FwiPzg_!H7zLq?ka@Ogpf?nj)R9TO*07)R>$oJtv!A^Q;r!EoVs zAuZnxDAi+6c%UP5XWJEa`~E=ZwQm53lp(Pvd)y&<;v8=PcCv(bVC^v7#zJ_>l)%Fr z83h)9T^dUGK!{1S64~nt0&x9s!@AY4+d&Xjhc?EZN}`a&__(hX)g_-7^<_H@xUO)= zu5V0jcMuF$26hah8qO6;wVIuh9Yq7tRNwPSNBahl6B&&~XMi}EU2&SgT?*GT52X&1 zW0CM8fgPlV-$!WIvE?Sp;j*wnOYj(!Uk_etJ$~NtR-)cUe|#;Si*1n#N<^0YbE%;ZGmi5bh^37J06tCbM04(G9G@Jb@H* z$>$v*m5-AgyiHMTDQb-Yn`zIaX@=4}2P0oPLU*q@A6Iq7WuE3eSQp?sx@SCo_)PoMktlnRE*vXIX26pCjcQ{3f* z&_FgZEpA_U{5xDEsHy@@v^>?~qx^NyL}F}!Z&5sdA5@u-Hs5J&)Hw=i{5#EXVE3}2 zxrL%4rj!SQ8X5jQICF@GQFn*%V`{z$bIn}p{=j@%%*X6I*b@6ccID5LBj^GQt&=pc z*w`lfevtF(V%OXeMCjc9>=cP+;ex~QB27Fr+)JN=Lu_;F_g`}FH)_;u8rXo5F}|Bs zoUn$X5*&;~r^=45A0RrMoX^GC@RpAO8{zKPwA+%fCKhAxwvNpmN#(sKb}Py|z@|^( z;!(sKaMn7Xf?@>qXf;+Kiv>IS@=TUBMYB`VMl~+ZQowXaO*6vD5=n6*NlMF3enTQ_IFj;uC_Y$Z)N|6Nc(6%~ug+34HlR zwW#S~i&m>iK}oNtc>F7^pz$6AJzuMt8K7xgVwy7#>8}8jA}I4#aDRrYCZcK2&bumr zfUdne%AyUEs~ka@`jY@g&_Yv4RS3R2i*^sq4PgE=MDZ{-a}yJd(WXu4cp6Hs&>%DlHSwb4KYsJ(bH`R8Ya5$=n@^q&IOF9J4|rP2KOU zXb}tz0}jK{%>7>QtyEMX3^Y#giW9$df>jjM5vR(cRDrU{1Ae8yD#N?7P@Cfnrr{#D zs*A5fc2dRNa%2#Sq~Zd+a8nDR!@umnHxGwpokh;OQ1daWr^HR6Qip$VC!LKoxe0`* zL@Ulcf{ddlY!)KlXc&Q{K}9#3caw#Vgb4+q>enCOMi0r4pD~~+!8F%$+4@hWCDM!R z?77ym{i9MT^V_L4ivn!kYqBmb5W;q}=Cw)eN4w#wQmD^gSZMOfp20Z)hFL_be`zFl z^2$=LX$Gi|o)e6;)Gj2kF^@IgFi~K-iDEu%0LBZ5gi?Jc5(D}kF-wn@k`hjeO~oZh zxVUVmvRZq3fJpnE4Kxc_pY2H`hMj0m4MR=rB+N%itxkT{bncK?FO9wI?EYiYs!j|Z z0YQL2e~r!UEs(qxYIPY%X^g>piGE|(CN0QE z1H5=z_kF`dx$r({osgQ^eXe8r`J;b>AvB2O;*tYaI>hS9x$c?p8TMH{BeS^so*^^y zy!~?|(ld?gNPth6L@GXx)IX~-#pN#WWu@_7Iuce4Woms41y4^!bXxV4$Wl9z6&;JW zb^)+5P~6nT-U_(nGB`p;@Z)u2QJ|S#6`!hz3&~N=qWJg>d7D*NjX@)-(-T+nf)bZdn$!EH>u__ru zwRAj3k1)&>nM=Bo>cdFm4e_-TON93@4KJ(dnufnQ(T0wX6|mB9CojBCg-R+>&E8bg z0gfW-HcOBF^K3{fC&pIofj?Jgt9FmK9dB(V^=rb`Jq6@6#yO8nu=Mf>`gfPq1nWvNw^0_)+W2j&BCt)>DM7~q?M zk9tjA9PZJ<9({hfEYUpjGF*=Ap~R?Ft^M?+7-er_l0z46`%g#SV|hq7Uf~GXqR?mm zL_OB#NGjiHa|h%2J|F4~q!Bj!sGF>yK1@2!gKlcmgmZ7?>f)*gLnfkduIoW&>d=nO zroS%f^N*K}p=kDCTSdeZQy@!!(`09ita$WPK|OO7v3K+nNg|Xeyg&`ecK=4^&brty zv`zPgY2_c)%!Uiy4*?>f4#ga4{HXq8|EA%vs(!CEFyxYGcW1R=i$f*jQuCgdd? z?Il`78!LGE>B|?dp2hSkIhb}69U+ANff3RmT`Nq(c9$ji%8-NzR`&+mgI}x_D@BeH z_?x1jWvn35@^|5i(|G zIET1cor(OUdVhQ2PMopU8VsY6G~{SZwVSP@h63952 zA&QMGCQOjQr6=^gXv;X+BzqT@e*|u%ZUXJRa>2(O#JS=HPZ}P6~d0FD`)x-fK9< zmDdxCd+lNDdf~c$-6jyDVx2{l{sAuX3eL!2TdPwpOa*dbBY#O;1bT*I&>`MVe%Tx8 zdB7iZaJan4>Q|DDQid{6G+uKCL;Q!I*ZH7)n9iZDpB}mzbBJQjn9@r#HKSngAQb_e z_KTBet_(zjg=JDe#IE*=c-KqJBrofcSn7zAqt?ajo+`yt$?O5dQ~Ki!t*W@K77=4; ztu`NlAYJx5S$&5m_r1kT*r^<*r5~4UxJ#D^V{_P;y_b2!8pdp1Q<%dD#u(V;ok}Y~ z`Ob(iEGPNAC8T3sQAd9kTMJKyvmc`HIa8;7hyA^C8{fR``GlXNV@ZbmU42j-;hzlAi=$-WboA)C_Q|0Z}_9toZ5Jy<#FEPba ztv-(9U*F@5e=phn`{0@aFtt2_b{cil@mCFnWbuZJy>uob4i2OR6}D*TdhLfuGEx3u zO|pngoScsL=#GV%7d9NepS`^qOc`T~m`Z#^ZQ!0>G4hN=a3YVXue8Dsz(!^72 zg}i8#BamHQ!P+bpq3ex7h4`eI(rGmb3)7rU*J>N1qJWG+)gXwmm)9SAknXO=WWWeO zq!UHO@E|mQx;vV4T3id23r@qD-6+xG89D6}_acEa9)_rHNhXH?fSG`!QDJP%k6xDq zXuQXPA?Y~2VM z(causY~{z5PWpquKs%me>5I+b3ml-I#?(a7l?s%J&_rFfM^d3B8lW^fzoANh+!wp4 z$zh5Ad!n$=HvA{+n&}BXy!Ag79imK-dX6kaYli18G(j7Yi*cbCnBOIvEH>u_m(qnW zQ@}(tl*qy16=f((of}}5ne-&?Ez>JXi7Sx>y_WT$Xe?e-QCUFd69U6?hWLLI(# zNueUVo`iR>QMGh|fOy5y@)Bf4VFhgF_P&t-riT&O3Jq1&X{HfKpc(eI-RB}bs>I6{ zRS|zO-JG{pB}}jx?@h4b@z!8MH+*n;-4{7e>iM+P6^rhu-YIUJg$$t157X$Mqfql{ zsKBqN6|0$&SPE$E;=~NLe{~x<>q=NT^2>R|Zfq2}yuCF217=Jt-N8cl`!?zF!|S== z;hYR(_42FAtlM3PYaG1!BNMpw$ibW!8_DLT)}!wyDYDlo5ZpP8y;t2oIEY?Z=!_j) zV<)KmOi=PRoSa7#CuXnOh|g@J|0`Vko-+cvkTZ`lLGH&5Lhv0RB_^!Si$l9T3fV`edKGDj1=1(FC_WkV0=IVP%MsaW%+n}% zbUnmm4zpf*IP=Uh+>S8x{JKC!rgRg_oEw2PsHdQ+*b5oguN4esKWR=DdPWSW*laf$ zlN73=%4&BU3}*e|PZuz$8=r74o-KCPBu$O*a~;nLMlbF?;E5aA)t1;B)7@HGHiF%#y^3#^_uKCJzDcWdeEsWM?o}0FrclVS!FoP%YX8@uW zw|TsauCA|dbCgB~9&X9u#`&7!24~Rsi5NS&qVD_3Fq$-!6>&+~r|lf^9e!}Ovarl? zVvjh$+{8}7RZcHybl)qgRUc(REAQZ`v03HsJ%*&PB0pec2^~tuVbH7%VXIM+)*`Tg zUGD2)LJJZ2x!)9G16!w%oQhu3?@)ksm~-h^;A-J19E%PPKePpuK3*-0N37Q7mRkW- zJN}SZ<%n)F&DI+93pk)G956`$R_Z3Pdeb#bpXh1BHVP?u$%BeIDhqktC*Ghk4%pEt z9dqnRy)a2&RQ%>pDWhU9E#w)a+y%arLX>#q;tjwdh>;n{jT+!#QBXLy=CAfYWNNDZ zlzvV8KJmL5)`*L0=yoW)VVZHB9cF5;`r((fkXz*s^fKwD-e{M8pF_~*IgzeL0{%1h zZm0OkM`Ch&KZ^!??O5@0_H8!(Ej|3#oLP`vP5pZnIbfqCz>o~Ai1bT>k zGrP(7sI%H4Ih{-Lv|{EJ0?XOL&-g7qc^DJ&`~K&#GJ zA?kb;+xNr>yw{^i*D6atMZ;GC(TlpmV{Ex#jEbwwisGq8 zkodAmwPNH;O9Toi-ZGY&!5SXG!4XnOoGmG7B&|tkuKGpm?I9~Ep@=T`ds?EW3h`&O zR$egEI?9EFBw{O+9=)X3sHL4YcX-L}RCnoLk&bM__)3=ANVnu}Gz{vl-{w!$D#+6+8Pdz%yx#HDI zm8gyP6LIuDdUwF&%pNfpRtimp?F6&9Fodk`w*nK?4tx*fot~O>Wm4Du^TbxP@To6l zzoiw{ow2q=%59BcMH;5YB`$otTyH>&EXw^LU>=PzF~B#`zBKC*Io;#Jr2iRjY3WH~ zI@m3ir*U%m=}u6Ks>U-(M&+JRR!-R*ro@uunL|$`@)QUuX$|gk4>DVEoV`LZTiENF z>?IuHNxEZQ>X@JL4DGrjS$g?N!}8|##cH5GnCb6fX0PBoY&60b%dcNsLd?fbmL-Sl z0r0VlY6By@8tz>Y9Da}$mIB9P#Okw8)(2GZ975gVeF(kOb`3%H#G28qq4Ob8>qVGbv*;gk76$3n>bl3klUm)qqes6eQ?b$$jWK@ncb|f0&Nh(ZnB1%1VF)(@q@bfNA%@sur*0g)q#$W^^>+Mq$3t8hGwR@(x z5t*AhB8u494!HxIQ)9lRbc;|ZYehmOyXfwgm>Q{jG7CF**7p5Nk@+P5$fXvu7N79m zfz-sv0-P>_nVH(~t{93+5u8n8-r%#&Zzg$Ri0*ebZ@|#AjA0~67xVYQs_J_%Run^-9AHo{Y!VtdmP!y=lR z%BGH&K>k*Q(cvHPd_YDl({p6hW^eC^CaSJyE@y`^P6)1Mo-xhr#)t6+)&kRQ%$&yu zoo%DNY~dNs@VXH5ku@V&eLrU+5?%4M1!-1pmgec6@>JZ?AX>l3eLk7GN(Suy-&6VU zU-RoGK)I)I19rnqbe7$NKd^#`h_GB&k@h*Aiy9uX`*Mq~6k&Sbbh_Edrktoac)SkW zX%fW9?1{1w`xmSMuT6?n)h{0;6{Bh8+naMYGw$+I^9IFxAA3<5st?196*v?XWeK!q zs5_n8$g6CPYr=WzX(AJ8xVZ_X#}jG{#>q3fcx8{?VAKM?aYMQ$bFY+Aatb=vkjC8) zqib9pe*Sf*DL>g^k=0(4Je$&IC|wRqKbhS;B>jMkit-`FaCj1cf|4!h6rA>1V8K}| z{@9ewXY5y;dsEY~pDxik)AAhe&OvgT5$Z(`Eq@!Fmg9hLQ&0y~kAbdLZJfNIr9 z{A%WNGSl+c>n8;Ob>wt$M+<%kc}h)f(HO?u%_50jq$r15Ob8)aTt+*hDs$)M!B6yp zgN%ge<2Xf3HDhRBqiV2#VNB?Df6m)^T^qQ{Ut|U`y&ynR2Qk}V!( z6qG@`IFp|tG&)ZBV=Bc(S2vhgNw)ky- z6Wdo3;smeFB5|pO6xPEueis^L@|y3c@T=*=3m?iK z-AEZ@R!v(ckj8oLOo5>X3bwCDcXy!K%zJ>hU<&-FB#YRk zN}2CsK}k?g)CrARh5`S&9Z)Bs*Ks03XjZB&l!sD!E^fG*4z&802R=Y=Iwd#S(Q4}y z>;hyzKE>UMR8zB<&^45mS62c2)i)%5e;CB{ZcBe1UI@6nxeZBVb3}6qlR7oRo8`nP z*%&&f5o3nW_yKevgL>|Uj$`LassE)Q!F#~E(B^#EwITWN;&^enV9wjMT?3`-xAcDO z7_%rK;C-Lm7KfCb_!zPhmsTBU-i!Pi*DO-@ZWrJRJyN$~T!JEjghciKwx=lP4K7@| zWRhf?wf=xnT)2>SI?!xo9dQ2KggnR|xoL7fkYHG>Qi*G5VnGQ#?^_UDL42eP#P+8j zVQ)|zXvD%coQCp+TV4F>Boagf91-s*KvtfOBti_h6<{+$O?+L9U#pU2Ar6m&vYxZ`B;)o)3|Bdjfcm0fRv|svyLUpACe-|8*2NZcINvU* zeIW7cE47L30hB?8XfqQq*H-C>NH#~Q(>mPa=*VkUfl+K1ZsskV6Y`hu7sW=g(=G=W z#2y<92i+cEOoR<%=leF(6r3M}SxL&*M_2k05Un$iS~K??h|Mte^<@n>b9E@$HLKH& z;JQI5s|Sb_v&V5=An?6M)?sfA*NZh*Y&JjPAEf+D&FBe_#$Z%xi3!_I3T`$ESFHU^ zkk$gfJmT?X?B3+hip`QxM9<6&O!f*bx0?;etT>WGKpl459q(@AhS{jVdy5g)(0OrJ z6-5Nsm1i@g)`=z2Vfwo_`~*-_bb z+MGiko}CgVsUaj5w=*~igqCv143s}N;1_sPx>eEe;RN1^uZ>1+#T z!G2GW*sPFK-9ynMhW*gL*s&EDxYl+K$dh`|eX#i&`L3GgcGi=8;9q0axl15W>YyLG zquv~NqU*Mkzy2+e`}UVs?|Rbd2)RXrvCQOU!ZPP`^gRt3GDs7FseMB_lZ-CL_7%rw z4(%u40RG9=H?Stj8#`SH>MIJQ(DL)!{yZ)xjlqGLVrqY)G_M9vBt@L?7GA!8P(wch zE13Pm`&I4>FPv zngkJgX9IBkagDd5#U0kC$ExcT5wsEgq=Wa$?UFx#-+#NVHNXtIK9LAKu#`PQqUL>Q z=NLIyn8DYG(KC?>8Qjok3-KY*y+P?wCYGQHxY z%}ye6#&QhQe1l=uX#Oo4GYV_#Y~&W!v7_eibIkF|&*(yhQam}+iD)cnM#x5JO9lm1 z7I&g}$-DXpYiV#Rh9=Ejn4vg3M8jXq#~_vI%qu9kwB=)qZ`J;Gsx(R~7i_FTWZ#0D zkQQx6Vjo*#aA7ucd))K+12%pRJhMIFA+ZjmmKfcQ<-t+LZU=t8&K&7!$GT8uD*d+s z)ZG@|90WHpx%Zqn7`kSH_e4J(*b$lin|nn0Yp)O`HpUd*xJzD)Wimy%rFd<;OH1X5 zw9SgbSXVKQlZ46gx91-!UL1!vVR&AHN zt$&S|;JkvQrfmD;@euw3Q!DzH4o|xGO1VE|js@0yU(wxBi@MTe*YSY%u**&|8siZH zjsUtDzWddb?pl+fN(-n-5iNM%+M^wVk>UTEiOd zpOa9+ew_}?K|H3_(2tgX!fJd`5=p>MC(AwatT3GD2BhDuE?Ht@2Uv!RB+|D?&9{ub zZIrupuvy-h9t!`1MOZxd2~d9SD)jpXB`P*O1fysN-&(5=W352Von03jCjJOgfnn#! znp{{8cxb2wMUS5VonPIgGw@gbVJcIGrSrYkT|Q!o>3?f@*cV8u$1dI2}Tsc zg{4&aBM~M#@`CB^p%e>!Ljv@Zi&%EyMRZlF-&rB8tH=ti0!YwfZAs7pOr)wmV*yV* zZ9N2_|0&*};w>YEv4?E6a{)}BjJakb?DQjEh@X&ZE1;hVQo!hMNIorjM9(+mv?4GL zLy%eIIJvk~xVdb#)8tnxurMLWNL;y63C@&G`uLK?kj-``%&rOuRznxu$=uImtsrvV z{s!hF5^&h@GnMXIM9Tv|*Nt9{83^dFwb$ciDG-{vA*sA*izl}(Zg$Q_yL;hdGjPxq zeA;wVTJ)0FwSEZQ%j$@G958#V9*x*IHh8Qm?=LSZE>%lm6oA_(t~wxvjDM=}b)&Jr zXcoi_thm`~{(*nu_E2oHNzMH9xmO|GnzSo2w(2t$MHWUi?{}o}KzKT6LdfK<1UPYy z0WXP$pNH%nU>S`YfPqaM7IwQ{UD$9=2X>N;7+$dp{~8nwX-g#0oqeph&=QWS`_L%j z4B}Bl{`}`2%{?K8d1TbFH1kL|>p# zSAOC9e~0GcLN~QvZZu~va79$JGUeTVI>~DX)UYv8RizEp;y!K%BD^$TMABj+I*nJ9SwIMZ#&h54=_D=WZvif`Q zSpIa4fM8tv1G;8(so-S1{Q1iX)*UHG+6uoz%c|en5ns!?Gb)zTV<%+j2p@0UNPHmD z-8;gvSFDOK=1xQKA9XBXLkcp*Hc;`~`i4sMK-YpNSA4JlI!J5vyOZ{Me3|bK9r*%* zt4tdY{0_&gmM=?Zth^gj{dqW zia+uENAS`(o6Ur^PeFu@=Gv;z{q9=EJ*0L1c6xZ3FI2+z6m{8bAjufOdNx?7Jx;1$ zr=1g3)HLfdel=ubzn_1Rew@5@(7VTi!)zQ%|YOC?ttLnn%{Q?(>%q{lk!3 zU5!>^h+K?yhjYl1)I-%wX%Ch>(w7}GJZ%Xe*JZWVu@r5ikmuURm&K|}9I_?oNHqe< zOEN0i^>L z;%dq-?Fo=n9LHhnt(`!^E*L(RVLE^mcxT|EJXVblg__<}LB~_4OwLBaCshFmrsgb} z4nVVKs9bx1V^&WYRN`5wCVqxCs|H5WOY2{W+7FDq77Y|q)}B9a34R?T)5)0`$XQI2 zbrhan76_B6;Y}QpX7gXfqoTY4-Q9rm{h7;_#-rM6rJ$kJ052NLu>+8X^Y-KyrwyfN86XY+i^w*s2fH9T4Q6f?~)19j(-u6`sqP+ z)s0~FNO$`&zci2s|>3wce2(LTODf2G2R%A3;c~V=U9{zyB3ACUP zAANLew5K2aSz<*91Y7pxEMzRHIux5z{Fx*M8naUh{e&;Huiqlej%K$hKDTI%6I*09oP{BIQA)d2%>!2vC?f7IC^2uVK$`J%=> z{JUt}nc*70dWUhUW34yd2{84%z{3xP*lxiQh~Q-al6zB#R)qY(&hBE%jG%$4=tiNWBxBrB zi4nLP6#t(cXlzfmV(|{#lLa?(c9U^_wVY7|HNJb<=5_JO-DcU(&Qh{AB*HY($t$;H z1M!E$&I<24mKN53s4SG|JE+@JR{ZjRfo?l?<78jfK?6fLkH02{#u^AS*@~CY({`2` zNZmHueb5B_1*Q&|nL@V3|BF4fs<(gZOir*Xvv}75gxeYAN+;p%I9i7V;WO)##e32exZ9QJamRc zB@bw`A3%5tMG5Z7Qc32?x7!LY?DcXNgJ7Y0)>CvWXURD`?5(2v`sAnZYT^EVVGzm1-@0<=f9N;f6s zP<0G(a5#6mtMpZ^Lg}G5EF)rDieb<632xtlJT#VJanCM1{ew*EjBcR@&ok@_CMTR1VKpA7KnKy#iXSGeb4P zqGx4W1I1jB&tLm!tp=J!mDgricz%68ExX&44!YSd)Q#I<& z5^Y6Yp0Z~s^TP(Dt!}X!ck8ZMf5$Urgc?QO@GKf?ZNRRIs?@h|b+P%nB@+*33R{4Y zi}uL!b#^QYrH3-o8yzg0++XD13kmSbB<_VcNH&zjY z;s)vtuoNr-AHQ-13J{NtPpR-5}3pB*JkWB~jj^u8)t^&NALx5_srpbldIaqK!op z4Az1!$WQ6TIuue}IC~;-*jVj=m;`H;-_KO3ZeDE4pLd1hpvlh+6uE!$@LpKU4||Uj z3_6^o+m@<`juj|r9jLj4j5&X5jQu6Fb;Z@uXYD!#aExh#RS`<9_ zD`!-kZKs8-t8SqrznApU#ad8)`|s^yf2bp~r!0!R{{Lg?@_6fF{MY;#J;V9(W*Lph zUTAs?SI)4!7L=P3QvNJd=<$)G#7-P#MhL;X+Y2@Tdc-hFn-6~|k91;&&XG`)Z$#vj zE&&g>x9xibGG|X`rYsVuYp+Q?CYI~O9VW-#k^-K0hXn1DWU+j zfge*GPM7>gtqdpFrv9JFc9n)LZ^uMWi8X=)Cczu+K?U-?!F zB7XXOo{h&46s(xl>jF?)F{J*79@Sp?Dx|*ly1^hqQJWOBnIE@;nS9!EVf)Hu^*)bD zYM=G7?Al9HDnvB(g7)p6d=*)8L+PIuHqU$XQjsEpFRmttPh*AYEpdo|PUVQC0W}1Q z+^L*lQ%0_>S>~q>+l}pWuAWL%kk4lD=i_^H)34-e)~$1rHnJdM8aqL$=;`0(l>>8I z`lx>J)mvbBDty$Z=DX{FZg`)W*Qs}Y)^u_6$KX=A7(P7|eRn2+O&vC{p!`iqd^&bE z2zvlM_|`;b{4(K3;}c>0PyY5puYj8t;Dv=o^+kvA4?Zym1mz_ldhDnobGyA__*v&B z7I1=Y45~gM9Li#?NXfnZ>+!7NHh04492|ntq#^#6m*kK*Q}t=~`+VX@Sv#2ZnNl83g`X>3y!dILi8S%Zq+)c(ss) zEl*qC1O-ZRXsS`;tlbjj1ya;M)Kb>`&5Nze$v>G#q`Piutk1Y*Ps1kbyrkP+816Zo z){7%)4th`HeuaoiUVTVW+`BOQqHn(}5-+A>#FVL9ojy#zjOWmq>E@uk!?cP&A|hrD zm|*oI*AZ%aZ(vIsPaYbcSEY4SiaIKuxG2$1-uSEuZ%asbr*v}v@l=KcaB377teC)- z3!fq`=ZVsml0YNnr>sk;7UVV_l|wOCXY%&jTP<^VEq8Ch9q5&r=Xw-!g2H#<1R4tj zUEHnYPKyy{UT}<&EedikEWNq54>PFnufCJe;yTkztPtys-~Q?N7GVOvuu!%>JmH3;3Sc+Ei7^dn&%Cp5=mB5>0z>DFID1%`Fg3qQSc?rXD? zebOfMddAWG!i`Z!H~G3r%N*{Yiwu9VwCP2}N@$?T)5EDq%Dbtrs?zGw%hv6^p)CH; zbi6;~=E%QE;mr}02c4j?lNf(=%{DJz8EX|v+mN?zTNe4|w=`R-3Zdr5P(EoVu!9LVBHt`@@dZIdkfH@`F?)C|`!Y4W^V@cHw=HM7W<)OHSluZgw(D*!P+!{&l7A{(KfafYpphQ|yXS zK}EE&W0|rxsKSr*xSOlxms(8$k}*P)7Sr4B1NOToMs%w-thcwB1foxp@2AN9Kv!px zQ&%!mDXv%70#INR1uKh@@`I@4N$zXjr1cF)4&9L$M+P$3~ zg5{@x-0QKtbJyn!n^aDsCh`{%-Ix%43iwC)?Uk=4WQqTo*MZU(SQjE((5dK*ZJu5q zqNPwe0(ucn$F{+JZ7+C-@!P^3WgFn=v(Hmz|+oAY6l$ z8#X3HS))5H_?H zRiU|Lf}8pWxeFNu;G!J48?@Fgg zSS|D{_xY(lPYrO9N*p=Y8#=JGtiCbMGOvS77vZ0(+y_q#!J1<5Lgz(oZg`@eqBDFA zQB74M{%%PHB@YkP-0&>$HhH3<;ge4;HP?poExe8K(mGoQ4CgdaDI5PT;CZ@}%=M*t za|uMcX|TT3Qosk`fp;F*H?Kl*nkh_vT}bOJBb57yp-M1OpFF;a3KaZZgr{SLwh`CG z$sE*0b}z<=cqQJbh03NiDk%paO$br!AE*Z!RSF_DnuEWG+FY`r*vV{JZq+mVwX*9Z zIo|@wUS#0H{4v-`3WHWbp%O`i=bhza%pGQ<$d-#16k?k)3oQX}b0;p9#qm?WcjtVf z`T2gcY%~k4P+IEN>(y7gmXB9Ut_})>{kAD>p{JDPMH=FPQ@EAP;IIRgwGLcStvH`4 z%FN};(D7qshZ7w4?F(42=Q|SMW3g!`1Aq%^G-6z05SlgMI){Q z12;b7*uY-Ml^pQ)SgQJlnk4`>zg;t3`$l;8IKz@uP?&>UPN^43rVwHu(H$vg^=I=R z<{v|)N-x#~s@Hd0KKgxuG|rwkYeE|56H&f0q=pD@wDMkTL8^fC5{f8Aid2;j(o3XE4N8-aNCyc$w9tD<@@`Zf zJ?DGQ|1TFm$jqL}p4Dcp*=w)crhXkR%lV?G5w@Sm=?5x1ozjjYchqt;d+;F8yueS$E$aX{!d8oytVJU`r@9`4pf{gT+*0dF|1FisS9Pba(2k$(6 z#`sBR1-^~EcwRp}bTia2Y|f*iS7g(_%9XT44W4m(P6EMjCEaTMny3&|$!JjWafRmj z-85mpw*{z3Jv@6m=1Sfoevg<1=A0CgiL&%_C!#0P*B!-GO!n9ow~zZ6aXq+MgM@Gs zU4}3MuCP&|lg{_8si-dJDAG2S4VLdYEAsHZ`aFCmh&iZ+lx3|nX94jw{1k+bsQl9Y z2(gS^WW97~?I<>=gH1Q1XTI_DW6?eCs(AsSQ1Ftp|LcCX;Str(5|7S}v1q1;#L~yc zhnu0Sp-K;cLuyK@S7}u|&q0@qJ4NY5`ep=xwI)3j$MI}D2XI-tq zJ*Es6@}g5jY|TV25?(=W*e<8|XfGDoi}Jen?CkUe{0Wv$m1OCVX!Ka;?y)T}0p z%|!E=K9M|#_W;r&m-%*O@e&C3Ph<{8HmFNv6)MVOUr$xHboZ9h^TdYJX&-dKiMuJ= zW9!1YLI>#IVInudk4Tf$9-tUQ2ZM`8Y9h5ZPf1@f>R+z7>8SK%t z8H??GR4Ky7RoCQ9^_FMCLbW{?ijF7&ituW#Q}2>IUWNPMSUK7 zuWr5uHH!jho2-Y8;%2_=4w$hpX%>s@$cMW!;^yEhB#a4ILQ$5U^x6pIUi9g{Ig2+? z{1tJF%i=F9NZSTpX^&rx=j5K;n1&)r&e!oR`sjy&zrOLC|DIq}K_WujyNUtDmKlXZrx4)4XFRsCNBRrR#q?vaTxwHXs$*ox+D2wD>O;E z(+e=BymM=+UZ1Zt@#!~=jj)Z9xb%wbytM7N7GOQJU8}3zlM#Gya7T*VMTDz$vi3Ft zp%$|TRkY1^T)h5tXjx#brl4=PviZmfn=mtDX|L^Umsf*? zmQ-nMz6}N4*os5em#j(037eRu*53i{9sVMuek+GVvM>%*)@#AIzW;ffT~|r#n<+Ul z`u&CU74slZWYcXV*Mb%E*&FMVuI8$h(jz=(F`f-Hw?pVAEuIAc%QIw!-2H@To2GI<^md69l!!Fsls^;91-V{rai z0!D?wgXt(^)Hg+!e#VT4m_2 zY}jXb%~p1mespLFADf9t37m%V=(mi7L%t=$@Z`;im*{=ZmU|0LY7eYA=o9N@Qd<}= zDH*qhyYF1^h7J|qxu`F+=%K@>e+SJniirh4wAUtlTeKB0>z9(HTKI|RMWAjX*V+Dp zQ_8O0D`3mk^;a;jW2CE>4{e`V>yrQ`c5s2-sI6VT^onHBIVuc z&O@F8%d+szgl6AaMMv%0mTT>vJ-`F%1=;2|eggL=JWtWD*q5e?Pvdk1%~G=};)vyI zd|{U%lu3#%isb&@{%P43-Q8o}9#uTsclExEn#8oG>xT>2k;(^S8 z8xwScEzcudU_}Y_Rblow?1cc@>y|pfC+fa#Myx_K_RH?!=vC!aEkEsnf(kuN63INX>6MUnkT}>z9jKJhAqLOX2RRi$@F%r zIwgOk5ZX67jU3cDZqi^QXb8wnODRdyzfHyOq-Xee5d4GV+df_+C8qM6(y;W zERFYky?s=nc^b%v($?Xu-PzeFi#r>)d>^RZ&w62XtoS^}IEE@o&AswH@~qo=?LntE zJVfPy|MF1fK@agUa_fA`PjCDD%`Pd~8ziI6X07)5l$o;e910{Ju5HNyhwq|6{aZ=5 zF(Jv(KjB&vwUr+9Fb8=IMEV&RG90ZC;>xuKy%}^HQCg5H@@4VZhVL-h&Vuo@H0$;t z#SXEs;_+dc>Wcg+ro>mVxtK_#Ckg!x!gJg8$BZ~iQ?N)w&EJ`_%MZ(rTb}B1$-Qs5 zeZ7h3j_J@DYAzFQNpkHWi&9qUSzv3Bhm7HpH-ziYH~Udu+<%g3X=nRT6rWXqA^gA& z7KCuz<05jFi6~jktm@7PxOr4VZDaRx?oyE{2?e_ZlhlRJ-~;Cvkd}%i&i*to_pv0v zu4#4NX9o&+qSdePi!(AWOhzZL>U?2_Y_qG?f1qEj_HeV3($88WzY13>Xv3eOS1P*e z>njhsX|HVu1#&>WLl$Lm_AxNn*>ZO2m`{X|556h+eBnyUhCr6l*SSdw8Sm4*lGj4z zy&D0t)1{SQKFZ*;J^r&+08FqTW}sRXger?s>6vdDCw^*b>xNs8t}@c_~u z3`26?=O(iz`>BK*A!e!5tD(oc+wJNrg3NO4UBuL`^>s{n@#cz6#m0Y4>xq$IbAo%B9iLg~{GZKU|o*WC;8w5+<^M8)Gw zq46nzyG9$FOedaA&27C}Ye_Sq0=gt7`Me!*t`s{(#VqM<~&eg8Yn$6FtOT3j;leq=^^ z&;>UD=NMq-^lqhApR;Nvn$}O_aE%!PGV-%rFJWWU8u5JZt^dUyJQ=L&Fm(OvsHBp5q={;kIMv}n->DOF*&4gi?R6{VD1Mg=V9w(xcbj~^eR;A#Tsep zIm>Ib#P#ht@zDUUwPiIS?649GpHHC zs_Z;OvhV9|T`#T{2fWMnFz4DW0>jldLKHq(w>?&u;-P}dn(HsEW)f&st-?O_zP?P( zkXB(!sIs-62so^9pr$!1sO|=LaE5{rF@U=J{WI(S7R}_Ma_K(Q>Z<}hJG5;JmNx{W z{ww(7BvUQ*eU5UV|3dJ_Ckw$(2fTQ$D7{y?84d4FWRL@1=QIEiqOn?D#BLKBt7d+klUC@^J%zZ05m*l`Q^*10$(9&+a*0?*s;}eBd z8rLvOUF7nlF|(C^yRvUPke0J`$qL0wR99-&qloQPArV}ZN$aaG+e%;-)zU_}3Ow=d zOwZZk74!(7LhnJ^6v(|@j`zf0CXw*83e0lWk}{ zWSC=z0&ZU1?I*L+gJtp(TN+QinUk7l54G((5A58=Jt1 zNXPMp-w=HkcxBNJrXkuZBga2CZe|-@YlIdOJdE77_g)5!5<-qKob*?+VX3U;zmHs z3otff+)=qGZs<`1u7N`QT~?byAll*>sFJQCMUR_Ex4h0TeOKofg?smP0K|^^p!D4} z$S5vOz@Xu0d>mv5y?37IRxqgF#~`y%>f`kZse-|(T&pp}332V@v-;{kU;$9foEbwp zS2lem?N$fKb530#jn81cl>CIO$0ELwnuFyP9F+$`ZtRAQGf=WotEHX^s z{`q%r*6x4LO9nx{A9m8BEdCiRbH+?Zx1z4_E{PSp$1lXx<7dqhm6!RC?Wg7ifo!mq zr}AYW(NtU(OTLe-}S`v@8KqDO7`;AtEm1 ze^SpMQ^u)ijBgdNlnw72)a>QsC&KRWG^EtuYr)Gj@e?`oAa-JYdZwC#q84z2c&JzbkZ1*zR6Ugq@Cn0_hX+qtIzclobG?Vx#?9@m&>mKry66-i5G5PrTgm;|F8 zQ{nZxXr3@;&v_cbehLSN$85mZ2or7qTYa3}Y^7H|VAlN)qsu$b5lf2p<_lO}PwUh3 z_01#+160jf;IeQ}Zf63LCC|xlJO2ldv04)k(Jg}u3m@|ntR(0JkhM~ z0p1L@8f7eJd(Tazecq1;)RmsPBhW5IIfLE{&x!J5*kAcaYk_~MR8qoA>C>% zE~0Xf3~+4M-0`ZvrGuH-QsW=WYCvU~bC>{?>dKH<^nBX!=^XRT{=@X!95JH<<-oz0 z4dr+PUh?!29R5Tx@s+3?bIzR3|GQCsnQwGtyT?s5C&8(+j}s0tq)IDVQ=XOJRJm13 z$h;uDE{}k;;9mhC@YKu#uha{w#iF5`nH~8fr-h8M&=y2 z1A~DB>oI@%N)^N?8~1~t_W;ZUFM!kmoQ-MK095RR=m@ z`spB=86Ur8&k}zZu&5ykd%355nVECk!JEmF@`fjJYhtHWTb2Ic9f;SO+0AS$7$|th zPxNgkPVnKCWBo@k9DAfGYvhh*so#!6RPa+~{q_MpF+@&Oex{3;XykyO;jc(URO?~{ z=tjNpLDP|S59wb4DnERY^<~t&D-fbKcX4)kX$n}%g}hJJd^7f(ZmtUsfQ8_Cz_A-h z1YrTyxDxD6Zmp=65ZQZ#S!^I3#^_Q`ivK)Zd#o=*fke>%lRMSsdrBVr;36R4lIq>> zus>w!i&K0q@EWl+&;He3s@81A>*B}ADy*A=sMPEC^Xz1y!l>=YmHB8IiB`npP$xqt zkwXtNCJvh$`mfHQU%TzB{5qBlTy?Na^?PLSkgscPV^n8)l8fA@vMlJ$kQd!S*Q=Pw zA+Ixhw$Q0USE?B#9?R&`W7}2cR6!ZomC}RxPqrjV`iIQN#|B-dxU4Qsrkwz@$qIY2NEf=C7qv;ocsmb3()m?OC0#Qs7`3|a z-?l}e0Zvr%fKs{tt}azJTHUTurn!dE`g>@D2Z912dMuO)Cx+_oHE=`Bbl)NY?? zVw*o<6EEupo+#bWm^M2fW(TL*Ocno@YA>0+6(%{{GkK?H?^#2Kuk>fWU*@HXx}&2^ zLs`2eIw8F;XpnGcL`ic~$4s`%61_~KWgcFcZ2`QebvTIz^qw7}s+c;X@p31l*AwKK?@mppBE{ry+H0V! zGOZlTLk^!MI?x5>E7V^qU$zuf{bmAw?1~5v^F%&5gT#p6yRe&9T*q=bths>wpyAl* zDF9+c<-c$RxL>lN%dBh#9|y1J22Go#jBKw$2&L~cLBMmD{S8`RBLdnsu?)oCv1IbEgOUD;oYuuSk;EFTVu-@}LI>vvz39njMMIP0NxV*ief92IHF;>9e*S=jkR z5S2J~3gG4reYW{vAEezY^V`%sUo7($dz-#ZVG}B)E)<5_LT=2V*3vQxqcE6r#%d>1 zO%4%^XwV^6E#Z$HHa_F3`P$y7w|3M>HA}irQ-{WL%B*ijwft7o(+3_O5H};fsd{e` zYqH_rFJ5ow@6MLUb}Gl#iDud7R%=DG6cKz+aXw&Tu`zM*-5n$=8$e}bb92^_Em(cE zK&z^~7ZDq2_A+S+)^b80{EaWz-?n+5ARAH107q0XO4u2Fi4{U^^o2_Ie!r~fZ7XN3 zoQk)%rYR=))@#B;S-*7-he!Wdv}bc=ZrFe^QP)mV1PNx+q>4l8&Muw>8z|DK5oe|? zy%Dd6H>pSWYnUe7s9NH!nBmzSb?>z=u#L8;b0N7<$Ox5+G&vR6xK;N1nyO)GPffMmey4xc-p%f{Bdf3_dY+0xOP95R(gDRA2jlWXf!QvIl;#JDNcCF*eO%SCQU1(^fskOCGk#QYcD-v_3 zc72?70AlB(Zfq_cgVx@4NA;+;{HSp?H8thQ%I;|ygSm-?%j`s3Nmrym<)%Wiyt?{awkAdM+Cy0V;qYJzdL2bTKv2(aPB33( zLD^gb!<%+?-)+s+Y20YB?$KkkrfgqZN@@+Um@brvaF*38=q8fI-ml%NlrViVAG&Yj zDC2`yyufJHsW!PtCup-jZ>;UbY z%2VXFhb1swjy8tj4jsj-@2;9UY)nI(K}Ju`w~IL7Mm)8#o1GDlQb}0YBxxFzQ(E3# zopHb(4p^3p=Z6mI!;Qk*+BTS2jy`yFPfi9ivB=QyG^RAIueV9M+KJAi5Fwrk2{f$M z#O zLrvTdG|F*taSStGwyV6$aV}o=Q1vymT8>RJ z9LynLK2Ofu86g-F3cG?lS!oQDR&8x>ZUv|V*8}(j6HS6o^x(M#G+Dtw@l% z{(y%?P(X&*U}9Z{O*+%k28c+wis%jL<)dOX^@K2sn7xF#`}^Cu zQ%m>vRu6v+kB^5O!0jNMyuwnOi$$>Tq#Pgj`Bq^qqF7?6jh&sIebH{`EW|^_Gj`2U@zWO!&ZE>*o>0Sf~0CRj7UbPD!B@Ok*O8 zN@z|06oKzrO91K`N7waG>=n+T73YGtJw#&eUOFVNSLI!7tT`>>;iw@3f)$H7*KEhG zJgl@jEyCVK7K+;`FYX>9SaIE31Fy+5c)i6wDTjJZN|}vtx;)J9!t^-W)O`233I(t1 z*1k$RoO{RFs6g=?&uR@e&$F|$E1F@X0Ib0A$2UM@r{k*pOFtF?aOaO0?o`oR@tmZ; z`UoT!APAW)Q>Glh%`YoiAi^J2PW&1va@gKVQb2Bgek29{Tr%F(-hSJt`gxn}m*2ev z=LWz%M@%GEskOqcGvbdAx2g!#Hm@egX8m*~B6bntUJv>Q``o_yn?*G+)Pw+@|5O)2`3m^Vk{2WjOIW$y6Sa|Xe zCxe&bEqp*kVpipjW??uENbL!7bKs1PVn9a5naJSypYu~o+x+)~lFQKV$z F|9=lMs;U40 literal 0 HcmV?d00001 diff --git a/docs/examples/te_gemma/requirements.txt b/docs/examples/te_gemma/requirements.txt new file mode 100755 index 0000000000..a4eaeea43f --- /dev/null +++ b/docs/examples/te_gemma/requirements.txt @@ -0,0 +1,4 @@ +transformers==4.55.0 +accelerate==1.10.0 +datasets==4.0.0 +sentencepiece==0.2.1 diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py new file mode 100755 index 0000000000..6285fea1a9 --- /dev/null +++ b/docs/examples/te_gemma/te_gemma.py @@ -0,0 +1,703 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import contextmanager + +from typing import Optional +from functools import partial +from collections import OrderedDict + +import torch +from torch.amp import autocast + +import transformer_engine as te +from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding +from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.pytorch.fp8 import get_default_fp8_recipe +import transformers +from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel + +import torch.nn.functional as F + +""" +Top level description of the classes used in the tutorial from this file. +---------------------------------------------------------------------- + +HuggingFace Gemma Model implementation hierarchy: +---------------------------------- +GemmaDecoderLayer: +├── self_attn: +│ ├── norm: (nn.LayerNorm) +│ ├── qkv_proj: (nn.Linear) +│ ├── attention: (SDPA, FlashAttention, etc.) +│ └── o_proj: (nn.Linear) +├── ffn: +│ ├── norm: (nn.LayerNorm) +│ ├── gate_proj: (nn.Linear) +│ ├── up_proj: (nn.Linear) +│ └── down_proj: (nn.Linear) + +GemmaModel: +├── embed_tokens : Token embedding layer +├── layers : GemmaDecoderLayer × N +├── norm : GemmaRMSNorm +└── rotary_emb : GemmaRotaryEmbedding + +GemmaForCausalLM: +├── model : instance of GemmaModel +├── lm_head : (nn.Linear) hidden states to vocabulary logits for generation +└── generate : generate method (input prompt -> GemmaForCausalLM -> next tokens) + +How `generate()` works in HF's GemmaForCausalLM: + 1. prefill (input prompt -> model -> lm_head -> logits -> next token) + 2. loop until max_new_tokens: + - next token -> model -> lm_head -> logits -> next token + 3. return all tokens + +NOTE: Notice how "prefill" and "loop until next tokens" are just part of the `generate()` method. + This is a common pattern in HF models. + + +TransformerEngine's Gemma Model Hierarchy: +---------------------------------------- +HF's `GemmaDecoderLayer` is monkey-patched with `TEGemmaDecoderLayer` before `GemmaForCausalLM` is initialized. This way, +while the model is downloaded from HuggingFace and most of the code runs from HF's `GemmaForCausalLM`, the underlying +blocks of "transformer layer" are actually from TransformerEngine. + +TEGemmaDecoderLayer (inherits from te.TransformerLayer): +├── te.MultiHeadAttention: +│ ├── linear_qkv: (te.LayerNormLinear) +│ ├── attention: (te.DotProductAttention) +│ └── out_proj: (te.LayerNormLinear) +├── te.LayerNormMLP: +│ ├── fc1: (te.LayerNormLinear) +│ ├── fc2: (te.Linear) +│ └── activation: (te.GeGLU) + +To be able to use `model.generate()`, an entry point is needed. `TEGemmaForCausalLM` is the entry point which +subclasses HF's `GemmaForCausalLM` and adds a few attributes and methods. + +TEGemmaForCausalLM (inherits from HF's GemmaForCausalLM) +├─ model : inherited from HF's GemmaForCausalLM but with monkey-patched TEGemmaDecoderLayer × N +├─ lm_head : directly inherited from HF's GemmaForCausalLM +├─ te_rope_emb : RotaryPositionEmbedding (reusing the same for all layers for CUDA graphs compatibility) +├─ hidden_states_buffer : shape [b, max_ctx, h] (static) +├─ generation_buffer : shape [b, 1, h] (view of `hidden_states_buffer`) (static) +├─ inference_params : TransformerEngine KV cache +├─ model_context_phase : GemmaModelWrapper → uses (model, lm_head, inference_params) for full-sequence prefill +├─ model_generation_phase : GemmaGenerationWrapper → uses (model, lm_head, inference_params) for single-token decode +└─ generate : generate method (input prompt -> TEGemmaForCausalLM -> next tokens) + +Notice how "prefill" and "loop until next tokens" are specialized to wrapper subroutines - "model_context_phase" and +"model_generation_phase" respectively which makes it easier to use CUDA Graphs. Just one more abstraction is needed: + +TEGemmaForCausalLMCudaGraphs (inherits from TEGemmaForCausalLM) +├─ model : unchanged (HF's GemmaModel with monkey-patched TEGemmaDecoderLayer × N) +├─ lm_head : unchanged +├─ hidden_states_buffer : unchanged +├─ generation_buffer : unchanged +├─ inference_params : unchanged +├─ record : utility function to record the graphed callable +├─ model_context_phase : GraphedCallable(for Context/prefill) replaced by `record` +├─ model_generation_phase : GraphedCallable(for Generation) replaced by `record` +└─ generate : unchanged + +How `generate()` works in TEGemmaForCausalLM/TEGemmaForCausalLMCudaGraphs: + 1. model_context_phase (input prompt -> model -> lm_head -> logits -> next token) + 2. model_generation_phase: + - loop until max_new_tokens: + - next token -> model -> lm_head -> logits -> next token + 3. return all tokens + +NOTE: In the tutorial, `record` is called when initializing the model. + +Additional notes and clarifications +----------------------------------- +- Wrappers, not submodules: + `model_context_phase` and `model_generation_phase` are convenience wrappers over the same + `model` (GemmaModel) and `lm_head`. They own no parameters; they standardize buffer usage, + masks (context uses "padding_causal", generation uses "padding"), rotary embeddings, and + KV-cache (`InferenceParams`) flow for TE-optimized inference. + +- Buffer relationship: + `hidden_states_buffer` has shape [b, max_ctx, h]. `generation_buffer` is a contiguous view + of size [b, 1, h] carved from its start to avoid non-contiguous indexing. Generation updates + `generation_buffer` in-place with next-token embeddings. + +- Padding policy: + Inputs may arrive left-padded (HF-style). Before TE execution, padding is shifted to the end + to match TE attention mask expectations and to keep shapes contiguous for capture/replay. + +- CUDA Graphs specifics: + `record()` captures two separate callables (context/prefill and generation) with fixed shapes and + stable pointers, then replaces the wrappers with these GraphedCallables. Under graphs, the + functional behavior is identical; only allocation/pointer churn and CPU overhead are removed. +""" + + +class TEGemmaDecoderLayer(te.pytorch.TransformerLayer): + """ + Wrapper class over TE's `TransformerLayer`. This makes the wrapper very + similar to HF's `GemmaDecoderLayer` and easier to replace it in the code. + + Args: + config: GemmaConfig + args: positional args (for compatibility with `GemmaDecoderLayer`) + kwargs: keyword args (for compatibility with `GemmaDecoderLayer`) + """ + + def __init__(self, config: GemmaConfig, layer_idx: int, *args, **kwargs): + + self.gemma_config = config + + super().__init__( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=config.fuse_qkv_params, + normalization="RMSNorm", + activation="geglu", + attn_input_format="bshd", + num_gqa_groups=config.num_key_value_heads, + kv_channels=self.gemma_config.head_dim, + layer_number=( + layer_idx + 1 + ), # Layer numbers in TE starts from 1, not 0 like in the HF. + zero_centered_gamma=True, + ) + + def forward(self, *args, **kwargs): # We need to additionally pass positional encoding. + + # filter out HF specific args + keys_to_remove = [ + "position_ids", + "past_key_value", + "output_attentions", + "use_cache", + "cache_position", + ] + for key in keys_to_remove: + kwargs.pop(key, None) + + rope_emb = kwargs.pop("rope_emb", None) + + # Return tuple to be compatible with HF. + return (super().forward(*args, rotary_pos_emb=rope_emb, **kwargs),) + + +class GemmaModelWrapper(torch.nn.Module): + """ + Encapsulates the HuggingFace GemmaModel class as a wrapper whose + forward pass is compatible with CUDA Graphs. + """ + + def __init__( + self, + model: GemmaModel, + dtype: torch.dtype, + lm_head: torch.nn.Module, + ): + super().__init__() + self.model = model + self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) + self.lm_head = lm_head + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor = None, + attn_mask_type: str = "arbitrary", + rope_emb: torch.Tensor = None, + ): + with torch.no_grad(): + # static operation - for CUDA graphs + hidden_states.data[:] = hidden_states.data[:] * self.normalizer + + for i, decoder_layer in enumerate(self.model.layers): + hidden_states.data[:] = decoder_layer( + hidden_states, + attention_mask=attention_mask, + self_attn_mask_type=self.mask if attn_mask_type is None else attn_mask_type, + inference_params=self.inference_params, + rope_emb=rope_emb, + )[ + 0 + ] # static copy - for CUDA graphs + + hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs + logits = self.lm_head(hidden_states) + + # This is not needed for generation but is needed for training + # or finetuning. + if self.training: + logits = logits.float() + + return logits + + +class GemmaGenerationWrapper(torch.nn.Module): + """ + Gets token embeddings for a batch of single tokens, runs forward pass, and + returns the batch ofnext tokens. Also compatible with CUDA graphs. Not a + subclass of `GemmaModel` since the model layers are simply reused here. + """ + + def __init__( + self, + model: GemmaModel, + lm_head: torch.nn.Module, + dtype: torch.dtype, + ): + super().__init__() + self.model = model + self.gemma_layers = GemmaModelWrapper(model, dtype, lm_head) + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + self.gemma_layers.set_inference_params(inference_params) + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor = None, + attn_mask_type: str = "arbitrary", + rope_emb: torch.Tensor = None, + ): + logits = self.gemma_layers( + hidden_states, attention_mask=mask, attn_mask_type=attn_mask_type, rope_emb=rope_emb + ) + + assert logits.shape[0] == hidden_states.shape[0] # b + assert logits.shape[1] == hidden_states.shape[1] # seq_len + + # Fetch the logits for the last token + logits = logits[:, -1, :] + next_tokens = torch.argmax(logits, dim=1) + + # static copy for CUDA graphs + hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1)) + + return next_tokens + + +@contextmanager +def replace_decoder(te_decoder_cls): + """ + Monkey-patches `GemmaDecoderLayer` with the custom `TEGemmaDecoderLayer` + class. + """ + original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls + try: + yield + finally: + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls + + +class TEGemmaForCausalLM(GemmaForCausalLM): + """ + Causal LM created with `GemmaModel`. The underlying `GemmaDecoderLayer` + class is monkey-patched with `TEGemmaDecoderLayer` class before + initializing the causal LM with `GemmaForCausalLM`. + + Args: + config: Gemma model config that HF uses to initialize the model. + """ + + def __init__(self, config: GemmaConfig): + + dtype = torch.bfloat16 + with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): + super().__init__(config) + + self.config = config + self.to(dtype).cuda() + self.hidden_size = config.hidden_size + + self._model_context_phase = GemmaModelWrapper(self.model, dtype, self.lm_head) + + self._model_generation_phase = GemmaGenerationWrapper( + lm_head=self.lm_head, + model=self.model, + dtype=dtype, + ) + + if self.config.fp8: + self.fp8_recipe = get_default_fp8_recipe() + + # Rotary position embedding remains the same for all the layers and so + # created here. This makes it compatible with CUDA Graphs too. + self.te_rope_emb = RotaryPositionEmbedding(self.config.head_dim)( + max_seq_len=self.config.max_position_embeddings + ).cuda() + + @staticmethod + def _padding_to_end(inputs, lengths, max_seq_len=None): + """ + Gets the tensor with sequence padded from the beginning and + updates it inplace to be padded from its end. + + Parameters + ---------- + inputs : Tensor, tensor with shape [b, s] containing token numbers. + It's padded from the beggining. + lengths: Tensor, tensor with shape [s] with lengths of the sequences. + + """ + max_seq_len = torch.max(lengths) if max_seq_len is None else max_seq_len + batch_size, max_seq_len = inputs.shape + new_input_ids = inputs.clone() + for i in range(batch_size): + new_input_ids[i, : lengths[i]] = inputs[i, (max_seq_len - lengths[i]) : max_seq_len] + new_input_ids[i, lengths[i] :] = inputs[i, 0 : (max_seq_len - lengths[i])] + + # Trim the inputs to no extra padding i.e. fix the max seq len to + # the longest sequence in the batch + actual_max_seq_len = max_seq_len + inputs.data = new_input_ids[:, :actual_max_seq_len] + + def _create_or_fetch_hidden_states_buffer(self, input_ids: torch.Tensor): + """ + Returns a tensor of shape [b, s, hd] where `b` is the batch size, + `s` is the sequence length, and `hd` is the hidden size. + + This function is overriden in TEGemmaForCausalLMCudaGraphs. + """ + + tensor = torch.empty( + (input_ids.shape[0], input_ids.shape[1], self.hidden_size), + device="cuda", + dtype=torch.float32, + ) + return tensor + + def _create_or_fetch_inference_params(self, *args, **kwargs): + """ + Creates an InferenceParams object. + + This function is overriden in TEGemmaForCausalLMCudaGraphs. + """ + + infer_params = InferenceParams(*args, **kwargs) + return infer_params + + def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None): + """ + Returns a tensor of shape [b, 1, hd] where `b` is the batch size, + `hd` is the hidden size. + + The buffer for generation is some part (beginning) of hidden states buffer. + This function returns pointer to it and also copies there data if provided. + """ + # hidden_states_buffer has shape [b, s, hd] + # generation_buffer will have shape [b, 1, hd] + # Notice that `hidden_states_buffer[:, 0, :].unsqueeze(1)` will return + # uncontiguous buffer, which we want to avoid. + output = hidden_states_buffer.view(-1)[ + : hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2] + ] + if data_to_copy is not None: + output.copy_(data_to_copy.reshape(-1)) + generation_buffer = output.view( + (hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2]) + ) + return generation_buffer + + def setup_and_run_context_phase( + self, input_ids: torch.Tensor, inference_params: InferenceParams + ): + """ + Runs the context or prefill phase of the model. + + This function is overriden in TEGemmaForCausalLMCudaGraphs. + """ + + hidden_states = self._create_or_fetch_hidden_states_buffer(input_ids) + hidden_states.copy_(self.model.embed_tokens(input_ids)) + + # Update offsets before every forward pass (including context/prefill + # phase) to make cache work properly. + lengths = input_ids.ne(0).sum(dim=1) + inference_params.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))) + + logits = self._model_context_phase( + hidden_states, + attention_mask=None, + attn_mask_type="padding_causal", + rope_emb=self.te_rope_emb, + ) + + logits = logits[torch.arange(logits.size(0)), lengths - 1, :] + next_tokens = torch.argmax(logits, dim=1) + + # `self.hidden_states` has shape [b, s, hd]. + # Return hidden state for the last token - output has shape [b, 1, hd]. + hidden_states = self._get_generation_buffer( + hidden_states, self.model.embed_tokens(next_tokens) + ) + return hidden_states, next_tokens + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + pad_token_id: int = 0, + max_new_tokens: int = 0, + *args, + **kwargs, + ): + """ + Generates next tokens auto-regressively for a batch of input tokens. + """ + self.eval() + + # Both autocasts are needed: FP8 for operations that can run in lower + # precision and BF16 for those that cannot. + with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast( + enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None + ): + lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() + # If padding is at the beginning, then shift it to the end + TEGemmaForCausalLM._padding_to_end( + input_ids, + lengths, + max_seq_len=( + self.config.cuda_graphs_static_max_context_len + if self.config.generation_cuda_graphs + else None + ), + ) + + batch_size = input_ids.shape[0] + # For benchmark generation run, this is being set explicitly. + max_input_sequence_len = self.config.max_seq_length + + # InferenceParams is a cache, where keys and values of previous + # tokens are stored. Moreover it stores the current running lengths + # of the sequences in the current batch. + # A helper function is used to create the inference params object + # because this `generate` method is common for TEGemmaForCausalLM + # and TEGemmaForCausalLMCudaGraphs. In case of CudaGraphs, this + # function is overriden to simply return the inference params object + # that is already created in TEGemmaForCausalLMCudaGraphs' + # constructor. + inference_params = self._create_or_fetch_inference_params( + max_batch_size=batch_size, + max_sequence_length=max_input_sequence_len, + num_heads_kv=self.config.num_key_value_heads, + head_dim_v=self.config.head_dim, + head_dim_k=self.config.head_dim, + dtype=torch.bfloat16, + is_paged=self.config.is_paged, + page_size=16, + total_num_pages=batch_size * max_input_sequence_len // 16, + ) + + # Set the inference params for both the context/prefill phase and + # generation phase objects. + self._model_context_phase.set_inference_params(inference_params) + self._model_generation_phase.set_inference_params(inference_params) + + # Context/prefill phase. + hidden_states, next_tokens = self.setup_and_run_context_phase( + input_ids, inference_params + ) + + # Generation phase. + lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + output_tokens = [next_tokens] + + for _ in range(max_new_tokens): + next_tokens = self._model_generation_phase( + hidden_states, + mask=None, + attn_mask_type="padding", + rope_emb=self.te_rope_emb, + ) + + # Increase sequence offsets by one because we generated one token + # for every sequence. + lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + + # `next_tokens` is a static output tensor, so we need to clone + # it because it gets changed every iteration. + output_tokens.append(next_tokens.clone()) + + result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) + return result + + def forward(self, *args, **kwargs): + """ + Forward pass for the model. This is used in calibration step when + forward pass is needed to generate FP8 calibration data. + """ + + self._model_context_phase.set_inference_params(None) + hidden_states = self.model.embed_tokens(kwargs["input_ids"]) + logits = self._model_context_phase( + hidden_states, + attention_mask=( + kwargs["input_ids"] == 0 + ), # Hardcoded, this only applies to bshd/sbhd layouts. + attn_mask_type="padding_causal", + ) + return logits + + +class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM): + """ + TEGemmaForCausalLMCudaGraphs is a wrapper over the class TEGemmaForCausalLM + and uses CUDA Graphs to speed up the generation process. We need to make one + trade-off - batch_size, max_seq_len and max_context_seq_len need to + be static. It is necessary to run generation without changing the pointer + to the variables that are recorded in the graph. + """ + + def __init__(self, config: GemmaConfig): + super().__init__(config) + + self.config = config + + # Preparation of the static buffer to hold the hidden states that are + # passed from one layer to the next. + self.hidden_states_buffer = torch.empty( + ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + self.config.hidden_size, + ) + ).cuda() + + # This is in fact part of the buffer for hidden_states. Refer to the + # `_get_generation_buffer` function for more details. + self.generation_buffer = self._get_generation_buffer( + self.hidden_states_buffer, + ) + + # InferenceParams contains the keys and values cache. Refer to the + # original call in TEGemmaForCausalLM's `generate` method for more + # details. + self.inference_params = InferenceParams( + max_batch_size=self.config.cuda_graphs_static_batch_size, + max_sequence_length=self.config.cuda_graphs_static_max_context_len, + num_heads_kv=self.config.num_key_value_heads, + head_dim_v=self.config.head_dim, + head_dim_k=self.config.head_dim, + dtype=torch.bfloat16, + is_paged=self.config.is_paged, + page_size=16, + total_num_pages=self.config.cuda_graphs_static_batch_size + * self.config.cuda_graphs_static_max_context_len + // 16, + ) + + self._model_generation_phase.set_inference_params(self.inference_params) + self._model_context_phase.set_inference_params(self.inference_params) + + def record(self): + """ + Here "the trick" happens. `_model_context_phase` and + `_model_generation_phase` from TEGemmaForCausalLM are replaced with + their recorded version. Once the graphs are recorded, they can be + replayed with minimal usage of CPU and that leads to speedup. + """ + # Record the model with training=False, because it will be used in + # generation. + self.eval() + + # Setup the recording for context/prefill phase. + input_shape = ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + ) + + # Hardcoded value for the context length. + lengths = torch.tensor([9] * self.config.cuda_graphs_static_batch_size).to( + device="cuda", dtype=torch.int32 + ) + self.inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths))), lengths.tolist())) + ) + + # Record the graph for context/prefill phase. + self._model_context_phase = self.record_graph( + self._model_context_phase, + self.hidden_states_buffer, + attn_mask_type="padding_causal", + rope_emb=self.te_rope_emb, + ) + + # Setup the recording for generation phase. + input_shape = (self.config.cuda_graphs_static_batch_size, 1) + lengths = torch.tensor(input_shape[0] * [1], device="cuda", dtype=torch.int32) + self.inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths))), lengths.tolist())) + ) + + # Record the graph for generation phase. + self._model_generation_phase = self.record_graph( + self._model_generation_phase, + self.generation_buffer, + attn_mask_type="padding", + rope_emb=self.te_rope_emb, + ) + + def _create_or_fetch_hidden_states_buffer(self, *args, **kwargs): + """ + Overriden to make `hidden_states` static i.e. not change its pointer + in memory between every invocation. + + Returns the static buffer for `hidden states` which is already created + in the constructor. This is the same buffer as used in the + context/prefill phase. + """ + return self.hidden_states_buffer + + def _create_or_fetch_inference_params(self, *args, **kwargs): + """ + Overriden to make `inference_params` static i.e. not change its pointer + in memory between every invocation. + + Returns the static buffer for `inference_params` which is already created + in the constructor. + """ + self.inference_params.reset() + return self.inference_params + + @torch.no_grad() + def record_graph(self, function, input_tensor, **sample_kwargs): + """ + Records the graph for the given function. The function is invoked on + argument (self.hidden_states,) and all kernels are recorded. + It then returns the captured callable, which can be run later while + minimizing CPU usage. + """ + fp8_recipe = get_default_fp8_recipe() + + # We need both autocasts: FP8 for operations that can run in lower + # precision and BF16 for those that cannot. + with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False): + graphed_function = te.pytorch.make_graphed_callables( + function, + (input_tensor,), + fp8_enabled=self.config.fp8, + fp8_recipe=fp8_recipe, + allow_unused_input=True, + num_warmup_iters=5, + sample_kwargs=sample_kwargs, + ) + return graphed_function diff --git a/docs/examples/te_gemma/te_gemma_loading_weights.py b/docs/examples/te_gemma/te_gemma_loading_weights.py new file mode 100755 index 0000000000..d0df9edc58 --- /dev/null +++ b/docs/examples/te_gemma/te_gemma_loading_weights.py @@ -0,0 +1,189 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import re +import gc +import torch + +from typing import List + +from transformer_engine.pytorch.fp8 import fp8_model_init + +from transformers.modeling_utils import load_state_dict +from transformers.utils.hub import get_checkpoint_shard_files + +""" + This file contains logic of mapping the HuggingFace GemmaModel parameters + with TransformerEngine TransformerLayer. When we have initialized Transformer models + both with HF and with TE, we can copy parameters from the first to the second. +""" + + +def _load_weights_for_fp8_model(vanilla_model, hyperparams): + """ + Loads weights and FP8 metadata from a calibrated weights file. + + The weights are in BF16 precision, but the state dict also contains + fp8 metadata computed by the calibration procedure. + """ + + fp8_metadata_sd = torch.load(hyperparams.fp8_model_weights_filename) + + # A hack to remove the extra state from the fp8_metadata_sd + # that contains the extra state from the core_attention module. + fp8_metadata_sd = { + k: v for k, v in fp8_metadata_sd.items() if "core_attention._extra_state" not in k + } + vanilla_model.load_state_dict( + fp8_metadata_sd, + strict=False, + # Because some parameters have multiple pointers to the same weight + # vanilla_model._model_context_phase.model and + # vanilla_model._model_generation_phase.model we need to load the + # weights in a non-strict manner. + ) + + +def _load_weights_for_standard_model(vanilla_model, config): + """ + Loads weights from the HuggingFace checkpoint. + """ + + archive_file = os.path.join(config.weights_cache_dir, "model.safetensors.index.json") + resolved_archive_file, _ = get_checkpoint_shard_files(config.weights_cache_dir, archive_file) + total_dict = {} + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + total_dict.update(state_dict) + + replace_params( + total_dict, + vanilla_model.state_dict(), + config, + qkv_fused_and_interleaved=config.fuse_qkv_params, + ) + # Copy remaining parameters like embedding. + vanilla_model.load_state_dict(total_dict, strict=False) + + # Force mem release. Taken from huggingface code. + del total_dict + gc.collect() + + +def load_te_model(cls, config): + """ + Loads the TE model with proper weights. + """ + + # Force the dtype to bfloat16 while loading the model. + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) + """ + Custom method adapted from `from_pretrained` method in HuggingFace + Transformers repo: + https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 + """ + config.use_cache = False # To make TransformerLayer compatible with GemmaModel + + # Loading model with FP8 only weights needs both the following context managers. + # 1. fp8_model_init(config.fp8_model_init) to tell TE to use FP8 only weights. + # 2. torch.no_grad() during TE modules' initilization so that they respect + # the `fp8_model_init` context manager. + with torch.no_grad(), fp8_model_init(config.fp8_model_init): + # Just create a model with random weights. + vanilla_model = cls(config).cuda() + + # Copy proper weights into the model. If loading weights with FP8 metadata, + # then the source weights are basically the same as the weights in the model. + # If not, then we need to load the weights from the HuggingFace checkpoint + # and do mapping of the weight names from HF to the TE model. + if config.fp8_model_weights_filename is not None: + _load_weights_for_fp8_model(vanilla_model, config) + else: + _load_weights_for_standard_model(vanilla_model, config) + + # Restore the original dtype. + torch.set_default_dtype(old_dtype) + return vanilla_model + + +def _get_all_layer_prefixes_to_update(hf_state_dict): + """ + There are many parameters in hf_state_dict, whose name start with "model.layers.[number]." + This function extracts all strings like "model.layers.[number]." + that are starting strings of keys in hf_state_dict. + """ + all_layer_prefixes = set() + for param_key in hf_state_dict.keys(): + layer_prefix_pat = "model.layers.\d+." + m = re.match(layer_prefix_pat, param_key) + if m is not None: + all_layer_prefixes.add(m.group()) + return all_layer_prefixes + + +def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False): + """ + Replaces params from TE TransformerLayer state_dict with corresponding parameters + from HuggingFace GemmaModel state_dict. + """ + all_layer_prefixes: List[str] = _get_all_layer_prefixes_to_update(hf_state_dict) + + for layer_prefix in all_layer_prefixes: + + def copy_from_ht_to_te(te_name, hf_name, start=None, end=None): + te_state_dict[layer_prefix + te_name].data[start:end].copy_( + hf_state_dict[layer_prefix + hf_name] + ) + + copy_from_ht_to_te( + "self_attention.layernorm_qkv.layer_norm_weight", "input_layernorm.weight" + ) + copy_from_ht_to_te("self_attention.proj.weight", "self_attn.o_proj.weight") + copy_from_ht_to_te("layernorm_mlp.layer_norm_weight", "post_attention_layernorm.weight") + copy_from_ht_to_te("layernorm_mlp.fc2_weight", "mlp.down_proj.weight") + copy_from_ht_to_te( + "layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size + ) + copy_from_ht_to_te( + "layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size + ) + + if qkv_fused_and_interleaved: + """ + When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor + in TE TransformerLayer. Moreover they are interleaved within each head. + Let q_i, k_i and v_i be query, key and value layers for i-th head respectively. + Then TE stores weight tensor in the form: + [q1 k1 v1 q2 k2 v2 ...] + This is done to maximally optimize performance time. + """ + te_qkv_layer = te_state_dict[layer_prefix + "self_attention.layernorm_qkv.weight"] + + def copy_interleave(hf_name, idx): + src = hf_state_dict[layer_prefix + hf_name] + for head_nr in range(config.num_attention_heads): + dst_offset = head_nr * config.head_dim * 3 + dst_slice = slice( + dst_offset + idx * config.head_dim, dst_offset + (idx + 1) * config.head_dim + ) + src_slice = slice( + head_nr * config.head_dim, head_nr * config.head_dim + config.head_dim + ) + te_qkv_layer[dst_slice, :] = src[src_slice, :] + + copy_interleave("self_attn.q_proj.weight", 0) + copy_interleave("self_attn.k_proj.weight", 1) + copy_interleave("self_attn.v_proj.weight", 2) + else: + copy_from_ht_to_te( + "self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight" + ) + copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight") + copy_from_ht_to_te( + "self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight" + ) + + return all_layer_prefixes diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb new file mode 100755 index 0000000000..cc8675cfd8 --- /dev/null +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -0,0 +1,941 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "87e8360b-8d08-44bc-9333-79ba949afe8c", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "# Accelerating Hugging Face Gemma Inference with Transformer Engine" + ] + }, + { + "cell_type": "markdown", + "id": "2da33092-eef5-46a4-b222-0188cc6e5079", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "## Introduction\n", + "\n", + "Generative AI has made remarkable strides in recent years, with Large Language Models (LLMs) like ChatGPT at the forefront. These models have revolutionized how we interact with machine-generated content, providing capabilities that range from writing assistance to complex decision support. The core functionality of these models is the generation process, which involves predicting the next token in a sequence based on the preceding text. This task is critical for applications such as automated content creation, translation, and more, emphasizing the importance of efficient implementation.\n", + "\n", + "

\n", + "\"\"\n", + "
\n", + "Animation 1: Hugging Face Gemma model token generation.\n", + "
\n", + "
\n", + "\n", + "For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n", + "\n", + "In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n", + "\n", + "This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n", + "\n", + "### 1. From vanilla KV-caching to Paged Attention for inference in Transformer Engine\n", + "\n", + "The original [Attention mechanism](https://arxiv.org/pdf/1706.03762) ushered in an era of Large Language Models, but the same attention mechanism, if used for deployment in inference scenarios, can be computationally wasteful. It is primarily due to a lot of redundant computation that happens in attention when the Transformer models are used autoregressively to compute the next token. Several tutorials on the internet explain in detail how KV Caching helps to reduce that redundant computation, e.g., [tutorial 1](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms), [tutorial 2](https://medium.com/@joaolages/kv-caching-explained-276520203249), etc.\n", + "\n", + "\n", + "Further, even though the performance benefit of KV Cache is immense, it comes at the cost of increased memory usage, which becomes a problem especially for longer context lengths. The major problems are: \n", + "\n", + "1. Internal fragmentation\n", + "2. External Fragmentation\n", + "\n", + "More information can be found in the [Paged Attention](https://arxiv.org/pdf/2309.06180) paper. The authors solve the above problems by treating the KV cache as a virtual memory with the actual physical blocks being much smaller than the overall cache size. This makes it easier to swap them in and out of GPU HBM as needed - very similar to how Operating Systems implement virtual memory to swap the individual pages in and out of the CPU RAM.\n", + "\n", + "\n", + "Transformer Engine allows users to use both \"Non-paged\" and \"Paged\" forms of KV Caching, and the results in this tutorial are posted for both use cases.\n", + "\n", + "\n", + "### 2. CUDA Graphs API\n", + "\n", + "The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to finish processing and then launch the kernels, which can lead to significant overhead. CUDA Graphs can address this issue. When such blocks of computation are executed repeatedly, CUDA Graphs allow us to record and replay them with less CPU involvement. This becomes particularly useful in applications like token generation, where multiple \"Transformer/Decoder Layers\" are run for every token that needs to be generated.\n", + "\n", + "One can read more about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n", + "\n", + "PyTorch exposes graphs via a raw `torch.cuda.CUDAGraph` class and two convenience wrappers: `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the CUDA graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 1: CUDA Graphs reduce the overhead generated by the long time it takes to launch a single kernel. It enables the recording and replaying of subsequent launches, thus reducing the total time used by the CPU.\n", + "
\n", + "
\n", + "\n", + "### 3. FP8 Scaling Factors Calibration\n", + "\n", + "This tutorial uses the `DelayedScaling` recipe for FP8 precision, which relies on the correct calculation of \"scaling factors\".\n", + "\n", + "If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `fp8_autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n", + "\n", + "It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", + "\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 2:\n", + "Assuming that the model is trained in FP32/BF16 precision and the goal is to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, the FP8 scaling parameters can be computed. This calibration allows the model to operate correctly in FP8 precision.\n", + "
\n", + "
\n", + "\n", + "### 4. FP8 Model Weights\n", + "\n", + "The typical approach is to store weights in higher precision and then cast them to FP8 before operations. This may prevent accuracy drops in training. However, for inference, this level of precision is not necessary.\n", + "\n", + "The Transformer Engine includes a wrapper `fp8_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 3: Model under fp8_autocast() stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using fp8_model_init() results in storing model weights in FP8 by default, which can help with these potential issues.\n", + "
\n", + "
\n", + "\n", + "### Benchmarking\n", + "\n", + "We'll evaluate the generation time across one benchmark: token generation with context/prefill phase max sequence length = 20, batch size = 64, and number of generated tokens = 492 on random texts with random lengths. This is a purely synthetic benchmark.\n", + "\n", + "
\n", + "Note\n", + " \n", + "This tutorial focuses on showcasing the mentioned features of the Transformer Engine in the context of token generation. It's important to note, however, that NVIDIA provides [TensorRT-LLM](https://docs.nvidia.com/tensorrt-llm/index.html), which is optimized for inference tasks and should be considered for such use cases.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "b18f91a9", + "metadata": {}, + "source": [ + "## Dependencies for this tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "e5201d77", + "metadata": {}, + "source": [ + "The following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_gemma.py`\n", + " - This file contains the code to load a Hugging Face Gemma checkpoint weights in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. Further, it contains necessary abstractions like a subclass of `GemmaForCausalLM` - `TEGemmaForCausalLM` that is used for generation with Transformer Engine's `TransformerLayer`, CUDA Graphs, and FP8 calibration for generation in FP8 precision.\n", + "2. `te_gemma_loading_weights.py`\n", + " - This file contains the logic of mapping the parameters from `GemmaDecoderLayer` into the `TransformerLayer`.\n", + "3. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training, and other miscellaneous tasks like restarting the Jupyter notebook from within the cell. \n", + "4. `requirements.txt`\n", + " - This file contains the necessary Python packages for this tutorial.\n", + "5. `media/`\n", + " - This directory contains the images and other artefacts used in this tutorial." + ] + }, + { + "cell_type": "markdown", + "id": "36767694-a1c5-4a00-a075-7addc55d8307", + "metadata": {}, + "source": [ + "### Setup and checks" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1de3351b-fa21-4b95-bb9e-d01ac8bb7edf", + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment and run this cell when running the tutorial for the first time\n", + "# %pip install -r requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c756ebbd-24c9-4a54-a381-e7c02c555206", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "import torch\n", + "cudnn_version = torch.backends.cudnn.version()\n", + "assert cudnn_version >= 90100, \"cuDNN version >= 9.1.0 is needed to run this tutorial.\"" + ] + }, + { + "cell_type": "markdown", + "id": "e8dfabbf", + "metadata": {}, + "source": [ + "## [Baseline] Running Hugging Face generation with Gemma model" + ] + }, + { + "cell_type": "markdown", + "id": "59560bff", + "metadata": {}, + "source": [ + "HuggingFace Transformers library offers generation API. \n", + "HuggingFace generation for the Gemma model will be used as a baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "2803e0ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing a lot of the same thing at the same time.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why GPUs are so good at graphics. The\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and builds advanced computer graphics and video processing chips for the PC and video game console markets.\n", + "* The company is a leading provider of graphics processing units (GPUs) for the PC and video game\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 46.60 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.batch_size = 64\n", + "run_config.max_seq_length = 512\n", + "\n", + "model = init_baseline_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "b3698dc6", + "metadata": {}, + "source": [ + "Let's put this time into the table for later comparison.\n", + "\n", + "| Models | Time | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 46.6 s | - |" + ] + }, + { + "cell_type": "markdown", + "id": "8bb40f45", + "metadata": {}, + "source": [ + "## [Optimization 1] Accelerating generation with Transformer Engine " + ] + }, + { + "cell_type": "markdown", + "id": "263b40f2", + "metadata": {}, + "source": [ + "Similar to the [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) finetuning tutorial, a `GemmaDecoderLayer` is substituted by a tuned `TransformerLayer` from the Transformer Engine library. Let's run it and compare the time with the baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9dceef93", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing a lot of the same thing at the same time.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why they are so good at graphics. The second\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n", + "* NVIDIA is the world leader in AI computing.\n", + "* NVIDIA is the world leader in graphics processing units (GP\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 12.25 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.batch_size = 64\n", + "run_config.max_seq_length = 512\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "b5d40836", + "metadata": {}, + "source": [ + "With just using Transformer Engine with default (non-paged) KV cache, a speedup of **3.8x** was obtained. Neat!" + ] + }, + { + "cell_type": "markdown", + "id": "006d18e8", + "metadata": {}, + "source": [ + "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n", + "|---|---|---|---|---|\n", + "| HF (baseline) | 46.6 s | - | - | - |\n", + "| TE (subsitution of `GemmaDecoderLayer` with `te.TransformerLayer`) | 12.25 s | 3.8x | 12.24 s | 3.8x |" + ] + }, + { + "cell_type": "markdown", + "id": "21a89d9c", + "metadata": {}, + "source": [ + "## [Optimization 2] More acceleration with CUDA Graphs" + ] + }, + { + "cell_type": "markdown", + "id": "e2d53e7b", + "metadata": {}, + "source": [ + "Transformer Engine includes a function `transformer_engine.pytorch.make_graphed_callables`, which behaves similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from [te_gemma.py](./te_gemma.py) from class `TEGemmaForCausalLMCudaGraphs`:\n", + "```python\n", + " def __init__(self, config : GemmaConfig):\n", + " \"\"\"\n", + " Here \"the trick\" happens. `_model_context_phase` and\n", + " `_model_generation_phase` from TEGemmaForCausalLM are replaced with\n", + " their recorded version. Once the graphs are recorded, they can be\n", + " replayed with minimal usage of CPU and that leads to speedup.\n", + " \"\"\"\n", + " (...)\n", + " # Record the graph for context/prefill phase.\n", + " self._model_context_phase = \n", + " self.record_graph(self._model_context_phase, self.hidden_states_buffer)\n", + "\n", + " (...) \n", + " # Record the graph for generation phase.\n", + " self._model_generation_phase = \n", + " self.record_graph(self._model_generation_phase, self.generation_buffer)\n", + "\n", + " @torch.no_grad()\n", + " def record_graph(self, function, input_tensor):\n", + " \"\"\"\n", + " Records the graph for the given function. The function is invoked on\n", + " argument (self.hidden_states,) and all kernels are recorded.\n", + " It then returns the captured callable, which can be run later while\n", + " minimizing CPU usage.\n", + " \"\"\"\n", + " fp8_recipe = get_default_fp8_recipe()\n", + "\n", + " # We need both autocasts: FP8 for operations that can run in lower\n", + " # precision and BF16 for those that cannot.\n", + " with autocast(\"cuda\", dtype=torch.bfloat16, cache_enabled=False):\n", + " graphed_function = te.pytorch.make_graphed_callables(\n", + " function,\n", + " (input_tensor,),\n", + " fp8_enabled=self.config.fp8,\n", + " fp8_recipe=fp8_recipe,\n", + " allow_unused_input=True,\n", + " num_warmup_iters=5,\n", + " sample_kwargs=sample_kwargs,\n", + " )\n", + " return graphed_function\n", + "```\n", + "\n", + "It is strongly recommended to review the entire code of the class `TEGemmaForCausalLMCudaGraphs`. Let's now proceed to evaluate the performance improvement offered by CUDA Graphs.\n", + "\n", + "*Note the usage of static buffers and corresponding configuration in the following cell, which is necessary for CUDA Graphs to function.*" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "31a3a8a3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing a lot of the same thing at the same time.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why they are so good at graphics. The second\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n", + "* NVIDIA is the world leader in AI computing.\n", + "* NVIDIA is the world leader in graphics processing units (GP\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 6.39 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.max_seq_length = 512\n", + "run_config.batch_size = 64\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "# It is necessary to preallocate a static buffer.\n", + "# CUDA graphs require static input tensors for every kernel.\n", + "# This approach may result in a slight increase in memory consumption;\n", + "# however, the substantial speedup achieved makes it worthwhile.\n", + "run_config.generation_cuda_graphs = True\n", + "run_config.cuda_graphs_static_batch_size = 64\n", + "run_config.cuda_graphs_static_max_seq_len = 512\n", + "run_config.cuda_graphs_static_max_context_len = 512\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "53bb430f", + "metadata": {}, + "source": [ + "A speed up of **7.2x** was obtained by using CUDA Graphs with TE's `TransformerLayer`.\n", + "\n", + "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n", + "|---|---|---|---|---|\n", + "| HF (baseline) | 46.6 s | - | - | - |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n", + "| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |" + ] + }, + { + "cell_type": "markdown", + "id": "0a11b75c", + "metadata": {}, + "source": [ + "Let's profile the code from one of the cells above, which runs generation with the Gemma model, and examine the resulting traces in [NVIDIA Nsight Systems](https://developer.nvidia.com/nsight-systems) to understand the performance characteristics and sources of speedup. A few things to recap:\n", + "\n", + "1. For the TE Gemma model implementation, `model.generate()` internally calls `model_context_phase` and `model_generation_phase`.\n", + "2. They are just wrappers around the Gemma model's layers, and they are graphed separately when CUDA graphs are enabled.\n", + "3. So, for each token generated (after the first token), a single invocation of `model_generation_phase` happens as a complete CUDA graph. \n", + "4. The following illustration zooms in on a single `TransformerLayer` layer forward pass (within the larger `model_generation_phase` graphed callable) for clarity.\n", + "\n", + "(For details, refer to the implementation in [te_gemma.py](./te_gemma.py))\n", + "\n", + "
\n", + "\n", + "
\n", + " \n", + "Figure 4: (Without CUDA graphs) Blue blobs in the top figure are GPU kernels, and whitespace b/w those indicates that GPUs are idle waiting for the CPU to finish processing and then launch kernels. (With CUDA graphs) The whitespace gets virtually eliminated because all the GPU kernels are bundled into a single highly optimized unit of work with no CPU time in between. (Note that for reference, the kernels are mapped across both cases, and the sizes of those kernels only seem different because of the presence of large voids in the former case, but the sizes are actually the same.)\n", + "
\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "id": "e6b171a0", + "metadata": {}, + "source": [ + "## [Optimization 3] Even more acceleration with FP8 precision " + ] + }, + { + "cell_type": "markdown", + "id": "1a80288b", + "metadata": {}, + "source": [ + "### Calibrating FP8 scaling factors for correctness\n", + "\n", + "Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `fp8_autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n", + "\n", + "1. Model weight tensors\n", + "2. Input tensors\n", + "\n", + "If the model is run in FP8 precision with incorrect scaling factors, the resulting FP8-cast model weights and FP8-cast inputs (both converted from BF16 precision) will be significantly misaligned, potentially leading to large errors and inaccurate results.\n", + "\n", + "To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n", + "\n", + "*Note that other tensors might need calibration in specific use-cases, but for the generation process in this tutorial, calibrating only the input and weight tensors is needed, and so only the forward pass is considered.*\n", + " \n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 5: The default FP8 scaling factors are incorrect, and so the BF16 to FP8 conversion, as is, can lead to numerical errors. Calibration allows for collecting statistics/metadata about the input and weight tensors in higher precision during the forward pass.\n", + "
\n", + "
\n", + "\n", + "\n", + "The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, the model is saved, and these weights will be utilized in subsequent steps." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "aecee0e1", + "metadata": {}, + "outputs": [], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "import transformer_engine.pytorch as te\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "run_config.fuse_qkv_params = True\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "# Calibration\n", + "with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast(\n", + " device_type=\"cuda\", dtype=torch.bfloat16\n", + "):\n", + " model.train()\n", + " run_forward_pass(model, run_config, num_iters=64)\n", + "\n", + "# Compute scale_fwd with enabled fp8 autocast\n", + "with te.fp8_autocast(enabled=True), torch.autocast(\n", + " device_type=\"cuda\", dtype=torch.bfloat16\n", + "):\n", + " run_forward_pass(model, run_config, 1)\n", + "\n", + "# Some parameters are in pointing to the same tensors, double save is avoided here.\n", + "dict_to_save = {\n", + " k: v\n", + " for k, v in model.state_dict().items()\n", + " if (\"_context_phase\" not in k and \"_generation_phase\" not in k)\n", + "}\n", + "torch.save(\n", + " dict_to_save, \"calibrated_weights.pth\"\n", + ") # <-- Add path to save calibrated weights." + ] + }, + { + "cell_type": "markdown", + "id": "b6dcd135", + "metadata": {}, + "source": [ + "### Generation with better FP8 scaling factors\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 6: After the calibration process, FP8 scaling factors are correct and prevent numerical errors.\n", + "
\n", + "
\n", + "\n", + "Now that the calibration has produced correct scaling factors, FP8 inference is ready to be run." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a913f54d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n", + "* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA is a key player\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 8.73 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.fuse_qkv_params = True # This is needed by the last improvement.\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "# CUDA Graphs related config\n", + "run_config.generation_cuda_graphs = True\n", + "run_config.cuda_graphs_static_batch_size = 64\n", + "run_config.cuda_graphs_static_max_seq_len = 512\n", + "run_config.cuda_graphs_static_max_context_len = 512\n", + "\n", + "# Enable FP8\n", + "run_config.fp8 = True\n", + "# Calibrated fp8 weights are loaded directly from the file.\n", + "run_config.fp8_model_weights_filename = (\n", + " \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n", + ")\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "8cdbb56c", + "metadata": {}, + "source": [ + "One can observe that the outputs are coherent; however, the generation time has increased. Why is this the case?\n", + "\n", + "### Use of FP8-only model weights\n", + "\n", + "Running the model in FP8 precision does not imply that the weights are stored in FP8. By default, they are stored in higher precision and are cast to FP8, using saved scaling factors before GEMM operations (matrix multiplications).\n", + "\n", + "This approach is appropriate during training since gradients during the backward pass are produced in higher precision, and therefore, having higher precision copies of model weights helps, as they have enough dynamic range to encompass incoming information from the gradients. During the forward pass, the higher precision model weights and the batch inputs are cast to FP8, and the GEMMs occur in FP8 precision, which helps save training time overall if the time saved from running GEMM in FP8 precision (than in higher precision) is more than the extra time spent during the cast operation.\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 7: Running the model at higher precision involves only one operation - GEMM. However, when the model operates in FP8, it requires casting inputs to the GEMM - namely, model weights and batch inputs from higher precision to FP8, which involves extra kernels in addition to the low-precision GEMM kernel.\n", + "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "626aefa1-d5c4-4d8f-88d9-7d7943afde0d", + "metadata": {}, + "source": [ + "However, things change during inference. Since the weights need no update and remain frozen, higher precision copies of weights could be avoided completely. It is possible to cast the higher precision weights only once to FP8 precision while initializing the model with appropriate scaling factors and then use those FP8-only copies of weights during the entirety of token generation. This provides two-fold benefits:\n", + "\n", + "1. Lower memory usage - since the model weights are stored in FP8 precision only (compared to training, where both BF16 and FP8 copies end up being present in the memory during peak usage).\n", + "2. Faster forward pass - since there is no cast kernel to cast higher precision weights to FP8 every time before a GEMM operation. (Unless the inputs are in FP8 precision already, there's still one cast kernel to cast inputs to FP8 precision.) \n", + "\n", + "\n", + "Transformer Engine supports maintaining FP8-only weights with the `fp8_model_init` context manager. Let's see a small example:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4562ee82-8c95-4736-8815-cd386078a485", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory required for 16384x16384 linear layer: \n", + "FP32 - 1024.0 MB, \n", + "BF16 - 512.0 MB, \n", + "FP8 - 256.0 MB, \n", + "\n", + "Actual GPU memory usage with a TE FP32 linear layer: 1024.06 MB\n", + "Actual GPU memory usage with a TE BF16 linear layer: 512.03 MB\n", + "Actual GPU memory usage with a TE FP8 linear layer: 256.08 MB\n" + ] + } + ], + "source": [ + "import torch\n", + "import transformer_engine.pytorch as te\n", + "\n", + "H = 2**14\n", + "D = 2**14\n", + "print(f\"Memory required for {H}x{D} linear layer: \\n\"\n", + " f\"FP32 - {H*D*4/1024**2} MB, \\n\"\n", + " f\"BF16 - {H*D*2/1024**2} MB, \\n\"\n", + " f\"FP8 - {H*D*1/1024**2} MB, \\n\")\n", + "\n", + "linear_fp32 = te.Linear(H, D, params_dtype=torch.float32) \n", + "print(f\"Actual GPU memory usage with a TE FP32 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", + "del linear_fp32\n", + "\n", + "linear_bf16 = te.Linear(H, D, params_dtype=torch.bfloat16)\n", + "print(f\"Actual GPU memory usage with a TE BF16 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", + "del linear_bf16\n", + "\n", + "# Initialize model weights in FP8 precision\n", + "with torch.no_grad(), te.fp8_model_init(enabled=True):\n", + " linear_fp8 = te.Linear(H, D)\n", + "print(f\"Actual GPU memory usage with a TE FP8 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", + "del linear_fp8" + ] + }, + { + "cell_type": "markdown", + "id": "2a26aba9-f3ba-42c4-b4c3-9e845502ae1b", + "metadata": {}, + "source": [ + "\n", + "
\n", + "\n", + "
\n", + " Figure 8: Using fp8_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n", + "
\n", + "
\n", + "\n", + "Let's run the code with `fp8_model_init`:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "96264b9c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n", + "* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA is a key player\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 4.99 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.fuse_qkv_params = True # This is needed by the last improvement.\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "# CUDA Graphs related config\n", + "run_config.generation_cuda_graphs = True\n", + "run_config.cuda_graphs_static_batch_size = 64\n", + "run_config.cuda_graphs_static_max_seq_len = 512\n", + "run_config.cuda_graphs_static_max_context_len = 512\n", + "\n", + "# Enable FP8 math and FP8 model weights\n", + "run_config.fp8 = True\n", + "run_config.fp8_model_init = True # This will result in storing only fp8 weights.\n", + "run_config.fp8_model_weights_filename = (\n", + " \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n", + ")\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "3e30ca5a", + "metadata": {}, + "source": [ + "The final speedup is **9.3x**. \n", + "\n", + "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n", + "|---|---|---|---|---|\n", + "| HF (baseline) | 46.6 s | - | - | - |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n", + "| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |\n", + "| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `fp8_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |" + ] + }, + { + "cell_type": "markdown", + "id": "c6e87275", + "metadata": {}, + "source": [ + "## Conclusions" + ] + }, + { + "cell_type": "markdown", + "id": "7bb2452d", + "metadata": {}, + "source": [ + "This tutorial focuses primarily on making the token generation faster with an off-the-shelf model downloaded from Hugging Face using the following features of the Transformer Engine:\n", + "\n", + "1. Support for KV Caching (both non-paged and paged),\n", + "2. Integration with CUDA Graphs,\n", + "3. FP8 scaling factors calibration,\n", + "4. Keeping model parameters in FP8 precision.\n", + "\n", + "It's worth noting that these features in TE are also readily applicable to other use-cases which haven't been extensively talked about in the tutorial: \n", + "\n", + "1. Longer context lengths (with paged KV cache) \n", + "2. Using less memory during generation (by storing weights in FP8 precision using `fp8_model_init`)\n", + "\n", + "Readers are encouraged to explore these use cases by playing around with this tutorial, especially with larger models." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py new file mode 100755 index 0000000000..cc31afc65a --- /dev/null +++ b/docs/examples/te_gemma/utils.py @@ -0,0 +1,370 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import sys +import IPython +import random +import string + +from te_gemma_loading_weights import load_te_model +import torch +from torch.utils.data import DataLoader + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + AutoConfig, +) +from transformers import DataCollatorForLanguageModeling +from datasets import load_dataset + + +from te_gemma import TEGemmaForCausalLM, TEGemmaForCausalLMCudaGraphs + +random.seed(42) +torch.manual_seed(42) + + +class RunConfiguration: + def __init__(self): + self.mixed_precision = "bf16" + self.model_name = None + + # FP8 precision settings + self.fp8 = False + self.fp8_model_weights_filename = None + self.fp8_model_init = False + + # Cuda graphs + self.generation_cuda_graphs = False + self.cuda_graphs_static_batch_size = 64 + self.cuda_graphs_static_max_seq_len = 512 + self.cuda_graphs_static_max_context_len = 512 + + # Finetuning/calibration/generation settings + self.dataset_name = "timdettmers/openassistant-guanaco" + self.dataset_text_field = "text" + self.learning_rate = 1.41e-5 + self.batch_size = 64 + self.max_seq_length = 512 + self.gradient_accumulation_steps = 1 + self.num_warmup_steps = 5 + self.num_training_steps = 10 + + # Coalesced QKV params or not + self.fuse_qkv_params = False + + # Attention + self.is_paged = False + + # This is either provided by the user or it will be set when the + # model weights are downloaded. + self.weights_cache_dir = "" + + +# Global variable for the run configuration so that it can be easily accessed +# throughout the jupyter notebook with an `import * from utils` statement +run_config = RunConfiguration() + + +def get_dataloaders(run_config): + """ + Returns a basic dataloader for the dataset which contains tokenized batches + of text. + """ + dataset = load_dataset(run_config.dataset_name, split="train") + tokenizer = AutoTokenizer.from_pretrained(run_config.model_name) + + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + def tokenize(element): + outputs = tokenizer( + element["text"], + truncation=True, + padding=False, + max_length=run_config.max_seq_length, + return_overflowing_tokens=False, + return_length=False, + ) + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + + # Tokenize the dataset + dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names) + + # Simply pad to the multiple of 16 for both FP8 and BF16 precision + pad_to_multiple_of = 16 + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + pad_to_multiple_of=pad_to_multiple_of, + ) + + dataloader_params = { + "batch_size": run_config.batch_size, + "collate_fn": data_collator, + "drop_last": True, + } + train_dataloader = DataLoader(dataset, **dataloader_params) + return train_dataloader + + +def ensure_model_is_downloaded(run_config): + """ + Downloads and caches the model weights if not already downloaded. A valid + Huggingface Access Token is required to download the model weights. + """ + assert run_config.model_name in [ + "google/gemma-7b", + ], "Only Gemma 7B model is supported!" + + # Login using Huggingface Hub API + from huggingface_hub import login + + try: + login(run_config.hf_access_token) + except Exception as e: + if "Invalid token passed!" in str(e): + print( + "Please pass a valid HF Access Token! More info at" + " https://huggingface.co/docs/hub/en/security-tokens." + ) + else: + print(f"Exception is {e}") + + # Download the model if it doesn't exist + from huggingface_hub import snapshot_download + + supplied_cache_dir = ( + run_config.weights_cache_dir if run_config.weights_cache_dir != "" else None + ) + run_config.weights_cache_dir = snapshot_download( + repo_id=run_config.model_name, cache_dir=supplied_cache_dir + ) + + +def init_baseline_model(run_config): + """ + Initializes a baseline HF Gemma model with the model name provided in + the run_config. + """ + + # Download and cache the weights if not already downloaded + ensure_model_is_downloaded(run_config) + + # Init the model + config = AutoConfig.from_pretrained(run_config.model_name) + + # Make sure to use flash_attention to do iso comparison with TEGemmaModel + config._attn_implementation = "flash_attention_2" + model = AutoModelForCausalLM.from_pretrained( + run_config.model_name, + config=config, + torch_dtype=torch.bfloat16, + ).cuda() + + return model + + +def init_te_gemma_model(run_config): + """ + Initializes a Gemma model with `GemmaDecoderLayer`s swapped with + `TransformerLayer`s from TransformerEngine. In case CUDA Graphs are enabled, + the model is initialized from `TEGemmaForCausalLMCudaGraphs` class. + """ + + # Download and cache the weights if not already downloaded + ensure_model_is_downloaded(run_config) + + cls = TEGemmaForCausalLMCudaGraphs if run_config.generation_cuda_graphs else TEGemmaForCausalLM + config = AutoConfig.from_pretrained(run_config.model_name) + + # Inject all fields from the `run_config` to the model `config` to make the + # code simpler. + for key, value in run_config.__dict__.items(): + setattr(config, key, value) + + # Initialize the model and move it to the GPU. + model = load_te_model(cls, config).cuda() + + # Record the model if CUDA Graphs are enabled. + if run_config.generation_cuda_graphs: + model.record() + + return model + + +def restart_jupyter_notebook(): + # Try restarting the Jupyter kernel + IPython.Application.instance().kernel.do_shutdown(True) + + # Check whether the device memory has been flushed + if torch.cuda.memory_allocated() != 0: + import warnings + + warnings.warn("The device memory hasn't been flushed, trying with a second method!") + + # Try restarting the Jupyter kernel another way + # Restart the kernel + from IPython.core.display import HTML + + HTML("") + + if torch.cuda.memory_allocated() != 0: + print( + "The device memory hasn't been flushed, try manually restarting the Jupyter kernel!" + ) + + # Suppress the warnings + if not sys.warnoptions: + import warnings + + warnings.simplefilter("ignore") + torch.set_warn_always(False) + + +@torch.no_grad() +def run_forward_pass(model, run_config, num_iters): + """ + Runs the forward pass of the model with sample data. Intended to use for + warmup and/or calibration. + """ + train_dataloader = get_dataloaders(run_config) + + model.train() + train_dataloader = enumerate(train_dataloader) + + for _ in range(num_iters): + _, batch = next(train_dataloader) + batch["input_ids"] = batch["input_ids"].cuda() + batch["attention_mask"] = batch["attention_mask"].cuda() + model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) + + +############################################################################### +# Benchmarking and example generation functions. +############################################################################### + + +def print_sample_of_generated_texts(model, run_config): + """ + Prints a sample of generated texts from the input model. + """ + + tokenizer = AutoTokenizer.from_pretrained(run_config.model_name) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + prompts = [ + "Here are the two facts about GPUs:", + "Some facts about NVIDIA:", + "The fundamental theorem of calculus for the layman:", + "A fact about AI:", + ] + + # Repeat prompts to match batch size + prompts *= run_config.batch_size // len(prompts) + inputs = tokenizer(prompts, return_tensors="pt", padding=True) + + max_total_tokens = ( + run_config.max_seq_length + if not run_config.generation_cuda_graphs + else run_config.cuda_graphs_static_max_seq_len + ) + + max_length = inputs["input_ids"].size(1) + new_length = ((max_length + 63) // 64) * max_total_tokens + + # Add padding to the left + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], (new_length - max_length, 0), value=tokenizer.pad_token_id + ) + + # Add padding to the left (only intended for baseline generation with HF + # which expects padding to the left) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (new_length - max_length, 0), value=0 + ) + + inputs["input_ids"] = inputs["input_ids"].cuda() + inputs["attention_mask"] = inputs["attention_mask"].cuda() + + outputs = model.generate(**inputs, max_new_tokens=50) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + def print_output(prompts, generated_texts, idx): + print("=" * 30 + f" Generation example {idx+1} " + "=" * 30) + print(f'Prompt: "{generated_texts[idx][: len(prompts[idx])]}"') + print(f'Generated text: "{generated_texts[idx][len(prompts[idx]) :]}"') + + # Print the output from first two prompts + for i in range(2): + print_output(prompts, generated_texts, i) + + +def _generate_random_words(num_words, max_word_length): + """ + Generates random words for the benchmark. + """ + + words = [] + for _ in range(num_words): + word_length = random.randint(1, max_word_length) + word = "".join(random.choices(string.ascii_lowercase, k=word_length)) + words.append(word) + return words + + +def benchmark_generation(model, run_config, context_length=20): + """ + Benchmarks the generation time for a random input to the model. + """ + + batch_size = run_config.batch_size + + max_total_tokens = ( + run_config.max_seq_length + if not run_config.generation_cuda_graphs + else run_config.cuda_graphs_static_max_seq_len + ) + max_new_tokens = max_total_tokens - context_length + + print("\n" + "=" * 80) + print( + f"Benchmarking for batch_size = {batch_size}, prefill tokens =" + f" {context_length} and max new tokens = {max_new_tokens}" + ) + + input_str = _generate_random_words(batch_size, context_length) + + tokenizer = AutoTokenizer.from_pretrained(run_config.model_name) + inputs = tokenizer(input_str, return_tensors="pt", padding=True) + + max_context_tokens = inputs["input_ids"].size(1) + + # Add padding to the left + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (max_total_tokens - max_context_tokens, 0), + value=tokenizer.pad_token_id, + ) + + # Add padding to the left (only intended for baseline generation with HF + # which expects padding to the left) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (max_total_tokens - max_context_tokens, 0), value=0 + ) + + inputs["input_ids"] = inputs["input_ids"].cuda() + inputs["attention_mask"] = inputs["attention_mask"].cuda() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + + model.generate(inputs["input_ids"].cuda(), max_new_tokens=max_new_tokens) + torch.cuda.synchronize() + end.record() + + print(f"Time: {start.elapsed_time(end)/1000:.2f} s.") diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index 7013e85ec6..00499cff5f 100644 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -5,7 +5,7 @@ "id": "6a5b2993", "metadata": {}, "source": [ - "# Accelerating a Hugging Face Llama 2 and Llama 3 models with Transformer Engine\n", + "# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n", "\n", "
\n", "\n", diff --git a/docs/index.rst b/docs/index.rst index e678b1d467..2c04810f4d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -46,6 +46,7 @@ Transformer Engine documentation examples/fp8_primer.ipynb examples/advanced_optimizations.ipynb examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb + examples/te_gemma/tutorial_generation_gemma_with_te.ipynb examples/onnx/onnx_export.ipynb .. toctree:: diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 11f07d9133..9e39b84c0b 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -465,14 +465,23 @@ def _test_norm_forward( x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer ) ref_out, ref_mu, ref_rsigma = _jax_layernorm( - x, gamma, beta, zero_centered_gamma, epsilon, quantizer=ref_quantizer + x, + gamma, + beta, + zero_centered_gamma, + epsilon, + quantizer=ref_quantizer, ) else: output, rsigma = tex.rmsnorm_fwd( x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer ) ref_out, ref_rsigma = _jax_rmsnorm( - x, gamma, zero_centered_gamma, epsilon, quantizer=ref_quantizer + x, + gamma, + zero_centered_gamma, + epsilon, + quantizer=ref_quantizer, ) ref_mu = None diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 0e8501abf3..7078cb69de 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -36,6 +36,12 @@ 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) ), # GQA "cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA + "cp_3_0": ModelConfig(2, 4096, 12, 192, attn_mask_type="causal", head_dim_v=128), # MLA + "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA + "cp_3_2": ModelConfig( + 2, 4096, 12, 192, attn_mask_type="causal", window_size=(512, 0), head_dim_v=128 + ), # MLA + "cp_3_3": ModelConfig(2, 4096, 12, 192, window_size=(512, 512), head_dim_v=128), # MLA } @@ -81,6 +87,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) + if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: + pytest.skip("MLA CP currently only support KV P2P!") subprocess.run( get_bash_arguments( diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 0b0732dfa8..0e01f0b04a 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -2,8 +2,11 @@ # # See LICENSE for license information. +import contextlib +import gc import os -from contextlib import nullcontext +from typing import Iterable, Optional + import pytest import torch @@ -11,15 +14,16 @@ from transformer_engine.common import recipe from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported from utils import ModelConfig, get_available_attention_backends -# Check if FP8 is supported +# Check supported quantization schemes fp8_available, _ = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() -fp8_recipes = [None] +quantization_recipes: Optional[recipe.Recipe] = [None] if fp8_available: - fp8_recipes.append(recipe.Float8CurrentScaling()) - fp8_recipes.append(recipe.DelayedScaling()) + quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling())) model_config = { "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), @@ -48,85 +52,139 @@ "transformer_layer": lambda: te.TransformerLayer( SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 ), + "linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + "layernorm_mlp_ops": lambda: te.ops.Sequential( + te.ops.LayerNorm(SIZE, dtype=torch.bfloat16), + te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + te.ops.GELU(), + te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + ), } -def _get_input(): - return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda() +def _make_input() -> torch.Tensor: + """Generate random input tensor.""" + return torch.randn( + (128, SIZE, SIZE), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) -def _get_fp8_weight_cache_size(models, fp8_recipe): - """ - Calculate the total FP8 weight cache size (in MB) for a list of models. - """ - if fp8_recipe is None: +def _warmup_model( + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], +) -> None: + """Perform forward and backward pass""" + tensor = _make_input() + for module in modules: + with te.fp8_autocast( + enabled=quantization_recipe is not None, + fp8_recipe=quantization_recipe, + ): + tensor = module(tensor) + tensor.sum().backward() + + +def _estimate_cached_weight_size( + model_name: str, + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], +) -> float: + """Calculate the memory (in MiB) needed for weight caching.""" + + # The weight params are cached directly for unquantized compute + if quantization_recipe is None: return 0 - params_bytes = 0 - for model in models: - for name, param in model.named_parameters(): - if "weight" in name: - params_bytes += param.numel() + # Count number of weight param elements + param_elements = 0 + for module in modules: + for param in module.parameters(): + if param.dim() == 2: + param_elements += param.numel() + + # FP8 tensor-scaling caches one byte per element + if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling(): + if not is_non_tn_fp8_gemm_supported() and model_name not in ( + "linear_op", + "layernorm_mlp_ops", + ): + # Modules do not deallocate FP8 transpose for weights + return 2 * param_elements / 1024**2 + return param_elements / 1024**2 + + # MXFP8 caches one data byte per element and one scale byte per 32 + # elements + if quantization_recipe.mxfp8(): + if model_name not in ("linear_op", "layernorm_mlp_ops"): + # Modules do not deallocate column-wise MXFP8 data for weights + return 2 * param_elements * (1 + 1 / 32) / 1024**2 + return param_elements * (1 + 1 / 32) / 1024**2 + + raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})") + + +def _measure_cached_memory( + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], + cpu_offload: bool, +) -> float: + """Measure the growth in allocated GPU memory in MiB after a model forward pass. + + Memory measurement excludes the input and output tensors. - # One byte for columnwise and one byte for rowwise, - # hence multiply by 2 and convert to MB - # there is 1 byte of scale per 32 elements in mxFP8 - factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1 - return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2) + """ + # Reset memory + gc.collect() + torch.cuda.empty_cache() -def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload): - tensor = _get_input() + # Context and sync function for CPU offloading if cpu_offload: offload_context, sync_function = te.get_cpu_offload_context( enabled=True, - num_layers=len(models) - 1, - model_layers=len(models), + num_layers=len(modules), + model_layers=len(modules) + 1, offload_activations=True, offload_weights=False, ) else: - offload_context = nullcontext() + offload_context = contextlib.nullcontext() sync_function = lambda x: x - for model in models: + # Forward pass, with dummy step to trigger offload for last module + inp = _make_input() + tensor = inp + memory_before_forward = torch.cuda.memory_allocated() / (1024**2) + for module in modules: with te.fp8_autocast( - enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe + enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe ), offload_context: - tensor = model(tensor) + tensor = module(tensor) tensor = sync_function(tensor) + with offload_context: + tensor = tensor.clone() + tensor = sync_function(tensor) + memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2) - max_mem_used = torch.cuda.memory_allocated() / (1024**2) - torch.cuda.synchronize() - + # Backward pass tensor.sum().backward() + torch.cuda.synchronize() - return max_mem_used + # Memory usage in MiB + return memory_after_forward - memory_before_forward -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("model_key", model_types.keys()) -def test_cpu_offload(fp8_recipe, model_key) -> None: - """ - We run three configurations: - (1) No offloading: All activations remain on the GPU between forward and backward passes. - (2) No offloading (one layer): Only the first layer's activations remain on the GPU between - forward and backward passes. - (3) With offloading (all layers): Only the last layer's activations remain on the GPU - between forward and backward passes, while all other layers are offloaded to the CPU. - - We expect the memory consumption of configurations (2) and (3) to be similar, with - the difference being the size of the FP8 cache that is not offloaded to the CPU. - We also expect this memory consumption to be smaller than in scenario (1). - """ - import gc +@pytest.mark.parametrize("quantization_recipe", quantization_recipes) +@pytest.mark.parametrize("model_name", model_types.keys()) +def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None: + """Check that CPU offloading runs and has expected memory usage.""" - gc.collect() - - model_cls = model_types[model_key] - models_list = [model_cls() for _ in range(NUM_LAYERS)] - - if model_key in ["multihead_attention", "transformer_layer"]: + # Construct model + modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)] + if model_name in ["multihead_attention", "transformer_layer"]: available_backends, *_ = get_available_attention_backends( model_config["small"], qkv_dtype=torch.bfloat16, @@ -138,20 +196,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: os.environ["NVTE_FLASH_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True - without_offloading = _measure_memory_between_forward_and_backward( - models_list, fp8_recipe, False - ) - without_offloading_one_layer = _measure_memory_between_forward_and_backward( - models_list[:1], fp8_recipe, False - ) - with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True) + # Warmup + _warmup_model(modules_list, quantization_recipe) - assert with_offloading < without_offloading + # Measure cached memory after forward pass + memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False) + memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True) - # The only difference between the memory consumption of with_offloading - # and without_offloading_one_layer should be the size of the FP8 weights cache, - # which is not offloaded to the CPU. - memory_consumption_diff = abs(with_offloading - without_offloading_one_layer) - assert ( - memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON + # Check for expected memory usage + assert memory_with_offload < memory_without_offload + memory_from_cached_weights = _estimate_cached_weight_size( + model_name, + modules_list, + quantization_recipe, ) + assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index eb4edc5cbd..3812eb28f8 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1746,11 +1746,11 @@ def test_clamped_swiglu( ) # A low value of limit = 0.1 is used for this test instead of the original # default = 7.0 used in GPT OSS. This is because low value kills decent number - # of gradients allowing us to check for correctness of gradient computation of + # of gradients allowing us to check for correctness of gradient computation of # ClampedSwiGLU. limit = 0.1 alpha = 1.702 - + # Plain PyTorch implementation x_glu, x_linear = x_ref.chunk(2, dim=-1) x_glu = x_glu.clamp(min=None, max=limit) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index e720673675..a0e285b913 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -39,16 +39,21 @@ from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.common import recipe import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states, get_available_attention_backends + # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() sm_80plus = get_device_compute_capability() >= (8, 0) @@ -120,6 +125,11 @@ fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) +use_cutlass_grouped_gemm = [False] +# Only enable cutlass grouped gemm on Hopper +if torch.cuda.get_device_capability() == (9, 0): + use_cutlass_grouped_gemm.append(True) + def is_fused_attn_available( config: ModelConfig, @@ -1800,6 +1810,7 @@ def test_grouped_linear_accuracy( bias, delay_wgrad_compute, parallel_mode=None, + use_cutlass=False, ): fp8 = recipe is not None if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: @@ -1871,9 +1882,47 @@ def test_grouped_linear_accuracy( delay_wgrad_compute, ) - # Shoule be bit-wise match - for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + for o, o_ref in zip(outputs, outputs_ref): + if use_cutlass: + torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + else: + # cuBLAS implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + +@pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0), + reason="Only enable CUTLASS grouped gemm on Hopper", +) +@pytest.mark.parametrize("dtype", param_types, ids=str) +@pytest.mark.parametrize("num_gemms", [3, 6]) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) +def test_grouped_linear_accuracy_cutlass( + dtype, + num_gemms, + bs, + model, + fuse_wgrad_accumulation, + delay_wgrad_compute, +): + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + test_grouped_linear_accuracy( + dtype, + num_gemms, + bs, + model, + None, + False, + fuse_wgrad_accumulation, + False, + delay_wgrad_compute, + None, + use_cutlass=True, + ) + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) @pytest.mark.parametrize("dtype", param_types, ids=str) @@ -2537,10 +2586,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): (16, 10027, 128, 512), ], ) -@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) @pytest.mark.parametrize("accumulate", [False, True]) -def test_grouped_gemm(shape, dtype, layout, accumulate): +@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm) +def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): torch.manual_seed(0) z, m, k, n = shape @@ -2575,6 +2625,9 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): grad = True single_output = False + if use_cutlass: + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + for i in range(z): general_gemm( A[i], @@ -2602,9 +2655,82 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): single_output=single_output, ) - # should be bit-wise match for o, o_ref in zip(out, out_ref): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + if not use_cutlass: + # cublas implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + else: + torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + + if use_cutlass: + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + + +@pytest.mark.parametrize("N", [32]) +@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + "input_quantizer", + [ + Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), + ], +) +@pytest.mark.parametrize( + "out_quantizer", + [ + Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), + Float8Quantizer( + torch.ones(1).cuda().squeeze(), torch.ones(1).cuda().squeeze(), tex.DType.kFloat8E4M3 + ), + ], +) +def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_quantizer): + # For MXFP8 and CurrentScaling, below unfused quantization should happen + # FP8 input --> cublas GEMM --> BF16 output --> Quantize to FP8 --> fp8 Output + # Skip invalid configurations + is_mxfp8_needed = isinstance(input_quantizer, MXFP8Quantizer) or isinstance( + out_quantizer, MXFP8Quantizer + ) + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if is_mxfp8_needed and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + inp_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype)) + weight_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype)) + outp_type = torch.float32 + quantized_out, *_ = general_gemm( + weight_fp8, + inp_fp8, + get_workspace(), + outp_type, + quantization_params=out_quantizer, + bias=None, + use_split_accumulator=False, + ) + + out, *_ = general_gemm( + weight_fp8, + inp_fp8, + get_workspace(), + outp_type, + quantization_params=None, + bias=None, + use_split_accumulator=False, + ) + expected_quantized_out = out_quantizer(out) + + # Match results again Pytorch GEMM and allow for quantization tolerance + pytorch_out = torch.matmul( + inp_fp8.dequantize().to(torch.float64), + torch.transpose(weight_fp8.dequantize().to(torch.float64), 0, 1), + ) + fp8_tols = dict(rtol=0.125, atol=0.0675) + torch.testing.assert_close( + pytorch_out.to(outp_type), expected_quantized_out.dequantize(), **fp8_tols + ) + # Match results between quantization happening inside vs outside general_gemm + torch.testing.assert_close(expected_quantized_out.dequantize(), quantized_out.dequantize()) @pytest.mark.parametrize( diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cb9f13b899..08e876404c 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -45,6 +45,11 @@ if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") endif() include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +set(CUTLASS_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/include") +set(CUTLASS_TOOLS_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include") + # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) @@ -81,6 +86,7 @@ list(APPEND transformer_engine_SOURCES fused_attn/fused_attn.cpp fused_attn/utils.cu gemm/cublaslt_gemm.cu + gemm/cutlass_grouped_gemm.cu normalization/common.cpp normalization/layernorm/ln_api.cpp normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -121,18 +127,30 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") - +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + set_source_files_properties( + "gemm/cutlass_grouped_gemm.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") +else() + message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") +endif() # Configure dependencies target_link_libraries(transformer_engine PUBLIC CUDA::cublas CUDA::cudart CUDNN::cudnn_all) + target_include_directories(transformer_engine PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine SYSTEM PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") +target_include_directories(transformer_engine PRIVATE + ${CUTLASS_INCLUDE_DIR} + ${CUTLASS_TOOLS_INCLUDE_DIR}) # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index d1d21be63d..0264bc9fbb 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -36,7 +36,7 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp } void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, - cudaStream_t stream) { + cudaStream_t stream) { NVTE_API_CALL(nvte_clamped_swiglu); using namespace transformer_engine; ClampedSwiGLUParam param = {limit, alpha}; @@ -44,10 +44,10 @@ void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, } void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, - float limit, float alpha, cudaStream_t stream) { + float limit, float alpha, cudaStream_t stream) { NVTE_API_CALL(nvte_clamped_dswiglu); using namespace transformer_engine; ClampedSwiGLUParam param = {limit, alpha}; - dgated_act_fn, oss_dsilu>(grad, input, output, - param, stream); + dgated_act_fn, oss_dsilu>( + grad, input, output, param, stream); } diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index d90dd3abc1..ec29e6e120 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -607,17 +607,21 @@ void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStr int comm_bytes_per_rank = comm_bytes / _tp_size; // We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush - userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm, - send_stream); - userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm, - recv_stream); + userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, + _ub_comm, send_stream); + userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, + _ub_comm, recv_stream); + // We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf for (auto stream : {send_stream, recv_stream}) { NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, stream)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); - // We sync with the comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _stop_comm, 0)); } + + // Next we sync with the main stream + // We have to recapture an event off the comm stream to enable cuda graph capture otherwise the comm stream will be never be joined in the graph + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); } /*************************************************************************************************** diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 17f3cf658e..1dcd54d0d7 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -2542,25 +2542,27 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, - int tp_size, communicator *comm, cudaStream_t stream) { + int tp_size, int world_rank, communicator *comm, cudaStream_t stream) { + int rank_round_tp = (world_rank / tp_size) * tp_size; for (int j = 1; j < tp_size; j++) { int i = (tp_rank + j) % tp_size; int send_offset = srcoffset + bytes_per_slice * tp_rank; int recv_offset = dstoffset + bytes_per_slice * tp_rank; - userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i, - stream); + userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, + rank_round_tp + i, stream); } } void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, - int tp_size, communicator *comm, cudaStream_t stream) { + int tp_size, int world_rank, communicator *comm, cudaStream_t stream) { + int rank_round_tp = (world_rank / tp_size) * tp_size; for (int j = tp_size - 1; j > 0; j--) { int i = (tp_rank + j) % tp_size; int send_offset = srcoffset + bytes_per_slice * i; int recv_offset = dstoffset + bytes_per_slice * i; - userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i, - stream); + userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, + rank_round_tp + i, stream); } } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 8077f90be8..4d52fbb644 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -306,10 +306,10 @@ void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cuda void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, - int tp_size, communicator *comm, cudaStream_t stream); + int tp_size, int world_rank, communicator *comm, cudaStream_t stream); void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, - int tp_size, communicator *comm, cudaStream_t stream); + int tp_size, int world_rank, communicator *comm, cudaStream_t stream); #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 60b10862e6..795697635d 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -251,11 +251,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 91100)) && - // 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA - (!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200 || - cudnn_runtime_version == 91300) && - is_training && sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 && - !(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) && + // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA + // Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed + (!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 && + head_dim_qk >= 128 && head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && + head_dim_qk != head_dim_v))) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || (cudnn_runtime_version >= 8906 && diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9e6c5417bc..f287072bcb 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -19,6 +19,7 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "common/util/cuda_runtime.h" +#include "cutlass_grouped_gemm.cuh" namespace { @@ -650,9 +651,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor CUBLAS_VERSION); #endif NVTE_CHECK( - cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, + transformer_engine::cuda::cudart_version() >= 12020 && + transformer_engine::cuda::cudart_version() < 13000, "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ", - cuda::cudart_version()); + transformer_engine::cuda::cudart_version()); NVTE_CHECK( cublas_version() >= 120205 && cublas_version() < 130000, "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", @@ -675,13 +677,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor n_split, gemm_producer, inputCounter, stream); } -void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, - const NVTETensor *bias, NVTETensor *pre_gelu_out, - const int num_gemms, bool transa, bool transb, bool grad, - NVTETensor *workspace, bool accumulate, - bool use_split_accumulator, int math_sm_count, - cudaStream_t stream) { - NVTE_API_CALL(nvte_multi_stream_cublas_gemm); +void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor *workspace, + bool accumulate, bool use_split_accumulator, int math_sm_count, + cudaStream_t stream) { using namespace transformer_engine; int num_streams = nvte_get_num_compute_streams(); @@ -711,6 +711,25 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT } } +void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, + const int num_gemms, bool transa, bool transb, bool grad, + NVTETensor *workspace, bool accumulate, + bool use_split_accumulator, int math_sm_count, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_stream_cublas_gemm); + using namespace transformer_engine; + + // Deprecation warning + NVTE_WARN( + "nvte_multi_stream_cublas_gemm is deprecated and will be removed in a future release. " + "Please migrate to nvte_multi_tensor_gemm (with CUTLASS Grouped GEMM support when " + "applicable)."); + + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace, + accumulate, use_split_accumulator, math_sm_count, stream); +} + namespace transformer_engine { using cublasHandleManager = detail::HandleManager; @@ -718,3 +737,85 @@ using cublasHandleManager = detail::HandleManager("NVTE_USE_CUTLASS_GROUPED_GEMM", false); + const bool warn_fallback = + transformer_engine::getenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false); + + auto cublas_path = [&]() { + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, + workspace, accumulate, use_split_accumulator, math_sm_count, stream); + }; + + // Currently only support cutlass group gemm on Hopper Arch + if (!(is_hopper && use_cutlass)) { + cublas_path(); + return; + } + + auto is_empty_arr = [&](const NVTETensor *p) -> bool { + if (p == nullptr) return true; + for (int i = 0; i < num_gemms; ++i) { + if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false; + } + return true; + }; + + auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool { + int64_t ref_k = -1; + for (size_t i = 0; i < num_gemms; i++) { + const auto tensor = transformer_engine::convertNVTETensorCheck(p[i]); + const int k = trans ? tensor->data.shape[0] : tensor->data.shape[1]; + + if ((k & 127) != 0) return false; + + if (ref_k < 0) + ref_k = k; + else if (k != ref_k) + return false; + } + + return true; + }; + + auto is_supported_dtype = [&]() -> bool { + auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); + auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); + auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); + auto A_type = get_cuda_dtype(inputA->data.dtype); + auto B_type = get_cuda_dtype(inputB->data.dtype); + auto D_type = get_cuda_dtype(OutputD->data.dtype); + + return (A_type == B_type) && (A_type == D_type) && + ((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F)); + }; + + // CUTLASS Grouped GEMM fast path (SM90/TMA) + // Conditions: + // - No fused epilogue: both bias and pre_gelu_out are empty. + // - Supported dtypes only: FP16/BF16 (FP32 accumulate). + // - Uniform K across groups and K % 128 == 0. + // - use_split_accumulator is ignored for FP16/BF16. + // - grad is irrelevant when bias/pre_gelu_out are empty. + // + // Otherwise, fall back to cuBLAS. + if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && + all_groups_uniform_k128(B, transb)) { + cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate, + current_device, math_sm_count, stream); + } else { + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } + cublas_path(); + } +} diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cu b/transformer_engine/common/gemm/cutlass_grouped_gemm.cu new file mode 100644 index 0000000000..18736c4f54 --- /dev/null +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cu @@ -0,0 +1,77 @@ +/*************************************************************************************************** + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + **************************************************************************************************/ + +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass_grouped_gemm.cuh" + +namespace transformer_engine { +namespace grouped_gemm { + +// Explicit template instantiation to match the template declarations in the .cuh +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, const NVTETensor*, + NVTETensor*, NVTETensor*, float, + float, int, cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, const NVTETensor*, + NVTETensor*, NVTETensor*, float, + float, int, cudaStream_t, int, int); + +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); + +} // namespace grouped_gemm +} // namespace transformer_engine + +void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, int device, int math_sm_count, cudaStream_t stream) { + using namespace transformer_engine; + auto* inputA = convertNVTETensorCheck(A[0]); + auto* inputB = convertNVTETensorCheck(B[0]); + + float one = 1.0; + float zero = 0.0; + float alpha = one; + float beta = (accumulate) ? one : zero; + + auto dispatch = [&](auto tag) { + using T = decltype(tag); + if (!transa && !transb) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else if (!transb && transa) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else if (transb && !transa) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else { + NVTE_ERROR("Layout 'TT' is not supported by cutlass_grouped_gemm."); + } + }; + + if (inputA->data.dtype == DType::kBFloat16) { + dispatch(cutlass::bfloat16_t{}); + } else if (inputA->data.dtype == DType::kFloat16) { + dispatch(cutlass::half_t{}); + } else { + NVTE_ERROR("Unsupported dtype: only BF16(FP16) are supported."); + } +} diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh new file mode 100644 index 0000000000..1add571325 --- /dev/null +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + **************************************************************************************************/ + +// +// Copyright (c) 2025 Shopee Inc. All Rights Reserved. +// + +/** + * @file: cutlass_grouped_gemm.cuh + * @author: min.yang@shopee.com, yangfan.bai@shopee.com, finch.li@shopee.com + * @date: 2025-08-08 16:20:00 + * @brief: cutlass group gemm kernel. + **/ + +#pragma once + +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "common/util/system.h" +#include "cute/tensor.hpp" +#include "cutlass/bfloat16.h" +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +template +using GroupedGemmInputALayout = + std::conditional_t; + +template +using GroupedGemmInputBLayout = + std::conditional_t; + +using ProblemShapeType = cute::Shape; +using ProblemShape = cutlass::gemm::GroupProblemShape; // per group +template +struct GemmGivenSchedule { + using ElementA = typename ScheduleConfig::DataType; // Element type for A matrix operand + using ElementB = typename ScheduleConfig::DataType; // Element type for B matrix operand + using ElementC = typename ScheduleConfig::DataType; // Element type for C and D matrix operands + + // A matrix configuration + using LayoutA = typename ScheduleConfig::LayoutA; // Layout type for A matrix operand + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits< + ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using LayoutB = typename ScheduleConfig::LayoutB; // Layout type for B matrix operand + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using LayoutC = typename ScheduleConfig::LayoutC; // Layout type for C and D matrix operands + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes) + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = + cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + + using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size + using ClusterShape = + typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster + using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +struct ScheduleConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + // TODO(Alan): Add tuning for different scenarios to select the optimal configuration, + // as the current configuration may not be the best. + + // using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + // using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + // using TileShape = Shape; + // using ClusterShape = Shape; + + using LayoutA = GroupedGemmInputALayout; + using LayoutB = GroupedGemmInputBLayout; + using LayoutC = cutlass::layout::RowMajor; + using DataType = DataType_; +}; + +template +using GemmGrouped = typename GemmGivenSchedule>::Gemm; + +template +typename GemmT::Arguments MakeArguments(int num_experts, void* problem_sizes_host, + void* problem_sizes, const ElementA** ptr_A, + StrideA* stride_A, const ElementB** ptr_B, + StrideB* stride_B, ElementC** ptr_C, StrideC* stride_C, + float alpha, float beta, int device, int math_sm_count) { + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + + cutlass::KernelHardwareInfo kernel_hw_info = + cutlass::KernelHardwareInfo::make_kernel_hardware_info( + device, math_sm_count); + + typename GemmT::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + + fusion_args.alpha = alpha; + fusion_args.beta = beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + + arguments = + typename GemmT::Arguments{cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, reinterpret_cast(problem_sizes), + reinterpret_cast(problem_sizes_host)}, + {ptr_A, stride_A, ptr_B, stride_B}, + { + fusion_args, + (beta > 0.0) ? (const ElementC**)ptr_C : nullptr, // NOLINT(*) + stride_C, + ptr_C, + stride_C, + }, + kernel_hw_info}; + + return arguments; +} + +template +inline __device__ __host__ T ROUND_UP(T m, T n) { + return (m + n - 1) / n * n; +} + +template +void debug_type() { + std::cout << typeid(T).name() << std::endl; +} + +int64_t inline getGemmCoordSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(ProblemShapeType), 128UL)); +} + +int64_t inline getPtrSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(half*), 128UL)); +} + +int64_t inline getLddSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(int64_t), 128UL)); +} + +// cpu workspace size is 4MB +static constexpr size_t kCPUWorkSpaceSize = 4 * 1024 * 1024; + +static char* getHostWorkspace() { + static std::once_flag flag; + static std::shared_ptr workspace; + + std::call_once(flag, [&]() { + workspace = + std::shared_ptr(reinterpret_cast(std::malloc(kCPUWorkSpaceSize)), [](char* p) { + if (p) std::free(p); + }); + + if (!workspace) { + throw std::bad_alloc(); + } + }); + + return workspace.get(); +} + +template +void CutlassGroupedGemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + NVTETensor* workspace, float alpha, float beta, int num_gemms, + cudaStream_t stream, int device, int math_sm_count) { + using Gemm = GemmGrouped; + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + + typename Gemm::Arguments arguments; + size_t kernel_workspace_size = Gemm::get_workspace_size(arguments); + auto gemm_coord_size = getGemmCoordSize(num_gemms); + auto ptr_size = getPtrSize(num_gemms); + auto ldd_size = getLddSize(num_gemms); + auto param_workspace_size = 3 * ptr_size + 3 * ldd_size + gemm_coord_size; + + NVTE_CHECK( + param_workspace_size < kCPUWorkSpaceSize, + "Insufficient kCPUWorkSpaceSize size: required=", static_cast(param_workspace_size), + ", available=", static_cast(kCPUWorkSpaceSize), " for CUTLASS grouped GEMM."); + + auto total_workspace_size = param_workspace_size + kernel_workspace_size; + transformer_engine::Tensor* wspace = transformer_engine::convertNVTETensor(workspace[0]); + + NVTE_CHECK(total_workspace_size < wspace->numel(), "Insufficient workspace[0] size: required=", + static_cast(total_workspace_size), + ", available=", static_cast(wspace->numel()), " for CUTLASS grouped GEMM."); + + char* workspace_ptr = reinterpret_cast(wspace->data.dptr); + + char* kernel_workspace_ptr = nullptr; + + char* host_workspace = getHostWorkspace(); + + ProblemShapeType* problem_sizes_host = reinterpret_cast(host_workspace); + + ElementA** ptr_A_host = reinterpret_cast(host_workspace + gemm_coord_size); + ElementB** ptr_B_host = reinterpret_cast(host_workspace + gemm_coord_size + ptr_size); + ElementC** ptr_C_host = + reinterpret_cast(host_workspace + gemm_coord_size + 2 * ptr_size); + int64_t* lda_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 0 * ldd_size); + int64_t* ldb_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 1 * ldd_size); + int64_t* ldc_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 2 * ldd_size); + + for (size_t i = 0; i < num_gemms; i++) { + const transformer_engine::Tensor* inputA = transformer_engine::convertNVTETensorCheck(A[i]); + const transformer_engine::Tensor* inputB = transformer_engine::convertNVTETensorCheck(B[i]); + transformer_engine::Tensor* outputD = transformer_engine::convertNVTETensor(D[i]); + + const int m = trans_a ? inputA->data.shape[1] : inputA->data.shape[0]; + const int k = trans_a ? inputA->data.shape[0] : inputA->data.shape[1]; + const int n = trans_b ? inputB->data.shape[0] : inputB->data.shape[1]; + + auto problem = ProblemShapeType(m, n, k); + problem_sizes_host[i] = problem; + + ptr_A_host[i] = reinterpret_cast(inputA->data.dptr); + ptr_B_host[i] = reinterpret_cast(inputB->data.dptr); + ptr_C_host[i] = reinterpret_cast(outputD->data.dptr); + + lda_host[i] = LayoutA::packed({m, k}).stride(0); + ldb_host[i] = LayoutB::packed({k, n}).stride(0); + ldc_host[i] = LayoutC::packed({m, n}).stride(0); + } + + cudaMemcpyAsync(workspace_ptr, host_workspace, param_workspace_size, cudaMemcpyHostToDevice, + stream); + + char* param_workspace_ptr = workspace_ptr; + ProblemShapeType* problem_sizes_device = reinterpret_cast(param_workspace_ptr); + const ElementA** ptr_A = reinterpret_cast( + reinterpret_cast(param_workspace_ptr) + gemm_coord_size); + const ElementB** ptr_B = reinterpret_cast( + reinterpret_cast(param_workspace_ptr) + gemm_coord_size + 1 * ptr_size); + ElementC** ptr_C = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 2 * ptr_size); + + StrideA* lda = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 0 * ldd_size); + StrideB* ldb = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 1 * ldd_size); + StrideC* ldc = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 2 * ldd_size); + + kernel_workspace_ptr = workspace_ptr + param_workspace_size; + + arguments = MakeArguments( + num_gemms, problem_sizes_host, problem_sizes_device, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + alpha, beta, device, math_sm_count); + + Gemm gemm; + + // Check can implement the kernel. + if (gemm.can_implement(arguments) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to implement CUTLASS Grouped GEMM"); + } + + // Initialize the kernel. + if (gemm.initialize(arguments, kernel_workspace_ptr) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM"); + } + + // Execute the kernel in the current stream. + if (gemm.run(stream) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to run CUTLASS Grouped GEMM"); + } +} + +} // namespace grouped_gemm +} // namespace transformer_engine + +void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, int device, int math_sm_count, cudaStream_t stream); diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 3b74b8f195..e50d71040d 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -191,7 +191,7 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) * \param[in] stream CUDA stream used for the operation. */ void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, - cudaStream_t stream); + cudaStream_t stream); /*! \brief Computes the gated ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -268,7 +268,7 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp * \param[in] stream CUDA stream used for the operation. */ void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, - float limit, float alpha, cudaStream_t stream); + float limit, float alpha, cudaStream_t stream); /*! \brief Computes the gated ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 50b33909fb..0c358328b6 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -133,12 +133,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] stream CUDA stream to wait on. */ -void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, - const NVTETensor* bias, NVTETensor* pre_gelu_out, - const int num_gemms, bool transa, bool transb, bool grad, - NVTETensor* workspace, bool accumulate, - bool use_split_accumulator, int math_sm_count, - cudaStream_t stream); +void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, bool use_split_accumulator, int math_sm_count, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index af19300a96..398c0acbdd 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -66,7 +66,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { - cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output + NVTE_CHECK(!cudnn_backend, + "cuDNN does not currently support amax output for non quantized output"); } bool gamma_in_weight_dtype = false; diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 1aae72e152..82e360ed64 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -52,7 +52,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { - cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output + NVTE_CHECK(!cudnn_backend, + "cuDNN does not currently support amax output for non quantized output"); } bool training = diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 4c82b8c81b..d38bf79963 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -579,14 +579,19 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor "Input and output_t must have the same shape for columnwise non-transpose case."); } } - - NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype."); + if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) { + // output may not be defined if rowwise quantization is not needed. + NVTE_CHECK(output.dtype == output_t.dtype, + "output and output_t need to have the same dtype."); + } NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2."); bool columnwise_compact = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT; size_t scale_t_k = scale_inv_t.shape[1]; scale_t_stride_x = columnwise_compact ? 1 : scale_t_k; scale_t_stride_y = columnwise_compact ? scale_t_k : 1; } + auto output_dtype = + rowwise_option != FP8BlockwiseRowwiseOption::NONE ? output.dtype : output_t.dtype; const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); @@ -597,7 +602,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor input.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output.dtype, OutputType, + output_dtype, OutputType, dim3 grid(num_blocks_x, num_blocks_y, 1); diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index d3c7d2b086..cdda201668 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1045,7 +1045,7 @@ def act_lu( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = act_lu( - x=x.astype(jnp.float32), + x=x, activation_type=activation_type, quantizer=None, ) @@ -1178,8 +1178,8 @@ def quantize_dact_dbias( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( - dz=dz.astype(jnp.float32), - x=x.astype(jnp.float32), + dz=dz, + x=x, activation_type=activation_type, quantizer=None, ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index de1877de5c..7a978c1b74 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -842,6 +842,8 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) output = normed_input * gamma + beta if quantizer: + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + output = output.astype(x.dtype) ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) @@ -867,6 +869,8 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): output = normed_input * gamma if quantizer: + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + output = output.astype(x.dtype) ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 113072131d..06dded1d86 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -526,10 +526,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); } - nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, - lhs_is_trans, grad, workspace_list.data(), accumulate, - use_split_accumulator, num_math_sm, stream); + nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), + pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, + grad, workspace_list.data(), accumulate, use_split_accumulator, + num_math_sm, stream); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index f00bd573f1..09384217c6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -358,7 +358,7 @@ def get_fa_args( max_seqlen_q, max_seqlen_kv, *[None] - * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + * 9, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale ] return [ *[None] @@ -366,7 +366,7 @@ def get_fa_args( max_seqlen_q, max_seqlen_kv, *[None] - * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + * 9, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale ] if qkv_format == "thd": return [ @@ -829,6 +829,19 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -838,19 +851,10 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=True, **fa_forward_kwargs, @@ -985,6 +989,22 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + else: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1001,19 +1021,10 @@ def forward( elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=False, **fa_forward_kwargs, @@ -1144,6 +1155,19 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1160,19 +1184,10 @@ def forward( elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=False, **fa_forward_kwargs, @@ -1269,6 +1284,19 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1278,19 +1306,10 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q, - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=False, **fa_forward_kwargs, @@ -1865,7 +1884,27 @@ def backward(ctx, dout): dv_ = dv_._data else: dq_ = torch.empty_like(q_) - dkv_ = torch.empty_like(kv_) + if ctx.enable_mla: + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + dkv_ = torch.empty_like(kv_) + dk_ = ( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ) + dv_ = ( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ) fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -1875,16 +1914,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ), - dv=( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ), + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or ( fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus @@ -1895,12 +1926,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = 0 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + k_part, + v_part, out_, softmax_lse, *fa_backward_args_thd, @@ -2016,7 +2046,29 @@ def backward(ctx, dout): dv_ = dv_._data else: dq_ = torch.empty_like(q_) - dkv_ = torch.empty_like(kv_) + if ctx.enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + dkv_ = torch.empty_like(kv_) + dk_ = ( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ) + dv_ = ( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ) fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -2026,16 +2078,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv // 2, dq=dq_, - dk=( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ), - dv=( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ), + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or ( fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus @@ -2046,12 +2090,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + k_part, + v_part, out_, softmax_lse, *fa_backward_args_thd, @@ -2160,7 +2203,27 @@ def backward(ctx, dout): dv_ = dv_._data else: dq_ = torch.empty_like(q_) - dkv_ = torch.empty_like(kv_) + if ctx.enable_mla: + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + dkv_ = torch.empty_like(kv_) + dk_ = ( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ) + dv_ = ( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ) fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -2170,16 +2233,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q // 2, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ), - dv=( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ), + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or ( fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus @@ -2190,12 +2245,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + k_part, + v_part, out_, softmax_lse_, *fa_backward_args_thd, @@ -2267,7 +2321,15 @@ def backward(ctx, dout): else: dq_ = torch.empty_like(q) - dkv_ = torch.empty_like(kv) + if ctx.enable_mla: + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] + v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] + dkv_ = torch.empty_like(kv) + dk_ = dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0] + dv_ = dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1] fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -2277,8 +2339,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], - dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) @@ -2287,12 +2349,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout, q, - kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], - kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], + k_part, + v_part, out, softmax_lse, *fa_backward_args_thd, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7097f4ba0f..9b2b9a1ac3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -126,10 +126,10 @@ class FlashAttentionUtils: # Please follow these instructions to install FA3 v3_installation_steps = """\ (1) git clone https://github.com/Dao-AILab/flash-attention.git -(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install +(2) cd flash-attention/ && git checkout 3ba6f82 && git submodule update --init && cd hopper/ && python setup.py install (3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` (4) mkdir -p $python_path/flash_attn_3 -(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py""" +(5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py""" v3_warning_printed = False @staticmethod @@ -434,8 +434,10 @@ def get_attention_backend( # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - if device_compute_capability == (8, 9) and cudnn_version <= (9, 13, 0): - logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.13") + # Temporarily disabling fused attention for kv caching for sm89 irrespective of cuDNN version + # until the cuDNN bug is resolved + if device_compute_capability == (8, 9): + logger.debug("Disabling FusedAttention for KV caching for sm89") use_fused_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") @@ -477,11 +479,10 @@ def get_attention_backend( # Filter: Head dimension if head_dim_qk != head_dim_v: - if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( - use_flash_attention_3 and FlashAttentionUtils.v3_is_installed - ): - logger.debug("Disabling FlashAttention as it does not support MLA.") - use_flash_attention = False + if use_flash_attention_2 and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention 2 as it does not support MLA.") + use_flash_attention_2 = False + qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") if use_fused_attention and qkv_layout_group != "hd_hd_hd": logger.debug( @@ -508,10 +509,41 @@ def get_attention_backend( ".".join([str(i) for i in device_compute_capability]), ) use_flash_attention_2 = False - if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128): - if FlashAttentionUtils.v3_is_installed: - logger.debug("Disabling FlashAttention 3 for head_dim > 128") - use_flash_attention_3 = False + if use_flash_attention_3: + + def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype): + if head_dim_qk > 256 or num_heads % num_gqa_groups != 0: + return False + if head_dim_qk != head_dim_v: + cond1 = 128 < head_dim_qk <= 192 + cond2 = 96 < head_dim_v <= 128 + cond3 = head_dim_qk <= 64 and head_dim_v <= 512 + if not ((cond1 and cond2) or cond3): + return False + if head_dim_v > 256 and qkv_dtype not in (torch.bfloat16, torch.float16): + return False + return True + + if not _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype): + if FlashAttentionUtils.v3_is_installed: + logger.debug( + "Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, " + "head_dim_qk, head_dim_v or qkv_dtype. " + "Supported: head_dim_qk <= 256, and num_heads %% num_gqa_groups = 0, and " + "if head_dim_qk is different from head_dim_v, then " + "(head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or " + "(head_dim_qk <= 64 and head_dim_v <= 512), and " + "if head_dim_qk is different from head_dim_v and head_dim_v > 256, then " + "qkv_dtype requires fp16 and bf16 data type. " + "Found: num_heads = %s, num_gqa_groups = %s, " + "head_dim_qk = %s, head_dim_v = %s and qkv_dtype = %s.", + num_heads, + num_gqa_groups, + head_dim_qk, + head_dim_v, + qkv_dtype, + ) + use_flash_attention_3 = False # Filter: QKV layout if qkv_format == "thd": diff --git a/transformer_engine/pytorch/attention/inference.py b/transformer_engine/pytorch/attention/inference.py index 8d5417a45c..f0ef8d0bd5 100644 --- a/transformer_engine/pytorch/attention/inference.py +++ b/transformer_engine/pytorch/attention/inference.py @@ -215,6 +215,17 @@ def __init__( device=torch.cuda.current_device(), ) + # This internal buffer holds the running length of each + # unfinished sequence in the batch and is updated in `pre_step()` + # method. One use of this buffer is applying RoPE to q and k tensors + # during inference by slicing ROPE Embeddings according to the + # current sequence length window. + self.pre_step_seqlens = torch.zeros( + self.max_batch_size, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + def reset(self): """Reset InferenceParams state""" self.sequences = OrderedDict() @@ -266,6 +277,15 @@ def pre_step( for k, v in self.sequences.items(): self.sequences_pre_step[k] = v - step_dict[k] + pre_step_seqlens_temp = torch.Tensor(list(self.sequences_pre_step.values())).to( + dtype=torch.int32, device="cpu" + ) + + # Copy the pre-step seqlens to the device in CUDA Graphs safe manner. + self.pre_step_seqlens[: len(pre_step_seqlens_temp)].copy_( + pre_step_seqlens_temp, non_blocking=False + ) + seqlens_q = list(step_dict.values()) cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)] cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size) @@ -280,9 +300,7 @@ def pre_step( def get_seqlens_pre_step(self): """Get cached sequence lengths before the stepping""" - return torch.Tensor(list(self.sequences_pre_step.values())).to( - dtype=torch.int32, device="cpu" - ) + return self.pre_step_seqlens def convert_paged_to_nonpaged(self, layer_number: int): """ @@ -458,14 +476,14 @@ def pre_step( finished_seqs = self.sequences.keys() - unfinished_seqs unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] - self.batch_indices.copy_( + self.batch_indices.data[:].copy_( torch.Tensor( ( unfinished_indices + finished_indices + list(range(prev_batch_size, self.max_batch_size)) ) - ).to(dtype=torch.int32, device="cpu") + ) ) # Advance unfinished sequences diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 9c82442af6..5fd16bf1a1 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -889,23 +889,11 @@ def forward( q_pos_emb, k_pos_emb = rotary_pos_emb - # adjust key and value for inference - if inference_params is not None: - if self.qkv_format == "sbhd": - sequence_length = key_layer.size(0) - elif self.qkv_format == "bshd": - sequence_length = key_layer.size(1) - else: - raise ValueError( - f"qkv_format={self.qkv_format} not supported for KV caching and RoPE." - ) - - sequence_start = inference_params.get_seqlens_pre_step() - # sequence_start = inference_params.seqlens[0] - sequence_end = sequence_start + sequence_length - - q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] - k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] + # Applyig RoPE for inference needs start positions of sequences + # for each iteration. + sequence_start_positions = ( + inference_params.get_seqlens_pre_step() if inference_params is not None else None + ) if pad_between_seqs: rotary_pos_cu_seq_lens_q = cu_seqlens_q_padded @@ -922,6 +910,7 @@ def forward( cu_seqlens=rotary_pos_cu_seq_lens_q, cp_size=self.cp_size, cp_rank=self.cp_rank, + start_positions=sequence_start_positions, interleaved=self.rotary_pos_interleaved, ) key_layer = apply_rotary_pos_emb( @@ -932,6 +921,7 @@ def forward( cu_seqlens=rotary_pos_cu_seq_lens_kv, cp_size=self.cp_size, cp_rank=self.cp_rank, + start_positions=sequence_start_positions, interleaved=self.rotary_pos_interleaved, ) diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index d1ba1a351c..064da8a670 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -28,9 +28,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, auto freqs_cu = makeTransformerEngineTensor(freqs); auto output_cu = makeTransformerEngineTensor(output); - auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor + auto start_positions_cu = TensorWrapper(); // empty start_positions tensor if (start_positions) { start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor"); } if (qkv_format == NVTE_QKV_Format::NVTE_THD) { diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index f4768bb9ba..0d18a5ec5b 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -93,6 +93,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans bool use_split_accumulator, CommOverlapCore* comm_overlap, std::optional comm_type, MaybeTensor extra_output, bool bulk_overlap, float alpha, std::optional beta) { + using namespace transformer_engine::pytorch::detail; + // Input tensors NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); @@ -123,10 +125,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans "into D tensor. Beta has nothing to be applied to."); } + DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); // Output tensor TensorWrapper D_tensor; if (D.is_none()) { - DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); } else { D_tensor = makeTransformerEngineTensor(D, quantizer); @@ -139,12 +141,35 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } + // maintain unquantized tensor in case we need unfused quantization support. + TensorWrapper unquantized_D_tensor; + py::object unquantized_out; + // Unfused quantization is needed in the following cases + // 1. Inputs: BF16, Output: FP8 (GEMM output has to be BF16, so FP8 quantization needed after that) + // 2. Inputs: FP8, Output: FP8 (For any quantization apart from delayed scaling, + // GEMM Output needs to be in BF16, to allow for unfused quantization) + bool unfused_quantization_needed = !quantizer.is_none(); + if (low_precision) { + // At the moment, only use-case for fused GEMM: + // Delayed scaling quantizer with per-tensor scaling inputs + bool is_per_tensor_scaling_input = IsFloat8Tensor(A.ptr()) || IsFloat8Tensor(B.ptr()); + if (IsFloat8Quantizers(quantizer.ptr()) && is_per_tensor_scaling_input) + unfused_quantization_needed = false; + } + + if (unfused_quantization_needed) { + NoneQuantizer q{none}; + std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype); + } + TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; + // Bias tensor TensorWrapper bias_tensor; MaybeTensor bias_grad = std::nullopt; if (bias.has_value()) { if (grad) { - auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA); + auto opts = + torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA); bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); bias_tensor = makeTransformerEngineTensor(*bias_grad); } else { @@ -157,7 +182,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Activation input tensor MaybeTensor pre_gelu_out = std::nullopt; - DType gelu_type = low_precision ? bias_type : D_tensor.dtype(); + DType gelu_type = low_precision ? bias_type : out_tensor.dtype(); if (gelu) { if (!grad) { auto dtype = GetATenDType(gelu_type); @@ -210,7 +235,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Direct GEMM call to the correct overlap if (bulk_overlap) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, comm_type.value(), extra_output_tensor, main_stream); @@ -218,14 +243,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else if (comm_type.value() == CommOverlapType::AG) { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -234,14 +259,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -251,15 +276,15 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(), + nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), alpha, *beta, use_split_accumulator, num_math_sms, main_stream); }); } } else { - if (D_tensor.numel() != 0 && !accumulate) { - D_tensor.zero_(main_stream); + if (out_tensor.numel() != 0 && !accumulate) { + out_tensor.zero_(main_stream); } if (bias.has_value()) { if (bias->numel() != 0 && grad) { @@ -267,7 +292,11 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - + if (unfused_quantization_needed) { + // Quantize the output + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + my_quantizer->quantize(unquantized_D_tensor, D_tensor); + } // Pack outputs std::vector out; out.emplace_back(std::move(D)); @@ -448,11 +477,10 @@ std::optional> te_general_grouped_gemm( // For now, we only have multi-stream cublas backend. NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), - te_bias_vector.data(), te_pre_gelu_out_vector.data(), - te_A_vector.size(), transa, transb, grad, - te_workspace_vector.data(), accumulate, use_split_accumulator, - math_sm_count, at::cuda::getCurrentCUDAStream()); + nvte_multi_tensor_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), + te_bias_vector.data(), te_pre_gelu_out_vector.data(), te_A_vector.size(), + transa, transb, grad, te_workspace_vector.data(), accumulate, + use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); }); return bias; } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 59bac8fe5a..c63f892cea 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -110,7 +110,8 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe TensorWrapper unquantized_out_cu; py::object unquantized_out; if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); std::tie(unquantized_out_cu, unquantized_out) = my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); @@ -145,7 +146,8 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); } else { @@ -290,7 +292,8 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w TensorWrapper unquantized_out_cu; py::object unquantized_out; if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); std::tie(unquantized_out_cu, unquantized_out) = my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); @@ -325,7 +328,8 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); } else { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index c690cd522a..cd7e70fecb 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -96,16 +96,6 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); } std::pair Float8Quantizer::create_tensor( @@ -318,17 +308,6 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); - // quantize output and its transpose - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); } std::pair Float8CurrentScalingQuantizer::create_tensor( @@ -562,20 +541,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti this->all_gather_usage = quantizer.attr("all_gather_usage").cast(); } -void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { - // Change the rowwise and columnwise_data to the configured dtype. - // May be a switch between E5M2 and E4M3. - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); -} +void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype) const { @@ -917,18 +883,7 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize this->dtype = quantizer.attr("dtype").cast(); } -void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); -} +void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, DType dtype) const { diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e9189ccc59..5749d96c9f 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -883,7 +883,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] * self.num_gemms weight_quantizers = [ self.quantizers["scaling_fwd"][ diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index cd02f31132..4d30be414e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -353,8 +353,11 @@ def forward( # Deallocate GEMM input tensor if no longer needed if not weight.requires_grad and not return_layernorm_output: - ln_out = ln_out_total = None clear_tensor_data(ln_out, ln_out_total) + ln_out = ln_out_total = None + elif with_input_all_gather and not return_layernorm_output_gathered: + clear_tensor_data(ln_out_total) + ln_out_total = None # ------------------------------------------------------ # Prepare output tensor @@ -891,9 +894,19 @@ def wgrad_gemm( grad_bias = grad_bias_ del grad_bias_ - # Deallocate input tensor if permitted - if not ctx.return_layernorm_output: + # Deallocate input tensors if permitted + if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: + # Input tensors have not been exposed externally + clear_tensor_data(ln_out) + elif ctx.ln_out_needs_gather and ctx.return_layernorm_output_gathered: + # Non-gathered input has not been exposed externally + clear_tensor_data(ln_out) + if ctx.ln_out_needs_gather: + # Gathered input is internal clear_tensor_data(ln_out_total) + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal + clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM if ctx.ub_bulk_wgrad: @@ -1169,7 +1182,9 @@ def __init__( self.return_bias = return_bias self.apply_bias = self.use_bias and not return_bias self.return_layernorm_output = return_layernorm_output - self.return_layernorm_output_gathered = return_layernorm_output_gathered + self.return_layernorm_output_gathered = ( + return_layernorm_output_gathered if return_layernorm_output else False + ) self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type @@ -1767,7 +1782,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a6c55ceb79..9f799c5538 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -445,14 +445,19 @@ def forward( act_out = activation_func(fc1_out, None) act_out = tex.quantize(act_out, fc2_input_quantizer) else: - act_out = activation_func(fc1_out, fc2_input_quantizer) + if fp8_calibration: + act_out = activation_func(fc1_out, None) + else: + act_out = activation_func(fc1_out, fc2_input_quantizer) if not is_grad_enabled: clear_tensor_data(fc1_out) - if fp8_calibration: - fc2_input_quantizer.calibrate(act_out) - fc2_weight_quantizer.calibrate(fc2_weight) + if not fp8 and fp8_calibration: + if fc2_input_quantizer is not None: + fc2_input_quantizer.calibrate(act_out) + if fc2_weight_quantizer is not None: + fc2_weight_quantizer.calibrate(fc2_weight) # Configure Userbuffers reduce-scatter if needed ub_obj_fc2out = None @@ -1897,7 +1902,7 @@ def _get_quantizers(self, fp8_output): fc2_grad_output_quantizer, ) = [None] * 10 fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers() - if self.fp8: + if self.fp8 or self.fp8_calibration: fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer.internal = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] @@ -2114,7 +2119,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None, None] fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] fc1_weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2ce6fb4c1d..7e526245c1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -317,6 +317,13 @@ def forward( # Finished forward GEMM... # ------------------------------------------------------ + # Deallocate GEMM input tensor if no longer needed + # TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically + # deallocated by GC. Manually deallocating is a temporary hack. + if with_input_all_gather_nccl: + clear_tensor_data(inputmat_total) + inputmat_total = None + # ------------------------------------------------------ # Prepare output tensor # Note: Perform tensor-parallel communication @@ -878,9 +885,16 @@ def wgrad_gemm( grad_bias = grad_bias_ del grad_bias_ - # Deallocate input tensor if permitted + # Deallocate tensors if permitted if ctx.owns_input: + # Input tensor is internal clear_tensor_data(inputmat_total) + elif ctx.backward_input_needs_gather: + # Gathered input tensor is internal + clear_tensor_data(inputmat_total) + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal + clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM if ctx.ub_bulk_wgrad: @@ -1643,7 +1657,7 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 8e997428f4..99bbc34c45 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -29,7 +29,9 @@ def maybe_dequantize( if is_quantized_tensor(tensor): return tensor.dequantize(dtype=dtype) if dtype is not None and tensor.dtype != dtype: - return tensor.to(dtype) + tensor = tensor.to(dtype) + if not tensor.is_contiguous(): + tensor = tensor.contiguous() return tensor diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 961b472e0f..2ca42dbcd8 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -11,6 +11,7 @@ import torch import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext @@ -111,6 +112,8 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x) ctx.save_for_backward(x) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -413,7 +416,11 @@ class ClampedSwiGLU(_ActivationOperation): Quantize input tensor when caching for use in the backward pass. """ - def __init__(self, *, limit: float, alpha: float, cache_quantized_input: bool = False): + def __init__(self, + *, + limit: float = 7.0, + alpha: float = 1.702, + cache_quantized_input: bool = False): super().__init__(cache_quantized_input=cache_quantized_input) self.limit = limit self.alpha = alpha diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 8336330558..70c70c54d2 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -13,6 +13,7 @@ import torch from ...cpp_extensions import general_gemm +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import ( CudaRNGStatesTracker, gather_along_first_dim, @@ -964,6 +965,8 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) ctx.save_for_backward(x_local, w) ctx.with_quantized_compute = with_quantized_compute ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index f0f55322c4..30ccf5ebcd 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -9,6 +9,7 @@ import torch import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from .._common import maybe_autocast_dtype, maybe_dequantize @@ -70,6 +71,8 @@ def op_forward( # Save context for backward if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(mask) ctx.save_for_backward(mask) ctx.impl = impl ctx.dropout_probability = self.dropout_probability diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py index a340e7d42a..440fee34d1 100644 --- a/transformer_engine/pytorch/ops/basic/l2normalization.py +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -10,10 +10,8 @@ import torch -from ...utils import clear_tensor_data from ... import torch_version -from .._common import maybe_dequantize -from ..op import BasicOperation, OperationContext +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...jit import ( l2normalization_fused, l2normalization_fwd_fused, @@ -22,6 +20,9 @@ warmup_jit_l2normalization_all_dtypes, ) from ...tensor import Quantizer +from ...utils import clear_tensor_data +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize class L2Normalization(BasicOperation): @@ -101,6 +102,8 @@ def op_forward( # Save state for backward pass if requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x, rsqrt_norm) ctx.save_for_backward(x, rsqrt_norm) return y diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 3d8862e99c..91e6de07d7 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -14,6 +14,9 @@ from transformer_engine_torch import layernorm_bwd, layernorm_fwd from ...constants import TE_DType +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...export import is_in_onnx_export_mode +from ...tensor import Quantizer from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -22,8 +25,6 @@ ) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, maybe_dequantize -from ...export import is_in_onnx_export_mode -from ...tensor import Quantizer class LayerNorm(BasicOperation): @@ -215,6 +216,8 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 42d3fc101b..8c3f029747 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -14,6 +14,9 @@ from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from ...constants import TE_DType +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...export import is_in_onnx_export_mode +from ...tensor import Quantizer from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -22,8 +25,6 @@ ) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, maybe_dequantize -from ...export import is_in_onnx_export_mode -from ...tensor import Quantizer class RMSNorm(BasicOperation): @@ -196,6 +197,8 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x, rstdevs) ctx.save_for_backward(x, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index b87b12f840..02bcfee0ae 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -10,14 +10,11 @@ import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.ops.basic import BasicLinear, Bias -from transformer_engine.pytorch.ops.op import ( - FusedOperation, - FusibleOperation, - OperationContext, -) +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...fp8 import FP8GlobalStateManager from ...tensor import Quantizer +from ..basic import BasicLinear, Bias +from ..op import FusedOperation, FusibleOperation, OperationContext class ForwardLinearBiasActivation(FusedOperation): @@ -121,6 +118,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index dd59e602f2..15cc081c1d 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -10,14 +10,11 @@ import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.ops.basic import AddExtraInput, BasicLinear, Bias -from transformer_engine.pytorch.ops.op import ( - FusedOperation, - FusibleOperation, - OperationContext, -) -from transformer_engine.pytorch.tensor import Quantizer +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...fp8 import FP8GlobalStateManager +from ...tensor import Quantizer +from ..basic import AddExtraInput, BasicLinear, Bias +from ..op import FusedOperation, FusibleOperation, OperationContext class ForwardLinearBiasAdd(FusedOperation): @@ -118,6 +115,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 448f72763a..21190d4fcf 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -10,14 +10,15 @@ import torch +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...fp8 import FP8GlobalStateManager +from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, ConstantScale from ..op import ( FusedOperation, FusibleOperation, OperationContext, ) -from ...tensor import Quantizer class ForwardLinearScaleAdd(FusedOperation): @@ -95,6 +96,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 574642794f..a604e57dcd 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -12,6 +12,7 @@ from transformer_engine_torch import CommOverlapType from ...cpp_extensions import general_gemm +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import get_distributed_world_size from ...fp8 import FP8GlobalStateManager from ...module.base import ( @@ -353,6 +354,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index adffe7c580..da0220eb7a 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -349,9 +349,14 @@ def _create_columnwise(self): def _transpose_columnwise_data(self): """Plainly transpose the columnwise data and scale inv.""" if self._columnwise_data is not None: + # TODO(yuzhongw, tmoon): Figure out why _old_data is not automatically + # deallocated by GC. Manually deallocating is a temporary hack. + _old_data = self._columnwise_data self._columnwise_data = tex.fp8_transpose( self._columnwise_data, self._fp8_dtype, out=None ) + _old_data.data = _empty_tensor() + del _old_data def __repr__(self): if self._rowwise_data is not None: diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 61edc999ac..6d48223443 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -95,8 +95,13 @@ def __new__( return instance def clear(self): - """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" - for t in (self._data, self._transpose, self._scale_inv): + """Deallocate this tensor's memory. Typically not needed and must be used carefully. + + Scale-inv tensor is not deallocated because it's often shared + between multiple FP8 tensors. + + """ + for t in (self._data, self._transpose): if t is not None: t.data = _empty_tensor() self._transpose_invalid = True From 7bf0bc4211eacccc77df7b659c8c4441b432e62d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Sep 2025 23:27:35 +0000 Subject: [PATCH 29/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/basic/activation.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 2ca42dbcd8..8a754c6382 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -416,11 +416,9 @@ class ClampedSwiGLU(_ActivationOperation): Quantize input tensor when caching for use in the backward pass. """ - def __init__(self, - *, - limit: float = 7.0, - alpha: float = 1.702, - cache_quantized_input: bool = False): + def __init__( + self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False + ): super().__init__(cache_quantized_input=cache_quantized_input) self.limit = limit self.alpha = alpha From fe93c015c1c7d2c648d573dc5344b8fb72d9a5d8 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 18 Sep 2025 18:47:31 -0700 Subject: [PATCH 30/53] Use limit=0.75 in clamped SwiGLU test Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3812eb28f8..8b4f671fe0 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1720,6 +1720,8 @@ def test_clamped_swiglu( quantization: Optional[str], quantize_forward: bool, quantize_backward: bool, + limit: float = 0.75, + alpha: float = 1.702, ): # Test SwiGLU variant used in GPT OSS. # Tensor dimensions @@ -1744,12 +1746,6 @@ def test_clamped_swiglu( test_device=device, requires_grad=False, ) - # A low value of limit = 0.1 is used for this test instead of the original - # default = 7.0 used in GPT OSS. This is because low value kills decent number - # of gradients allowing us to check for correctness of gradient computation of - # ClampedSwiGLU. - limit = 0.1 - alpha = 1.702 # Plain PyTorch implementation x_glu, x_linear = x_ref.chunk(2, dim=-1) From 5d3b169ef369dde5d1fbc1750a4fefbcd0c5d002 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 19 Sep 2025 05:51:25 +0000 Subject: [PATCH 31/53] Address review comments Signed-off-by: Varun Thumbe --- transformer_engine/common/activation/swiglu.cu | 4 ++-- transformer_engine/common/util/math.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 0264bc9fbb..cafc48abba 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -40,7 +40,7 @@ void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, NVTE_API_CALL(nvte_clamped_swiglu); using namespace transformer_engine; ClampedSwiGLUParam param = {limit, alpha}; - gated_act_fn>(input, output, param, stream); + gated_act_fn>(input, output, param, stream); } void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, @@ -48,6 +48,6 @@ void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETen NVTE_API_CALL(nvte_clamped_dswiglu); using namespace transformer_engine; ClampedSwiGLUParam param = {limit, alpha}; - dgated_act_fn, oss_dsilu>( + dgated_act_fn, clamped_dsilu>( grad, input, output, param, stream); } diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 5885652c5a..2f20817fb0 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -75,7 +75,7 @@ __device__ inline OType silu(const IType val, const Empty& e) { } template -__device__ inline OType oss_silu(const IType val, const ClampedSwiGLUParam& p) { +__device__ inline OType clamped_silu(const IType val, const ClampedSwiGLUParam& p) { const float cval = min(p.limit, static_cast(val)); // Clamping return qgelu_with_alpha(cval, p.alpha); } @@ -87,7 +87,7 @@ __device__ inline OType dsilu(const IType val, const Empty& e) { } template -__device__ inline OType oss_dsilu(const IType val, const ClampedSwiGLUParam& p) { +__device__ inline OType clamped_dsilu(const IType val, const ClampedSwiGLUParam& p) { const bool dclamp_val = static_cast(val) <= p.limit; const float clamp_val = min(static_cast(val), p.limit); const float dsilu_val = dqgelu_with_alpha(clamp_val, p.alpha); From 0c17c7e9653f79837bf8f0b152bcf442123ea507 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 24 Sep 2025 01:29:57 +0000 Subject: [PATCH 32/53] JAX integration changes Signed-off-by: Varun Thumbe --- tests/jax/test_custom_call_compute.py | 89 +++++-- .../common/util/cast_gated_kernels.cuh | 2 +- .../common/util/vectorized_pointwise.h | 2 +- transformer_engine/jax/activation.py | 31 ++- .../jax/cpp_extensions/activation.py | 251 ++++++++++++++---- transformer_engine/jax/csrc/extensions.h | 18 +- .../jax/csrc/extensions/activation.cpp | 55 ++-- .../jax/csrc/extensions/pybind.cpp | 1 + transformer_engine/jax/flax/module.py | 13 +- transformer_engine/jax/flax/transformer.py | 5 + transformer_engine/jax/layernorm_mlp.py | 27 +- 11 files changed, 386 insertions(+), 108 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9e39b84c0b..a50017e776 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -170,29 +170,35 @@ def assert_dequantized_grouped_scaled_tensor( ("quick_gelu", "linear"), ("squared_relu",), ("squared_relu", "linear"), + ("clamped_silu", "clamped_linear"), ] ACTIVATION_TYPES = { "L0": [ ("gelu",), ("gelu", "linear"), + ("clamped_silu", "clamped_linear"), ], "L2": ALL_ACTIVATION_TYPES, } class TestActivation: - def ref_act(self, x, activation_type): - return _jax_act_lu(x, activation_type).data + def ref_act(self, x, activation_type, act_params): + return _jax_act_lu(x, activation_type, act_params=act_params).data - def value_n_grad_ref_func(self, x, activation_type): + def value_n_grad_ref_func(self, x, activation_type, act_params): jitted_reference = jit( - value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,)) + value_and_grad( + lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,) + ) ) return jitted_reference(x) - def primitive_func(self, inputs, activation_type, quantizer): - out = activation(inputs, activation_type=activation_type, quantizer=quantizer) + def primitive_func(self, inputs, activation_type, quantizer, act_params): + out = activation( + inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params + ) return jnp.mean(out) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @@ -209,12 +215,20 @@ def test_act_grad(self, shape, activation_type): x = jnp.repeat(x, len(activation_type), axis=-2) value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3) ) - - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype) @@ -234,7 +248,8 @@ def test_act_grad_with_tensor_scaling_fp8( self.activation_type = activation_type value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), + static_argnums=(1, 3), ) quantizer = QuantizerFactory.create( @@ -242,9 +257,21 @@ def test_act_grad_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func( + x, activation_type, quantizer, act_params + ) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type) @@ -273,10 +300,18 @@ def test_act_forward_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=q_layout, ) - - te_output = tex.act_lu(x, activation_type, te_quantizer) - jax_output = _jax_act_lu(x, activation_type, jax_quantizer) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + te_output = tex.act_lu(x, activation_type, te_quantizer, act_params) + jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params) assert_bitwise_scaled_tensors(te_output, jax_output) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @@ -296,10 +331,18 @@ def test_act_forward_with_block_scaling_fp8( quantizer = QuantizerFactory.create( scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) - - output = tex.act_lu(x, activation_type, quantizer) - ref_out = self.ref_act(x, activation_type) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + output = tex.act_lu(x, activation_type, quantizer, act_params) + ref_out = self.ref_act(x, activation_type, act_params) assert_dequantized_scaled_tensor(output, ref_out) @@ -1450,4 +1493,4 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype) assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) - assert_allclose(prim_dbias, ref_dbias, dtype=dtype) + assert_allclose(prim_dbias, ref_dbias, dtype=dtype) \ No newline at end of file diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 7f5d68fc62..3dab0ccc5c 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -747,7 +747,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate.data.elt[e]); float after_act_elt; float after_gate_elt; - float dgate_elt = true; + bool dgate_elt = true; if constexpr (std::is_same::value) { // In case of GPT OSS, clamp the activation and gate values dgate_elt = gate_elt < p.limit && gate_elt > -p.limit; // Derivative of clamp diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 959eb8ea90..4ad1c16de8 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -338,7 +338,7 @@ template void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output, const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N, - const Param params, cudaStream_t stream) { + const Param ¶ms, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index 12b35ec43c..d154779f31 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -8,10 +8,11 @@ from typing import Sequence, Union, Callable, Optional from functools import partial +from dataclasses import dataclass import jax import jax.numpy as jnp - +import numpy as np from . import cpp_extensions as tex from .quantize.tensor import NoScaleTensor @@ -22,6 +23,7 @@ def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[tex.activation.ActivationParams] = None, ) -> jnp.ndarray: """Apply activation functions to input tensor with optional quantization. @@ -32,17 +34,19 @@ def activation( x: Input tensor to apply activations to activation_type: Sequence of activation functions quantizer: Optional quantizer for quantizing the output + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated output tensor """ assert x.shape[-1] % len(activation_type) == 0 - output = _activation(x, activation_type, quantizer) + output = _activation(x, activation_type, quantizer, act_params) return output -@partial(jax.custom_vjp, nondiff_argnums=(1,)) -def _activation(x, activation_type, quantizer): +@partial(jax.custom_vjp, nondiff_argnums=(1, 3)) +def _activation(x, activation_type, quantizer, act_params): """Internal implementation of activation with custom VJP. This function implements the core activation logic with support for @@ -52,36 +56,43 @@ def _activation(x, activation_type, quantizer): x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated tensor """ - _output, _ = _activation_fwd_rule(x, activation_type, quantizer) + _output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params) return _output -def _activation_fwd_rule(x, activation_type, quantizer): +def _activation_fwd_rule(x, activation_type, quantizer, act_params): """Forward pass rule for activation function. Args: x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Tuple of (output, context) for backward pass """ - fwd_output = tex.act_lu(x, activation_type, quantizer) + + fwd_output = tex.act_lu(x, activation_type, quantizer, act_params) # This is a no-op for higher-precision tensors fwd_output = fwd_output.dequantize() return fwd_output, (x, quantizer) -def _activation_bwd_rule(activation_type, ctx, g): +def _activation_bwd_rule(activation_type, act_params, ctx, g): """Backward pass rule for activation function. Args: activation_type: Sequence of activation functions + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. ctx: Context from forward pass g: Gradient from upstream @@ -90,7 +101,7 @@ def _activation_bwd_rule(activation_type, ctx, g): """ (x, _) = ctx assert x.dtype == g.dtype - dx = tex.dact_lu(g, x, activation_type) + dx = tex.dact_lu(g, x, activation_type, act_params=act_params) # No quantization is used in this VJP backward, so the output should # always be a NoScaleTensor assert isinstance(dx, NoScaleTensor) @@ -98,4 +109,4 @@ def _activation_bwd_rule(activation_type, ctx, g): return (dx, None) -_activation.defvjp(_activation_fwd_rule, _activation_bwd_rule) +_activation.defvjp(_activation_fwd_rule, _activation_bwd_rule) \ No newline at end of file diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index cdda201668..259f17c681 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -6,16 +6,18 @@ import operator from functools import reduce, partial from packaging import version +from dataclasses import dataclass + import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec +import numpy as np import transformer_engine_jax from transformer_engine_jax import NVTE_Activation_Type - from .base import BasePrimitive, register_primitive from .misc import ( jax_dtype_to_te_dtype, @@ -27,7 +29,7 @@ should_apply_1x_fused_dbias_war_for_arch_l_100, NamedSharding, ) -from .quantization import _jax_dbias, _quantize_dbias_impl +from .quantization import _jax_dbias, _quantize_dbias_impl, AmaxScope from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( @@ -37,10 +39,6 @@ ScalingMode, ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports __all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"] @@ -56,17 +54,106 @@ ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU, ("squared_relu",): NVTE_Activation_Type.SRELU, ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU, + ("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU, } -def _convert_to_activation_function(fn_or_string): +@dataclass(frozen=True) +class ClampedSwigluParams: + limit: float = 7.0 + alpha: float = 1.702 + """Parameters for the Clamped SwiGLU activation function + used in GPT OSS.""" + + def __hash__(self): + return hash((self.limit, self.alpha)) + + def to_ffi_lowering_dict(self): + return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + + +@dataclass(frozen=True) +class ActivationParams: + clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams() + + # Add other activation-specific parameter fields here as needed in the future + @staticmethod + def create(activation_type, **kwargs): + """Factory method to create ActivationParams based on activation_type.""" + CLAMPED_ACTIVATION_TYPES = { + ("clamped_silu", "clamped_linear"), + "clamped_silu", + "clamped_linear", + } + if activation_type in CLAMPED_ACTIVATION_TYPES: + return ActivationParams(ClampedSwigluParams(**kwargs)) + else: + return ActivationParams() # Default params for activations without parameters + + def __hash__(self): + return hash((self.clamped_swiglu,)) + + def to_ffi_lowering_dict(self): + return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()} + + +@dataclass(frozen=True) +class ClampedSwigluParams: + limit: float = 7.0 + alpha: float = 1.702 + """Parameters for the Clamped SwiGLU activation function + used in GPT OSS.""" + + def __hash__(self): + return hash((self.limit, self.alpha)) + + def to_ffi_lowering_dict(self): + return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + + +@dataclass(frozen=True) +class ActivationParams: + clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams() + + # Add other activation-specific parameter fields here as needed in the future + @staticmethod + def create(activation_type, **kwargs): + """Factory method to create ActivationParams based on activation_type.""" + CLAMPED_ACTIVATION_TYPES = { + ("clamped_silu", "clamped_linear"), + "clamped_silu", + "clamped_linear", + } + if activation_type in CLAMPED_ACTIVATION_TYPES: + return ActivationParams(ClampedSwigluParams(**kwargs)) + else: + return ActivationParams() # Default params for activations without parameters + + def __hash__(self): + return hash((self.clamped_swiglu,)) + + def to_ffi_lowering_dict(self): + return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()} + + +def _convert_to_activation_function(fn_or_string, act_params: ActivationParams): """Convert a string to an activation function.""" if fn_or_string == "linear": return lambda x: x + if fn_or_string == "clamped_linear": + # This function is used for ClampedSwiGLU + # used in GPT OSS where the gates are not only clamped + # but also shifted by +1 + limit = act_params.clamped_swiglu.limit + return lambda x: jnp.clip(x, min=-limit, max=limit) + 1 if fn_or_string == "quick_gelu": return lambda x: jax.nn.sigmoid(1.702 * x) * x if fn_or_string == "squared_relu": return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)]) + if fn_or_string == "clamped_silu": + limit = act_params.clamped_swiglu.limit + alpha = act_params.clamped_swiglu.alpha + return lambda x: jax.nn.sigmoid(alpha * jnp.minimum(x, limit)) * jnp.minimum(x, limit) if isinstance(fn_or_string, str): return getattr(jax.nn, fn_or_string) if callable(fn_or_string): @@ -89,7 +176,8 @@ class ActLuPrimitive(BasePrimitive): 6, 7, 8, - ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer + 9, + ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer, act_params inner_primitive = None outer_primitive = None @@ -105,11 +193,12 @@ def abstract( is_2x, scale_dtype, is_outer, + act_params, ): """ te_act_lu_p abstract """ - del act_enum + del act_enum, act_params dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 @@ -155,6 +244,7 @@ def lowering( is_2x, scale_dtype, is_outer, + act_params, ): """ te_gated_act_lu_p lowering rules @@ -163,9 +253,14 @@ def lowering( x_aval, scale_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 - out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x + ctx, + x, + scale, + act_enum=act_enum, + scaling_mode=scaling_mode.value, + is_2x=is_2x, + act_params=act_params.to_ffi_lowering_dict(), ) return out @@ -180,6 +275,7 @@ def impl( is_2x, scale_dtype, is_outer, + act_params, ): """ to describe implementation @@ -198,6 +294,7 @@ def impl( is_2x=is_2x, scale_dtype=scale_dtype, is_outer=False, + act_params=act_params, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -226,6 +323,7 @@ def batcher( is_2x, scale_dtype, is_outer, + act_params, ): """ to describe batch rules for vmap @@ -247,6 +345,7 @@ def batcher( scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, + act_params=act_params, ), out_bdims, ) @@ -260,6 +359,7 @@ def infer_sharding_from_operands( is_2x, scale_dtype, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -271,6 +371,7 @@ def infer_sharding_from_operands( scale_dtype, act_len, is_outer, + act_params, ) # Unused. x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) @@ -323,6 +424,7 @@ def partition( is_2x, scale_dtype, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -383,6 +485,7 @@ def sharded_impl(x, scale): is_2x=is_2x, scale_dtype=scale_dtype, is_outer=True, + act_params=act_params, ) ) @@ -410,32 +513,35 @@ def shardy_sharding_rule( is_2x, scale_dtype, is_outer, + act_params, mesh, value_types, result_types, ): - del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types - prefix = "ActLuPrimitive_" - x_rank = len(value_types[0].shape) + + del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types, act_params + prefix = "ActLu_" + input_shape = value_types[0].shape + output_shape = input_shape[:-2] + input_shape[-1:] + # Here we pass len of output so that the scales are propagated correctly scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - x_rank - 1, unique_var=prefix + "x", flatten_axis=-2 + output_shape, unique_var=prefix + "x", flatten_axis=-1 ) - x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",) - out = (*x_axes[:-2], x_axes[-1]) - scale_inv = scale_rules.rowwise_rule + x_axes = scale_rules.input_spec + # Correct input spec with act dim + x_axes = x_axes[:-1] + (prefix + "_act_dim",) + x_axes[-1:] + out = scale_rules.input_spec colwise_out = (prefix + "out_colwise",) colwise_scale_inv = (prefix + "scale_inv_colwise",) if is_2x: colwise_scale_inv = scale_rules.colwise_rule if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_out = tuple( - multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2) - ) + colwise_out = multidim_transpose(out, transpose_axis=-1) else: colwise_out = out + colwise_scale_inv = scale_rules.colwise_rule - # amax is always a unit tensor. amax = (prefix + "amax",) return SdyShardingRule( @@ -443,7 +549,8 @@ def shardy_sharding_rule( x_axes, ("…1",), ), - (out, colwise_out, scale_inv, colwise_scale_inv, amax), + (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax), + **scale_rules.factor_sizes, ) @@ -458,8 +565,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): name = "te_dact_dbias_quantize_ffi" multiple_results = True - # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer - impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10) + # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer, act_params + impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11) inner_primitive = None outer_primitive = None @@ -477,6 +584,7 @@ def abstract( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p abstract @@ -533,6 +641,7 @@ def abstract( jax_dtype_to_te_dtype(out_dtype), scaling_mode, is_2x, + act_params, ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -578,6 +687,7 @@ def lowering( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p lowering rules @@ -596,6 +706,7 @@ def lowering( is_2x=is_2x, is_dbias=is_dbias, act_enum=int(act_enum), + act_params=act_params.to_ffi_lowering_dict(), ) @staticmethod @@ -611,6 +722,7 @@ def impl( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p impl @@ -630,6 +742,7 @@ def impl( act_enum=act_enum, act_len=act_len, is_outer=False, + act_params=act_params, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -658,6 +771,7 @@ def batcher( act_enum, act_len, is_outer, + act_params, ): """ to describe batch rules for vmap @@ -688,6 +802,7 @@ def batcher( is_dbias=is_dbias, act_enum=act_enum, act_len=act_len, + act_params=act_params, ), out_bdims, ) @@ -702,11 +817,12 @@ def infer_sharding_from_operands( act_enum, act_len, is_outer, + act_params, mesh, arg_infos, result_infos, ): - del out_dtype, result_infos, act_enum + del out_dtype, result_infos, act_enum, act_params del scale_dtype, act_len, is_outer x_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[2]) @@ -777,6 +893,7 @@ def partition( act_enum, act_len, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -857,6 +974,7 @@ def sharded_impl(dz, x, scale): act_enum=act_enum, act_len=act_len, is_outer=True, + act_params=act_params, ) ) if is_dbias: @@ -883,31 +1001,38 @@ def shardy_sharding_rule( act_enum, act_len, is_outer, + act_params, mesh, value_types, result_types, ): - del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types - prefix = "BaseDActLuDBiasQuantizePrimitive_" + + del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types, act_params + prefix = "DActLuDBias_" + scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 + value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 ) x_axes = scale_rules.input_spec dz_axes = (*x_axes[:-2], x_axes[-1]) out = x_axes + colwise_out = (prefix + "out_colwise",) + colwise_scale_inv = (prefix + "scale_inv_colwise",) if is_2x: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) else: colwise_out = out + colwise_scale_inv = scale_rules.colwise_rule dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) amax = (prefix + "amax",) return SdyShardingRule( (dz_axes, x_axes, ("…2",)), - (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), + (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + **scale_rules.factor_sizes, ) @@ -922,20 +1047,22 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]: +def _jax_act_lu( + inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None +) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert inputs.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {inputs.shape} and act_len {act_len}" ) - x = jnp.split(inputs, act_len, axis=-2) acts = [] for idx, act_fn in enumerate(activation_type): - x_i = _convert_to_activation_function(act_fn)(x[idx]) + x_i = _convert_to_activation_function(act_fn, act_params)(x[idx]) acts.append(x_i) x = reduce(operator.mul, acts) x = jnp.squeeze(x, axis=-2) @@ -950,10 +1077,12 @@ def _jax_quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ): """ JAX implementation of dact_lu and dbias with optional quantization """ + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -961,7 +1090,8 @@ def _jax_quantize_dact_dbias( ) _, vjp_func = jax.vjp( - partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) + partial(_jax_act_lu, activation_type=activation_type, act_params=act_params), + x.astype(jnp.float32), ) # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) @@ -984,6 +1114,8 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -992,6 +1124,7 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: If quantizer is None: @@ -1005,24 +1138,22 @@ def act_lu( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x.shape} and act_len {act_len}" ) - + act_params = act_params if act_params is not None else ActivationParams() if not ActLuPrimitive.enabled(): - return _jax_act_lu(x, activation_type, quantizer) + return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support colwise-only quantization yet if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: - return _jax_act_lu(x, activation_type, quantizer) - + return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support 2x quantization for DelayedScaling yet war_output = try_apply_delayed_scaling_2x_war( - f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer + f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer, act_params=act_params ) if war_output is not None: return war_output scale = jnp.empty((1,), jnp.float32) output_shape = (*x.shape[:-2], x.shape[-1]) - if quantizer is None: out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( x, @@ -1034,6 +1165,7 @@ def act_lu( is_2x=False, scale_dtype=jnp.float32, is_outer=True, + act_params=act_params, ) out = out.reshape(output_shape) out = NoScaleTensor( @@ -1048,10 +1180,16 @@ def act_lu( x=x, activation_type=activation_type, quantizer=None, + act_params=act_params, + ) + out, _ = _quantize_dbias_impl( + out, + is_dbias=False, + quantizer=quantizer, + dq_dtype=x.dtype, + amax_scope=amax_scope, ) - out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) return out - if isinstance(quantizer, DelayedScaleQuantizer): scale = quantizer.scale @@ -1071,6 +1209,7 @@ def act_lu( is_2x=quantizer.is_2x2x(), scale_dtype=quantizer.get_scale_dtype(), is_outer=True, + act_params=act_params, ) quantizer.update(updated_amax) @@ -1093,6 +1232,7 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1109,7 +1249,7 @@ def quantize_dact_dbias( - The gradient of the activation with respect to the input. - The gradient of the activation with respect to the bias. """ - + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -1122,8 +1262,7 @@ def quantize_dact_dbias( if not PrimitiveClass.enabled() or ( quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE ): - return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) - + return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) if quantizer is None: output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( dz, @@ -1139,6 +1278,7 @@ def quantize_dact_dbias( act_enum=act_type_id, act_len=act_len, is_outer=True, + act_params=act_params, ) output = output.astype(x.dtype) dbias = None @@ -1154,7 +1294,11 @@ def quantize_dact_dbias( # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type, + quantizer=None, + act_params=act_params, ) return _quantize_dbias_impl( out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 @@ -1171,6 +1315,7 @@ def quantize_dact_dbias( is_dbias=is_dbias, quantizer=quantizer, flatten_axis=-2, + act_params=act_params, ) if war_output is not None: return war_output @@ -1182,6 +1327,7 @@ def quantize_dact_dbias( x=x, activation_type=activation_type, quantizer=None, + act_params=act_params, ) out, dbias = _quantize_dbias_impl( out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 @@ -1194,7 +1340,10 @@ def quantize_dact_dbias( # TE/common dact_dbias_quantize does not support gated act yet if is_dbias and is_gated: dgated = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type=activation_type, + act_params=act_params, ) out, dbias = _quantize_dbias_impl( dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 @@ -1220,6 +1369,7 @@ def quantize_dact_dbias( act_enum=act_type_id, act_len=act_len, is_outer=True, + act_params=act_params, ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise @@ -1248,6 +1398,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1261,11 +1412,13 @@ def dact_lu( Returns: The gradient of the activation with respect to the input. """ + act_params = act_params if act_params is not None else ActivationParams() output, _ = quantize_dact_dbias( dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=quantizer, + act_params=act_params, ) - return output + return output \ No newline at end of file diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 59079fe3f0..7ad3dc00d6 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -38,6 +38,15 @@ XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); namespace transformer_engine { namespace jax { +struct ClampedSwigluConfig { + float limit; + float alpha; +}; + +struct ActivationConfig { + ClampedSwigluConfig clamped_swiglu_config; +}; + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation @@ -134,4 +143,11 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); } // namespace jax } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig, + ::xla::ffi::StructMember("limit"), + ::xla::ffi::StructMember("alpha")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::ActivationConfig, + ::xla::ffi::StructMember("clamped_swiglu")); +#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ \ No newline at end of file diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 17fa9906bb..90301ab1c7 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -18,7 +18,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, - bool is_2x_int) { + bool is_2x_int, ActivationConfig act_params) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.clamped_swiglu_config.limit; + auto swiglu_alpha = act_params.clamped_swiglu_config.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -125,6 +128,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal case NVTE_Activation_Type::SREGLU: nvte_sreglu(input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, + stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -133,20 +140,23 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // input - .Arg() // scale - .Ret() // output - .Ret() // colwise output - .Ret() // scale_inv - .Ret() // scale_inv colwise - .Ret() // amax - .Attr("act_enum") - .Attr("scaling_mode") - .Attr("is_2x"), - FFI_CudaGraph_Traits); +XLA_FFI_DEFINE_HANDLER_SYMBOL( + ActLuHandler, ActLuFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Attr("act_enum") + .Attr("scaling_mode") + .Attr("is_2x") + .Attr( + "act_params"), // Can generalize the config later if we have more activations that need params + FFI_CudaGraph_Traits); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, @@ -216,7 +226,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t act_enum, bool is_2x, bool is_dbias) { + int64_t act_enum, bool is_2x, bool is_dbias, + ActivationConfig act_params) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.clamped_swiglu_config.limit; + auto swiglu_alpha = act_params.clamped_swiglu_config.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -383,6 +397,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, case NVTE_Activation_Type::SREGLU: nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + swiglu_limit, swiglu_alpha, stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -408,7 +426,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("scaling_mode") .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias"), + .Attr("is_dbias") + .Attr("act_params"), FFI_CudaGraph_Traits); } // namespace jax -} // namespace transformer_engine +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index afbeb644c1..08600fd3f4 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -133,6 +133,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("QGEGLU", NVTE_Activation_Type::QGEGLU) .value("SRELU", NVTE_Activation_Type::SRELU) .value("SREGLU", NVTE_Activation_Type::SREGLU) + .value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU) .export_values(); pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index c548c54efa..f02876d8f4 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -898,6 +898,10 @@ class LayerNormMLP(TransformerEngineBase): activations: Sequence[Union[str, Callable]], default = ('relu',) The sequence of activation functions to apply after the first dense layer transformation. Each activation has its own transformation layer. + activation_params: dict, default = None + The parameters needed(if any) by the activation functions specified in :attr:`activations`. + At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS + need additional parameters. intermediate_dropout_rng_name: str, default = 'dropout' The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. intermediate_dropout_rate: float, default = 0.1 @@ -956,6 +960,7 @@ class LayerNormMLP(TransformerEngineBase): bias_axes_2: Tuple[str, ...] = ("embed",) return_layernorm_output: bool = True activations: Sequence[Union[str, Callable]] = ("relu",) + activation_params: dict = None intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () @@ -1023,6 +1028,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: ("relu", "linear"), ("quick_gelu", "linear"), ("squared_relu", "linear"), + ("clamped_silu", "clamped_linear"), ] act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)] normalized_acts = [] @@ -1031,7 +1037,9 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: return False normalized_acts.append(act.lower()) normalized_acts = tuple( - reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts + reversed(normalized_acts) + if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear") + else normalized_acts ) is_act_implemented = normalized_acts in (gated_act_pool + act_pool) @@ -1150,6 +1158,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name=self.ffn1_ckpt_name, ffn2_ckpt_name=self.ffn2_ckpt_name, activation_type=normalized_acts, + activation_params=self.activation_params, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), ) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) @@ -1287,4 +1296,4 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): out = checkpoint_name(out, self.ffn2_ckpt_name) assert out.dtype == input_dtype - return out, ln_output # Output, layner_norm_output + return out, ln_output # Output, layer_norm_output diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index fb3ac7b9ae..fc72b9bc3f 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1631,6 +1631,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mlp_activations: Sequence[str], default = ('relu', ) The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer. + mlp_activation_params: dict = None + This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment + ClampedSwiglu is the only activation that requires parameters. use_bias: bool, default = False Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. If set to False, the layer will not learn additive biases. @@ -1751,6 +1754,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mha_kernel_init: Initializer = None mlp_kernel_init: Initializer = None mlp_activations: Sequence[str] = ("relu",) + mlp_activation_params: dict = None use_bias: bool = False bias_init: Initializer = nn.initializers.zeros apply_residual_connection_post_layernorm: bool = False @@ -2045,6 +2049,7 @@ def hidden_dropout(x, deterministic): return_layernorm_output=self.apply_residual_connection_post_layernorm, intermediate_dim=self.mlp_hidden_size, activations=self.mlp_activations, + activation_params=self.mlp_activation_params, intermediate_dropout_rng_name=self.dropout_rng_name, intermediate_dropout_rate=self.intermediate_dropout, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index fc957801af..9d59629427 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -21,6 +21,7 @@ from jax.ad_checkpoint import checkpoint_name from . import cpp_extensions as tex +from .cpp_extensions.quantization import AmaxScope from .layernorm import canonicalize_norm_type from .quantize import ( with_sharding_constraint_by_logical_axes, @@ -48,6 +49,7 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + activation_params: dict = None, quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -129,6 +131,7 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, quantizer_sets, ) return output @@ -154,6 +157,7 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + activation_params: dict, quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -203,6 +207,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, quantizer_sets, ) return output @@ -227,6 +232,7 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -272,13 +278,12 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, + amax_scope=AmaxScope.TPSP, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_kernel_1 = tex.quantize( - kernel_1, - flatten_axis=-2, - quantizer=ffn1_quantizer_set.kernel, + kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP ) # NN GEMM @@ -306,10 +311,18 @@ def _layernorm_mlp_fwd_rule( dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # (batch..., hidden_in) -> (batch..., hidden) + # At the moment the act_params is only used for ClampedSwiglu + # If there are more activations that require parameters in the future + # we might need to change it to a more generic parameter container casted_act_out = tex.act_lu( dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, + act_params=( + tex.activation.ActivationParams.create(activation_type, **activation_params) + if activation_params + else None + ), ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -317,6 +330,7 @@ def _layernorm_mlp_fwd_rule( casted_kernel_2 = tex.quantize( kernel_2, quantizer=ffn2_quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, ) # NN GEMM @@ -371,6 +385,7 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, ctx, grad, ): @@ -417,6 +432,7 @@ def _layernorm_mlp_bwd_rule( grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, + amax_scope=AmaxScope.TPSP, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -457,6 +473,11 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, + act_params=( + tex.activation.ActivationParams.create(activation_type, **activation_params) + if activation_params + else None + ), ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim From 66c7086ce3ebd83d1b964c678edd7b39343abf58 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Sep 2025 01:41:59 +0000 Subject: [PATCH 33/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 2 +- tests/pytorch/attention/test_attention_with_cp.py | 1 - transformer_engine/jax/activation.py | 2 +- transformer_engine/jax/cpp_extensions/activation.py | 3 +-- transformer_engine/jax/cpp_extensions/normalization.py | 1 - transformer_engine/jax/csrc/extensions.h | 2 +- transformer_engine/jax/csrc/extensions/activation.cpp | 2 +- 7 files changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index a50017e776..d6192d33aa 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1493,4 +1493,4 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype) assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) - assert_allclose(prim_dbias, ref_dbias, dtype=dtype) \ No newline at end of file + assert_allclose(prim_dbias, ref_dbias, dtype=dtype) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 30b42263d6..0c8a34876b 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -105,7 +105,6 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if not flash_attn_supported: pytest.skip("No attention backend available.") - subprocess.run( get_bash_arguments( num_gpus_per_node=num_gpus, diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index 2e7476a40a..c4e7cf2fa2 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -108,4 +108,4 @@ def _activation_bwd_rule(activation_type, act_params, ctx, g): return (dx, None) -_activation.defvjp(_activation_fwd_rule, _activation_bwd_rule) \ No newline at end of file +_activation.defvjp(_activation_fwd_rule, _activation_bwd_rule) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index d7fd8bbd17..0995a4580f 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1044,7 +1044,6 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" - def _jax_act_lu( inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None ) -> Union[NoScaleTensor, ScaledTensor]: @@ -1419,4 +1418,4 @@ def dact_lu( quantizer=quantizer, act_params=act_params, ) - return output \ No newline at end of file + return output diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index a7295a02fc..3348c725be 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -1169,7 +1169,6 @@ def rmsnorm_fwd( quantizer=quantizer, dq_dtype=x.dtype, amax_scope=amax_scope, - ) return out, rsigma diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 7ad3dc00d6..1a55cc52cd 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -150,4 +150,4 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConf XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( transformer_engine::jax::ActivationConfig, ::xla::ffi::StructMember("clamped_swiglu")); -#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ \ No newline at end of file +#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 90301ab1c7..7e7e3178b4 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -430,4 +430,4 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("act_params"), FFI_CudaGraph_Traits); } // namespace jax -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine From af19dbf654cc7cea6622304fa193b1914eb5b00b Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 23 Sep 2025 18:43:30 -0700 Subject: [PATCH 34/53] revert line break Signed-off-by: vthumbe1503 --- tests/pytorch/attention/test_attention_with_cp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 0c8a34876b..c752d07d82 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -94,7 +94,6 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ) if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") - dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} available_backends, *_ = get_available_attention_backends( config, From 4f29915a1f34425bec24277d5740631927ac0e88 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 23 Sep 2025 18:44:24 -0700 Subject: [PATCH 35/53] revert line break Add documentation for quantization function parameters and return value. Signed-off-by: vthumbe1503 --- transformer_engine/jax/cpp_extensions/quantization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index e34ed8cde3..021af4c9db 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -875,6 +875,7 @@ def quantize_dbias( Defaults to -1. amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + Returns: A tuple containing: - A ScaledTensor containing the quantized input tensor. From 24828f3ea738cf5dea70b9a27f82dcc0f8da4663 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 24 Sep 2025 02:01:16 +0000 Subject: [PATCH 36/53] missed adding oss swiglu to nvte enum in common Signed-off-by: Varun Thumbe --- .../common/include/transformer_engine/activation.h | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index e50d71040d..4e48088586 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -39,6 +39,7 @@ enum class NVTE_Activation_Type { QGEGLU, SRELU, SREGLU, + CLAMPED_SWIGLU }; /*! \brief Computes the GeLU activation of the input. From 19410b624f72a0cfb8c823a65a73e9d5d3b3179b Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 24 Sep 2025 03:17:56 +0000 Subject: [PATCH 37/53] fix jax linting errors Signed-off-by: Varun Thumbe --- examples/jax/encoder/test_multigpu_encoder.py | 1 + .../jax/cpp_extensions/activation.py | 67 +++++++------------ 2 files changed, 24 insertions(+), 44 deletions(-) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index bc6a567521..b0bf185511 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -53,6 +53,7 @@ def __call__(self, x, mask, disable_dropout=False): layer_type=te_flax.TransformerLayerType.ENCODER, self_attn_mask_type="padding", enable_relative_embedding=False, + mlp_activations=('silu', 'linear'), ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 0995a4580f..a32f15a489 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -5,8 +5,8 @@ from typing import Sequence, Union, Callable, Optional, Tuple import operator from functools import reduce, partial -from packaging import version from dataclasses import dataclass +from packaging import version import jax import jax.numpy as jnp @@ -59,62 +59,36 @@ @dataclass(frozen=True) class ClampedSwigluParams: - limit: float = 7.0 - alpha: float = 1.702 """Parameters for the Clamped SwiGLU activation function used in GPT OSS.""" - - def __hash__(self): - return hash((self.limit, self.alpha)) - - def to_ffi_lowering_dict(self): - return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} - - -@dataclass(frozen=True) -class ActivationParams: - clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams() - - # Add other activation-specific parameter fields here as needed in the future - @staticmethod - def create(activation_type, **kwargs): - """Factory method to create ActivationParams based on activation_type.""" - CLAMPED_ACTIVATION_TYPES = { - ("clamped_silu", "clamped_linear"), - "clamped_silu", - "clamped_linear", - } - if activation_type in CLAMPED_ACTIVATION_TYPES: - return ActivationParams(ClampedSwigluParams(**kwargs)) - else: - return ActivationParams() # Default params for activations without parameters - - def __hash__(self): - return hash((self.clamped_swiglu,)) - - def to_ffi_lowering_dict(self): - return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()} - - -@dataclass(frozen=True) -class ClampedSwigluParams: limit: float = 7.0 alpha: float = 1.702 - """Parameters for the Clamped SwiGLU activation function - used in GPT OSS.""" def __hash__(self): + """Custom hash function to ensure dataclass is hashable for jax jit to work. + + Returns: + int: Hash value of the dataclass instance. + """ return hash((self.limit, self.alpha)) def to_ffi_lowering_dict(self): + """Convert the activation parameters to a dictionary format for FFI lowering. + + Returns: + dict: A dictionary representation of the activation parameters consumable by + XLA FFI bindings for activation functions. + """ return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} @dataclass(frozen=True) class ActivationParams: + """Parameters for various activation functions. + Currently only Clamped SwiGLU activation has parameters. + """ clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams() - - # Add other activation-specific parameter fields here as needed in the future + @staticmethod def create(activation_type, **kwargs): """Factory method to create ActivationParams based on activation_type.""" @@ -125,13 +99,18 @@ def create(activation_type, **kwargs): } if activation_type in CLAMPED_ACTIVATION_TYPES: return ActivationParams(ClampedSwigluParams(**kwargs)) - else: - return ActivationParams() # Default params for activations without parameters + return ActivationParams() # Default params for activations without parameters def __hash__(self): + """Custom hash function to ensure dataclass is hashable for jax jit to work""" return hash((self.clamped_swiglu,)) def to_ffi_lowering_dict(self): + """Convert the activation parameters to a dictionary format for FFI lowering. + Returns: + dict: A dictionary representation of the activation parameters consumable by + XLA FFI bindings for activation functions. + """ return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()} From 5480d29a834c1d1570557c2574fce8de64019214 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Sep 2025 03:18:27 +0000 Subject: [PATCH 38/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/jax/encoder/test_multigpu_encoder.py | 2 +- transformer_engine/jax/cpp_extensions/activation.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index b0bf185511..fd95e8920e 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -53,7 +53,7 @@ def __call__(self, x, mask, disable_dropout=False): layer_type=te_flax.TransformerLayerType.ENCODER, self_attn_mask_type="padding", enable_relative_embedding=False, - mlp_activations=('silu', 'linear'), + mlp_activations=("silu", "linear"), ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index a32f15a489..c0793847d0 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -61,6 +61,7 @@ class ClampedSwigluParams: """Parameters for the Clamped SwiGLU activation function used in GPT OSS.""" + limit: float = 7.0 alpha: float = 1.702 @@ -87,8 +88,9 @@ class ActivationParams: """Parameters for various activation functions. Currently only Clamped SwiGLU activation has parameters. """ + clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams() - + @staticmethod def create(activation_type, **kwargs): """Factory method to create ActivationParams based on activation_type.""" @@ -108,7 +110,7 @@ def __hash__(self): def to_ffi_lowering_dict(self): """Convert the activation parameters to a dictionary format for FFI lowering. Returns: - dict: A dictionary representation of the activation parameters consumable by + dict: A dictionary representation of the activation parameters consumable by XLA FFI bindings for activation functions. """ return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()} From 7a917ea9c41ab4e8111aab3821113cf42f2c5bac Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 24 Sep 2025 03:26:35 +0000 Subject: [PATCH 39/53] fix jax linting errors Signed-off-by: Varun Thumbe --- transformer_engine/jax/activation.py | 2 -- transformer_engine/jax/cpp_extensions/activation.py | 1 - 2 files changed, 3 deletions(-) diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index c4e7cf2fa2..daa3679c48 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -8,11 +8,9 @@ from typing import Sequence, Union, Callable, Optional from functools import partial -from dataclasses import dataclass import jax import jax.numpy as jnp -import numpy as np from . import cpp_extensions as tex from .quantize.tensor import NoScaleTensor diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index c0793847d0..db7aa34179 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -6,7 +6,6 @@ import operator from functools import reduce, partial from dataclasses import dataclass -from packaging import version import jax import jax.numpy as jnp From 53dd179b10a55d1a9541c25deab29fdf11347963 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 24 Sep 2025 03:29:11 +0000 Subject: [PATCH 40/53] revert multi_gpu_encoder change Signed-off-by: Varun Thumbe --- examples/jax/encoder/test_multigpu_encoder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index fd95e8920e..bc6a567521 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -53,7 +53,6 @@ def __call__(self, x, mask, disable_dropout=False): layer_type=te_flax.TransformerLayerType.ENCODER, self_attn_mask_type="padding", enable_relative_embedding=False, - mlp_activations=("silu", "linear"), ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) From 3bfae54a5eae0ccab1d6ba58e05cef23e3f55f19 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Thu, 25 Sep 2025 02:02:11 +0000 Subject: [PATCH 41/53] fix flax integration bug Signed-off-by: Varun Thumbe --- transformer_engine/jax/cpp_extensions/activation.py | 1 - transformer_engine/jax/layernorm_mlp.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index db7aa34179..8a03eac3f8 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -619,7 +619,6 @@ def abstract( jax_dtype_to_te_dtype(out_dtype), scaling_mode, is_2x, - act_params, ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 9d59629427..20541e719b 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -137,7 +137,7 @@ def layernorm_mlp( return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, From 38382dcf69e806ffc0e3e20c6808bcbe0d4bbfd9 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Thu, 25 Sep 2025 02:05:19 +0000 Subject: [PATCH 42/53] fix linting error Signed-off-by: Varun Thumbe --- transformer_engine/jax/cpp_extensions/activation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 8a03eac3f8..925c1d01ae 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -567,7 +567,7 @@ def abstract( """ te_dact_dbias_quantize_p abstract """ - del act_enum + del act_enum, act_params dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype == dz_dtype From c7ef0780fcc47c585c07acd789e85e00bf990da3 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 26 Sep 2025 18:58:36 +0000 Subject: [PATCH 43/53] bug fixed in other branch and not here Signed-off-by: Varun Thumbe --- transformer_engine/common/util/cast_gated_kernels.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index edac5a7c62..1c1578ac53 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -188,7 +188,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = min(act_elt, p.limit); const float s = sigmoidf(p.alpha * x); act_x = x * s; - if (x < p.limit) { + if (act_elt < p.limit) { dact_x = s + s * (1 - s) * p.alpha * x; } else { dact_x = 0.0f; @@ -507,7 +507,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = min(act_elt, p.limit); const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x < p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; + dact_x = act_elt < p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; } else { if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); @@ -762,7 +762,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = min(act_elt, p.limit); const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x < p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; + dact_x = act_elt < p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; } else { if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { const float s = sigmoidf(x); From 2a2e6def0c0cd272f75f0307f05dabe858eda30a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Sep 2025 18:52:54 +0000 Subject: [PATCH 44/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/csrc/extensions.h | 1 - transformer_engine/jax/layernorm_mlp.py | 1 - 2 files changed, 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1086130ef6..3edc99ecf9 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -142,7 +142,6 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); } // namespace jax } // namespace transformer_engine - XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig, ::xla::ffi::StructMember("limit"), ::xla::ffi::StructMember("alpha")); diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 15388d760f..35a130e0b4 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -146,7 +146,6 @@ def layernorm_mlp( return output - @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) def _layernorm_mlp( x: jnp.ndarray, From b2f4fcbfe0ded2f1cac88aaa3fd3e619f6570739 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 29 Sep 2025 21:49:42 +0000 Subject: [PATCH 45/53] bug in dbias computation Signed-off-by: Varun Thumbe --- tests/jax/test_custom_call_compute.py | 1 + transformer_engine/jax/cpp_extensions/activation.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index d6192d33aa..1c5bb7e828 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -777,6 +777,7 @@ def test_quantize_dbias( def _test_quantize_dact_dbias( self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout ): + key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 925c1d01ae..75a3518af7 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1281,6 +1281,7 @@ def quantize_dact_dbias( ) is_gated = act_len == 2 + print(f"is_gated: {is_gated}, act_len: {act_len}") # TE/common does not support DelayedScaling2x for gated-act yet if is_gated: war_output = try_apply_delayed_scaling_2x_war( @@ -1299,8 +1300,8 @@ def quantize_dact_dbias( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( - dz=dz, - x=x, + dz=dz.astype(jnp.float32), + x=x.astype(jnp.float32), activation_type=activation_type, quantizer=None, act_params=act_params, From 4f41c1b73a918b840e36b66b1ceb28062a537c42 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Oct 2025 04:34:10 +0000 Subject: [PATCH 46/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index b7adbadfd7..8041f2fd51 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -777,7 +777,7 @@ def test_quantize_dbias( def _test_quantize_dact_dbias( self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout ): - + key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1) From d2072b13e7d0f123acf42b2629ff5f5b708a480d Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 1 Oct 2025 18:35:56 +0000 Subject: [PATCH 47/53] address review comments Signed-off-by: Varun Thumbe --- transformer_engine/jax/csrc/extensions/activation.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 167eae9bc5..443df364a6 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -154,8 +154,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Attr("act_enum") .Attr("scaling_mode") .Attr("is_2x") - .Attr( - "act_params"), // Can generalize the config later if we have more activations that need params + .Attr("act_params"), FFI_CudaGraph_Traits); Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, From 978fcde029f51546ce8cbddb0e7360823017423d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Oct 2025 18:38:08 +0000 Subject: [PATCH 48/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/csrc/extensions/activation.cpp | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 443df364a6..4ce1f1c92c 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -140,22 +140,21 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - ActLuHandler, ActLuFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // input - .Arg() // scale - .Ret() // output - .Ret() // colwise output - .Ret() // scale_inv - .Ret() // scale_inv colwise - .Ret() // amax - .Attr("act_enum") - .Attr("scaling_mode") - .Attr("is_2x") - .Attr("act_params"), - FFI_CudaGraph_Traits); +XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Attr("act_enum") + .Attr("scaling_mode") + .Attr("is_2x") + .Attr("act_params"), + FFI_CudaGraph_Traits); Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Result_Type output_buf, Result_Type colwise_output_buf, From 13a3e3c3b59c92cfc9db28a54d535375125d21cc Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 1 Oct 2025 18:50:29 +0000 Subject: [PATCH 49/53] minor bug because of merge conflict Signed-off-by: Varun Thumbe --- .../jax/csrc/extensions/activation.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 443df364a6..b71234e1bb 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -161,10 +161,10 @@ Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, int64_t act_enum, - JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { + JAXX_Scaling_Mode scaling_mode, bool is_2x_int, ActivationConfig act_params) { return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, - act_enum, scaling_mode, is_2x_int); + act_enum, scaling_mode, is_2x_int, act_params); } XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, @@ -179,7 +179,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, .Ret() // amax .Attr("act_enum") .Attr("scaling_mode") - .Attr("is_2x")); + .Attr("is_2x") + .Attr("act_params")); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, @@ -460,11 +461,13 @@ Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type inp Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, - bool is_2x, bool is_dbias) { + bool is_2x, bool is_dbias, + ActivationConfig act_params) { return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, act_input_buf, scale_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, - workspace_buf, scaling_mode, act_enum, is_2x, is_dbias); + workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, + act_params); } XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, @@ -484,7 +487,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, .Attr("scaling_mode") .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias")); + .Attr("is_dbias") + .Attr("act_params")); } // namespace jax } // namespace transformer_engine From d59526b4e57e801e062790e986c8bb4153b04f1e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Oct 2025 18:51:14 +0000 Subject: [PATCH 50/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/csrc/extensions/activation.cpp | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 7fd57749d8..887d1728c6 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -160,7 +160,8 @@ Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, int64_t act_enum, - JAXX_Scaling_Mode scaling_mode, bool is_2x_int, ActivationConfig act_params) { + JAXX_Scaling_Mode scaling_mode, bool is_2x_int, + ActivationConfig act_params) { return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, act_enum, scaling_mode, is_2x_int, act_params); @@ -453,20 +454,16 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("act_params"), FFI_CudaGraph_Traits); -Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, - Buffer_Type act_input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, - Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type amax_buf, - Result_Type dbias_buf, Result_Type workspace_buf, - JAXX_Scaling_Mode scaling_mode, int64_t act_enum, - bool is_2x, bool is_dbias, - ActivationConfig act_params) { +Error_Type DActLuDBiasQuantizeInitializeFFI( + cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, + Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, + Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, + bool is_dbias, ActivationConfig act_params) { return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, act_input_buf, scale_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, - workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, - act_params); + workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params); } XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, From 47835144e398f04750434aa009644a3ce8c6e6f8 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 1 Oct 2025 20:23:08 +0000 Subject: [PATCH 51/53] accept copilot suggestion Signed-off-by: Varun Thumbe --- transformer_engine/jax/csrc/extensions.h | 2 +- transformer_engine/jax/csrc/extensions/activation.cpp | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index c3cc6ac429..bbfc62120a 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -42,7 +42,7 @@ struct ClampedSwigluConfig { }; struct ActivationConfig { - ClampedSwigluConfig clamped_swiglu_config; + ClampedSwigluConfig clamped_swiglu; }; inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 7fd57749d8..cf7fe83a88 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -20,8 +20,8 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int, ActivationConfig act_params) { // parameters for clamped swiglu used in GPT OSS - auto swiglu_limit = act_params.clamped_swiglu_config.limit; - auto swiglu_alpha = act_params.clamped_swiglu_config.alpha; + auto swiglu_limit = act_params.clamped_swiglu.limit; + auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -252,8 +252,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, int64_t act_enum, bool is_2x, bool is_dbias, ActivationConfig act_params) { // parameters for clamped swiglu used in GPT OSS - auto swiglu_limit = act_params.clamped_swiglu_config.limit; - auto swiglu_alpha = act_params.clamped_swiglu_config.alpha; + auto swiglu_limit = act_params.clamped_swiglu.limit; + auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); From 14f89711f0a34f65386e7da9023be387b6e7120e Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 1 Oct 2025 22:53:16 +0000 Subject: [PATCH 52/53] fix test and remove a redundant test addition Signed-off-by: Varun Thumbe --- tests/jax/test_custom_call_compute.py | 2 +- tests/pytorch/test_fusible_ops.py | 72 ------------------- .../jax/cpp_extensions/activation.py | 5 +- 3 files changed, 3 insertions(+), 76 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 8041f2fd51..ab4866ca26 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -829,7 +829,7 @@ def _test_quantize_dact_dbias( (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()) # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently. or ( - activation_type == ("squared_relu",) + activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")} and in_dtype == jnp.bfloat16 and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING ) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index e97d9a1853..231fa64bc1 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1810,78 +1810,6 @@ def test_clamped_swiglu( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) - @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", _quantization_list) - @pytest.mark.parametrize("quantize_forward", (False, True)) - @pytest.mark.parametrize("quantize_backward", (False, True)) - def test_clamped_swiglu( - self, - *, - out_shape: Iterable[int] = (32, 32), - dtype: torch.dtype, - device: torch.device = "cuda", - quantization: Optional[str], - quantize_forward: bool, - quantize_backward: bool, - limit: float = 0.75, - alpha: float = 1.702, - ): - # Test SwiGLU variant used in GPT OSS. - # Tensor dimensions - in_shape = list(out_shape) - in_shape[-1] *= 2 - - # Skip invalid configurations - quantized_compute = quantization is not None - if not quantized_compute and (quantize_forward or quantize_backward): - pytest.skip("Quantization scheme has not been provided") - maybe_skip_quantization(quantization, dims=in_shape, device=device) - - # Random data - x_ref, x_test = make_reference_and_test_tensors( - in_shape, - test_dtype=dtype, - test_device=device, - ) - dy_ref, dy_test = make_reference_and_test_tensors( - out_shape, - test_dtype=dtype, - test_device=device, - requires_grad=False, - ) - - # Plain PyTorch implementation - x_glu, x_linear = x_ref.chunk(2, dim=-1) - x_glu = x_glu.clamp(min=None, max=limit) - x_linear = x_linear.clamp(min=-limit, max=limit) - out_glu = x_glu * torch.sigmoid(alpha * x_glu) - y_ref = out_glu * (x_linear + 1) - y_ref.backward(dy_ref) - - # Implementation with fusible operation - recipe = make_recipe(quantization) - - forward = te_ops.Sequential( - te_ops.Quantize(forward=False, backward=quantize_backward), - te_ops.ClampedSwiGLU(limit=limit, alpha=alpha), - te_ops.Quantize(forward=quantize_forward, backward=False), - ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): - y_test = forward(x_test) - - y_test.backward(dy_test) - - # Expected numerical error - tols = dtype_tols(dtype) - if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) - - # Check results - y_test = y_test.to(dtype=torch.float64, device="cpu") - dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(y_test, y_ref, **tols) - torch.testing.assert_close(dx_test, x_ref.grad, **tols) - @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) @pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2))) @pytest.mark.parametrize("dtype", _dtypes) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 75a3518af7..925c1d01ae 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1281,7 +1281,6 @@ def quantize_dact_dbias( ) is_gated = act_len == 2 - print(f"is_gated: {is_gated}, act_len: {act_len}") # TE/common does not support DelayedScaling2x for gated-act yet if is_gated: war_output = try_apply_delayed_scaling_2x_war( @@ -1300,8 +1299,8 @@ def quantize_dact_dbias( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( - dz=dz.astype(jnp.float32), - x=x.astype(jnp.float32), + dz=dz, + x=x, activation_type=activation_type, quantizer=None, act_params=act_params, From 5a55b0dbb8b03a88b3d2dc1c49ab266ae31080e0 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 3 Oct 2025 16:22:00 +0000 Subject: [PATCH 53/53] address review comments Signed-off-by: Varun Thumbe --- tests/jax/test_custom_call_compute.py | 1 - transformer_engine/jax/layernorm_mlp.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index ab4866ca26..7a4fa268af 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -177,7 +177,6 @@ def assert_dequantized_grouped_scaled_tensor( "L0": [ ("gelu",), ("gelu", "linear"), - ("clamped_silu", "clamped_linear"), ], "L2": ALL_ACTIVATION_TYPES, } diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 35a130e0b4..77daa4672c 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -336,9 +336,6 @@ def _layernorm_mlp_fwd_rule( dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # (batch..., hidden_in) -> (batch..., hidden) - # At the moment the act_params is only used for ClampedSwiglu - # If there are more activations that require parameters in the future - # we might need to change it to a more generic parameter container casted_act_out = tex.act_lu( dot_1_output, activation_type,