diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py index b8a0bad0ca08..e5bc55c32751 100644 --- a/python/tvm/relax/backend/contrib/cublas.py +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -25,7 +25,7 @@ from tvm.relax.transform import PatternCheckContext from ..pattern_registry import get_patterns_with_prefix, register_patterns -from ..patterns import make_matmul_pattern +from ..patterns import make_matmul_pattern, make_matmul_dequantize_pattern from ..utils import has_leaking_intermediate_variables @@ -48,6 +48,16 @@ def _check_matmul(context: PatternCheckContext) -> bool: rhs = context.annotated_expr["rhs"] matmul_call = context.annotated_expr["root"] + if "scale" in context.annotated_expr and "zp" in context.annotated_expr: + scale = context.annotated_expr["scale"] + zero_point = context.annotated_expr["zp"] + # Only scalar values for scale and zero_point are supported. + if scale.struct_info.ndim != 0 or zero_point.struct_info.ndim != 0: + return False + # Only zero_point == 0.0 is supported. + if zero_point.data.numpy()[()].item() != 0.0: + return False + lhs_dtype = lhs.struct_info.dtype rhs_dtype = rhs.struct_info.dtype out_dtype = matmul_call.struct_info.dtype @@ -187,11 +197,16 @@ def _check_matmul(context: PatternCheckContext) -> bool: ), _check_matmul, ), + ( + "cublas.matmul_transposed_dequantize", + *make_matmul_dequantize_pattern(transposed_rhs=True), + _check_matmul, + ), ] ) -def partition_for_cublas(mod): +def partition_for_cublas(mod, bind_constants=False): """ Partition the input module into cuBLAS-supported subgraphs. @@ -200,6 +215,9 @@ def partition_for_cublas(mod): mod: tvm.IRModule The IRModule to be partitioned. + bind_constants : bool + Whether or not to keep bound constants in the grouped function. + Returns ------- mod: tvm.IRModule @@ -208,4 +226,6 @@ def partition_for_cublas(mod): """ patterns = get_patterns_with_prefix("cublas") - return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod) + return transform.FuseOpsByPattern( + patterns, bind_constants=bind_constants, annotate_codegen=True + )(mod) diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 23de175b24f6..404f7dc97526 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -336,6 +336,46 @@ def make_rms_norm_pattern(): return out, annotations +def make_matmul_dequantize_pattern( + transposed_rhs: bool = False, +) -> Tuple[DFPattern, Mapping[str, DFPattern]]: + """ + Create pattern for matrix multiplication and dequantize operation. + + Parameters + ---------- + transposed_rhs: bool + Whether the right hand side of multiplication is transposed. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a matrix multiplication. + + annotations: Mapping[str, DFPattern] + A mapping from name to sub pattern. It can be used to extract important expressions from + match result, to power the partition check function and codegen. + """ + + lhs = wildcard() + rhs = wildcard() + annotations = {"lhs": lhs, "rhs": rhs} + + if transposed_rhs: + rhs = is_op("relax.permute_dims")(rhs) + + out = is_op("relax.matmul")(lhs, rhs) + annotations["root"] = out + + scale = is_const() + zp = is_const() + annotations.update({"scale": scale, "zp": zp}) + + out = is_op("relax.dequantize")(out, scale, zp) + + return out, annotations + + def make_attention_rewrite_pattern( qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False ): diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index e573d9a12385..9f29d21aaa3d 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -22,6 +22,7 @@ * \brief Implementation of the CUBLAS JSON serializer. */ #include +#include #include @@ -74,6 +75,25 @@ class CublasJSONSerializer : public JSONSerializer { auto node = std::make_shared(composite_name, /* name_ */ "kernel", /* op_type_ */ inputs, 1 /* num_outputs_ */); + if (composite_name.find("dequantize") != std::string::npos) { + const CallNode* dequantize_call = backend::GetOpInFunction(fn, "relax.dequantize"); + if (dequantize_call->args[1]->IsInstance()) { + const auto* const_expr = dequantize_call->args[1].as(); + auto sinfo = Downcast(const_expr->struct_info_); + float alpha = 1.0; + if (sinfo->dtype == DataType::Float(16)) { + alpha = __gnu_h2f_ieee(static_cast(const_expr->data->data)[0]); + } else { + ICHECK(sinfo->dtype == DataType::Float(32)); + alpha = static_cast(const_expr->data->data)[0]; + } + + std::vector dq_scale = {backend::to_str(alpha)}; + std::vector dq_scale_attr; + dq_scale_attr.emplace_back(dq_scale); + node->SetAttr("dq_scale", dq_scale_attr); + } + } const CallNode* root_call = backend::GetOpInFunction(fn, "relax.matmul"); SetCallNodeAttribute(node, root_call); diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h index 412651d3f990..e0195a61950f 100644 --- a/src/relax/backend/contrib/utils.h +++ b/src/relax/backend/contrib/utils.h @@ -137,6 +137,18 @@ inline const CallNode* GetOpInFunction(Function f, const std::string& op_name) { */ Map ExtractArgIdx(String pattern_name, Function f); +/*! + * \brief Converts a numeric value to std::string. + * \param value A numeric value to convert. + * \return String representation of a numeric value. + */ +template +std::string to_str(const Type& value) { + std::ostringstream os; + os << std::setprecision(12) << value; + return os.str(); +} + } // namespace backend } // namespace relax } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 553d4014c0b4..1edb6b95c962 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -138,7 +138,8 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; } void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, bool transa, bool transb, - void* workspace_ptr, size_t workspace_size, cublasLtEpilogue_t epilogue) { + void* workspace_ptr, size_t workspace_size, cublasLtEpilogue_t epilogue, + std::optional dq_scale) { ICHECK(TypeEqual(A->dtype, B->dtype)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed(A) ? !transa : transa; @@ -152,7 +153,10 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, float zero_fp32 = 0.0; int32_t one_i32 = 1; int32_t zero_i32 = 0; - void* alpha = &one_fp32; + // Pass dequantization scale through the "alpha" parameter. If there is no dequantization after + // matmul, then alpha == 1.0 + float alpha_value = dq_scale.value_or(one_fp32); + void* alpha = &alpha_value; void* beta = &zero_fp32; if (TypeMatch(A->dtype, kDLFloat, 16)) { diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 1a072a92eb8b..8578d86789b8 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -129,9 +129,15 @@ class CublasJSONRuntime : public JSONRuntimeBase { auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT); + std::optional dq_scale = std::nullopt; + if (op_name.find("dequantize") != std::string::npos) { + dq_scale = std::stof(node.GetAttr>("dq_scale")[0]); + } + tvm::contrib::CallCublasLt(entry_ptr->handle, stream, entry_ptr->matmul_pref_desc, a_ptr, b_ptr, bias_ptr, out_ptr, transa, transb, - entry_ptr->workspace_ptr, entry_ptr->workspace_size, epilogue); + entry_ptr->workspace_ptr, entry_ptr->workspace_size, epilogue, + dq_scale); } } } diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 5c5cb6920860..2906279f904a 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -34,6 +34,7 @@ #if CUDART_VERSION >= 10010 #include #endif // CUDART_VERSION >= 10010 +#include namespace tvm { namespace contrib { @@ -124,7 +125,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, bool transa, bool transb, void* workspace_ptr, size_t workspace_size, - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT); + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT, + std::optional dq_scale = std::nullopt); } // namespace contrib } // namespace tvm diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index ea0861467faa..4ff498ae2b93 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -24,6 +24,8 @@ from tvm.relax.backend.contrib.cublas import partition_for_cublas from tvm.relax.testing import get_relax_matmul_module from tvm.script import relax as R +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import relax as relax_builder try: import ml_dtypes @@ -60,8 +62,8 @@ def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False): return f(*inputs).numpy() -def get_result_with_relax_cublas_offload(mod, np_inputs, cuda_graph=False): - mod = partition_for_cublas(mod) +def get_result_with_relax_cublas_offload(mod, np_inputs, cuda_graph=False, bind_constants=False): + mod = partition_for_cublas(mod, bind_constants=bind_constants) mod = relax.transform.RunCodegen()(mod) return build_and_run(mod, np_inputs, "cuda", cuda_graph) @@ -95,6 +97,43 @@ def _to_concrete_shape(symbolic_shape, var_table): } +def get_relax_matmul_dequantize_module( + x_shape, + y_shape, + in_dtype, + out_dtype, + transposed_y=False, + scale_const=1.0, + zero_point_const=0.0, +): + """Create a matmul op followd by dequantize operations.""" + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + x = R.arg("x", R.Tensor(x_shape, in_dtype)) + y = R.arg("y", R.Tensor(y_shape, in_dtype)) + + with R.dataflow() as frame: + if transposed_y: + axes = list(range(len(y_shape) - 2)) + [-1, -2] + y = R.emit(R.permute_dims(y, axes=axes)) + result = R.emit(R.matmul(x, y, out_dtype="float32")) + result = R.emit( + R.dequantize( + result, + scale=R.const(scale_const, "float16"), + zero_point=R.const(zero_point_const, "float16"), + axis=-1, + out_dtype=out_dtype, + ) + ) + R.output(result) + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + @pytest.mark.parametrize( "x_shape, y_shape, transpose_y, epilogue", [ @@ -262,6 +301,32 @@ def test_matmul_fp8_offload( tvm.testing.assert_allclose(out, ref_out, rtol=1e-3, atol=1e-3) +@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed") +def test_matmul_fp8_dequantize_offload(): + x_shape = (10, 32) + y_shape = (64, 32) + in_dtype = "e4m3_float8" + mod = get_relax_matmul_dequantize_module( + x_shape, + y_shape, + in_dtype, + "float16", + transposed_y=True, + scale_const=0.34786, + zero_point_const=0.0, + ) + + numpytype = "float8_e4m3fn" + x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype) + y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype) + args = (x, y) + + out = get_result_with_relax_cublas_offload(mod, args, bind_constants=True) + ref = build_and_run(mod, args, "llvm", legalize=True) + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) + + @pytest.mark.parametrize( "M, N, K, out_dtype, transposed_y, partition_done", [ @@ -283,6 +348,29 @@ def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, transposed_y, partition assert func_name in mod["main"].script() +@pytest.mark.parametrize( + "M, N, K, scale, zp, num_bindings", + [ + (16, 64, 32, 2.0, 0.0, 1), + (16, 64, 32, 2.0, 1.0, 2), + (16, 64, 32, [2.0] * 64, [2.0] * 64, 2), + ], +) +def test_cublas_partition_fp8_matmul_dequantize(M, N, K, scale, zp, num_bindings): + mod = get_relax_matmul_dequantize_module( + (M, K), + (N, K), + "e4m3_float8", + "float16", + transposed_y=True, + scale_const=scale, + zero_point_const=zp, + ) + mod = partition_for_cublas(mod) + # Check whether R.dequantize is still in main function or not + assert len(mod["main"].body.blocks[0].bindings) == num_bindings + + def test_cublas_partition_matmul_without_bias(): # cuBLAS does not handle 2D bias (residual input) mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", bias_shape=(16, 32))