Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 2 additions & 20 deletions backends/vulkan/test/op_tests/linear_weight_int4_test.cpp
Original file line number Diff line number Diff line change
@@ -14,6 +14,8 @@
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include "test_utils.h"

#include <cassert>

//
@@ -201,26 +203,6 @@ void test_reference_linear_qcs4w(
ASSERT_TRUE(at::allclose(out, out_ref));
}

vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
using namespace vkcompute;
switch (at_scalartype) {
case c10::kFloat:
return vkapi::kFloat;
case c10::kHalf:
return vkapi::kHalf;
case c10::kInt:
return vkapi::kInt;
case c10::kLong:
return vkapi::kInt;
case c10::kChar:
return vkapi::kChar;
case c10::kByte:
return vkapi::kByte;
default:
VK_THROW("Unsupported at::ScalarType!");
}
}

void test_vulkan_linear_qga4w_impl(
const int B,
const int M,
22 changes: 2 additions & 20 deletions backends/vulkan/test/op_tests/rotary_embedding_test.cpp
Original file line number Diff line number Diff line change
@@ -14,6 +14,8 @@
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include "test_utils.h"

#include <cassert>

//
@@ -55,26 +57,6 @@ std::pair<at::Tensor, at::Tensor> rotary_embedding_impl(
// Test functions
//

vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
using namespace vkcompute;
switch (at_scalartype) {
case c10::kFloat:
return vkapi::kFloat;
case c10::kHalf:
return vkapi::kHalf;
case c10::kInt:
return vkapi::kInt;
case c10::kLong:
return vkapi::kInt;
case c10::kChar:
return vkapi::kChar;
case c10::kByte:
return vkapi::kByte;
default:
VK_THROW("Unsupported at::ScalarType!");
}
}

void test_reference(
const int n_heads = 4,
const int n_kv_heads = 2,
20 changes: 2 additions & 18 deletions backends/vulkan/test/op_tests/sdpa_test.cpp
Original file line number Diff line number Diff line change
@@ -18,6 +18,8 @@
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
#include <executorch/extension/llm/custom_ops/op_sdpa.h>

#include "test_utils.h"

#include <cassert>
#include <iostream>

@@ -261,24 +263,6 @@ void test_reference_sdpa(
}
}

vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
using namespace vkcompute;
switch (at_scalartype) {
case c10::kFloat:
return vkapi::kFloat;
case c10::kHalf:
return vkapi::kHalf;
case c10::kInt:
return vkapi::kInt;
case c10::kLong:
return vkapi::kInt;
case c10::kChar:
return vkapi::kChar;
default:
VK_THROW("Unsupported at::ScalarType!");
}
}

void test_vulkan_sdpa(
const int start_input_pos,
const int base_sequence_len,
37 changes: 35 additions & 2 deletions backends/vulkan/test/op_tests/targets.bzl
Original file line number Diff line number Diff line change
@@ -142,6 +142,28 @@ def define_common_targets(is_fbcode = False):
platforms = get_platforms(),
)

runtime.cxx_library(
name = "test_utils",
srcs = [
"test_utils.cpp",
],
headers = [
"test_utils.h",
],
exported_headers = [
"test_utils.h",
],
deps = [
"//executorch/backends/vulkan:vulkan_graph_runtime",
"//executorch/runtime/core/exec_aten:lib",
runtime.external_dep_location("libtorch"),
],
visibility = [
"//executorch/backends/vulkan/test/op_tests/...",
"@EXECUTORCH_CLIENTS",
],
)

define_test_targets(
"compute_graph_op_tests",
src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]"
@@ -150,9 +172,20 @@ def define_common_targets(is_fbcode = False):
define_test_targets(
"sdpa_test",
extra_deps = [
":test_utils",
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/extension/tensor:tensor",
]
)
define_test_targets("linear_weight_int4_test")
define_test_targets("rotary_embedding_test")
define_test_targets(
"linear_weight_int4_test",
extra_deps = [
":test_utils",
]
)
define_test_targets(
"rotary_embedding_test",
extra_deps = [
":test_utils",
]
)
114 changes: 114 additions & 0 deletions backends/vulkan/test/op_tests/test_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "test_utils.h"

#include <stdexcept>

executorch::aten::ScalarType at_scalartype_to_et_scalartype(
at::ScalarType dtype) {
using ScalarType = executorch::aten::ScalarType;
switch (dtype) {
case at::kByte:
return ScalarType::Byte;
case at::kChar:
return ScalarType::Char;
case at::kShort:
return ScalarType::Short;
case at::kInt:
return ScalarType::Int;
case at::kLong:
return ScalarType::Long;
case at::kHalf:
return ScalarType::Half;
case at::kFloat:
return ScalarType::Float;
case at::kDouble:
return ScalarType::Double;
default:
throw std::runtime_error("Unsupported dtype");
}
}

std::string scalar_type_name(c10::ScalarType dtype) {
switch (dtype) {
case c10::kLong:
return "c10::kLong";
case c10::kShort:
return "c10::kShort";
case c10::kComplexHalf:
return "c10::kComplexHalf";
case c10::kComplexFloat:
return "c10::kComplexFloat";
case c10::kComplexDouble:
return "c10::kComplexDouble";
case c10::kBool:
return "c10::kBool";
case c10::kQInt8:
return "c10::kQInt8";
case c10::kQUInt8:
return "c10::kQUInt8";
case c10::kQInt32:
return "c10::kQInt32";
case c10::kBFloat16:
return "c10::kBFloat16";
case c10::kQUInt4x2:
return "c10::kQUInt4x2";
case c10::kQUInt2x4:
return "c10::kQUInt2x4";
case c10::kFloat:
return "c10::kFloat";
case c10::kHalf:
return "c10::kHalf";
case c10::kInt:
return "c10::kInt";
case c10::kChar:
return "c10::kChar";
case c10::kByte:
return "c10::kByte";
case c10::kDouble:
return "c10::kDouble";
case c10::kUInt16:
return "c10::kUInt16";
case c10::kBits16:
return "c10::kBits16";
default:
return "Unknown(" + std::to_string(static_cast<int>(dtype)) + ")";
}
}

vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
using namespace vkcompute;
switch (at_scalartype) {
case c10::kHalf:
return vkapi::kHalf;
case c10::kFloat:
return vkapi::kFloat;
case c10::kDouble:
return vkapi::kDouble;
case c10::kInt:
return vkapi::kInt;
case c10::kLong:
return vkapi::kLong;
case c10::kChar:
return vkapi::kChar;
case c10::kByte:
return vkapi::kByte;
case c10::kShort:
return vkapi::kShort;
case c10::kUInt16:
return vkapi::kUInt16;
default:
VK_THROW(
"Unsupported at::ScalarType: ",
scalar_type_name(at_scalartype),
" (",
static_cast<int>(at_scalartype),
")");
}
}
32 changes: 32 additions & 0 deletions backends/vulkan/test/op_tests/test_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <string>

#include <ATen/ATen.h>
#include <c10/core/ScalarType.h>
#include <executorch/backends/vulkan/runtime/api/api.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>

/**
* Convert at::ScalarType to executorch::ScalarType
*/
executorch::aten::ScalarType at_scalartype_to_et_scalartype(
at::ScalarType dtype);

/**
* Get the string name of a c10::ScalarType for better error messages
*/
std::string scalar_type_name(c10::ScalarType dtype);

/**
* Convert c10::ScalarType to vkcompute::vkapi::ScalarType
*/
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype);