diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py
index 5c59f13fc24..a137a7d538f 100644
--- a/backends/vulkan/runtime/gen_vulkan_spv.py
+++ b/backends/vulkan/runtime/gen_vulkan_spv.py
@@ -56,52 +56,97 @@
 TYPE_MAPPINGS: Dict[str, Any] = {
     "IMAGE_T": {
         3: {
+            "double": "image3D",
             "float": "image3D",
             "half": "image3D",
-            "int": "iimage3D",
-            "uint": "uimage3D",
+            # integer dtypes
             "int8": "iimage3D",
             "uint8": "uimage3D",
+            "int16": "iimage3D",
+            "uint16": "uimage3D",
+            "int32": "iimage3D",
+            "uint32": "uimage3D",
+            "int64": "iimage3D",
+            "uint64": "uimage3D",
+            # common dtype aliases
             "bool": "uimage3D",
+            "int": "iimage3D",
+            "uint": "uimage3D",
         },
         2: {
+            "double": "image2D",
             "float": "image2D",
             "half": "image2D",
-            "int": "iimage2D",
-            "uint": "uimage2D",
+            # integer dtypes
             "int8": "iimage2D",
             "uint8": "uimage2D",
+            "int16": "iimage2D",
+            "uint16": "uimage2D",
+            "int32": "iimage2D",
+            "uint32": "uimage2D",
+            "int64": "iimage2D",
+            "uint64": "uimage2D",
+            # common dtype aliases
             "bool": "uimage2D",
+            "int": "iimage2D",
+            "uint": "uimage2D",
         },
     },
     "SAMPLER_T": {
         3: {
+            "double": "sampler3D",
             "float": "sampler3D",
             "half": "sampler3D",
-            "int": "isampler3D",
-            "uint": "usampler3D",
+            # integer dtypes
             "int8": "isampler3D",
             "uint8": "usampler3D",
+            "int16": "isampler3D",
+            "uint16": "usampler3D",
+            "int32": "isampler3D",
+            "uint32": "usampler3D",
+            "int64": "isampler3D",
+            "uint64": "usampler3D",
+            # common dtype aliases
             "bool": "usampler3D",
+            "int": "isampler3D",
+            "uint": "usampler3D",
         },
         2: {
+            "double": "sampler2D",
             "float": "sampler2D",
             "half": "sampler2D",
-            "int": "isampler2D",
-            "uint": "usampler2D",
+            # integer dtypes
             "int8": "isampler2D",
             "uint8": "usampler2D",
+            "int16": "isampler2D",
+            "uint16": "usampler2D",
+            "int32": "isampler2D",
+            "uint32": "usampler2D",
+            "int64": "isampler2D",
+            "uint64": "usampler2D",
+            # common dtype aliases
             "bool": "usampler2D",
+            "int": "isampler2D",
+            "uint": "usampler2D",
         },
     },
     "IMAGE_FORMAT": {
+        "double": "rgba32f",
         "float": "rgba32f",
         "half": "rgba16f",
-        "int": "rgba32i",
-        "uint": "rgba32ui",
+        # integer dtypes
         "int8": "rgba8i",
         "uint8": "rgba8ui",
+        "int16": "rgba16i",
+        "uint16": "rgba16ui",
+        "int32": "rgba32i",
+        "uint32": "rgba32ui",
+        "int64": "rgba32i",
+        "uint64": "rgba32ui",
+        # common dtype aliases
         "bool": "rgba8ui",
+        "int": "rgba32i",
+        "uint": "rgba32ui",
     },
 }
 
@@ -118,10 +163,18 @@ def define_variable(name: str) -> str:
 def buffer_scalar_type(dtype: str) -> str:
     if dtype == "half":
         return "float16_t"
-    elif dtype[-1] == "8":
-        return dtype + "_t"
+    elif dtype == "float":
+        return "float"
+    elif dtype == "double":
+        return "float64_t"
+    # integer dtype alias conversion
     elif dtype == "bool":
         return "uint8_t"
+    # we don't want to append _t for int32 or uint32 as int is already 32bit
+    elif dtype == "int32" or dtype == "uint32":
+        return "int" if dtype == "int32" else "uint"
+    elif dtype[-1].isdigit():
+        return dtype + "_t"
     return dtype
 
 
@@ -129,22 +182,28 @@ def buffer_gvec_type(dtype: str, n: int) -> str:
     if n == 1:
         return buffer_scalar_type(dtype)
 
-    if dtype == "float":
-        return f"vec{n}"
-    if dtype == "uint":
-        return f"uvec{n}"
-    elif dtype == "half":
-        return f"f16vec{n}"
-    elif dtype == "int":
-        return f"ivec{n}"
-    elif dtype == "int8":
-        return f"i8vec{n}"
-    elif dtype == "uint8":
-        return f"u8vec{n}"
-    elif dtype == "bool":
-        return f"u8vec{n}"
-
-    raise AssertionError(f"Invalid dtype: {dtype}")
+    dtype_map = {
+        "half": f"f16vec{n}",
+        "float": f"vec{n}",
+        "double": f"vec{n}",  # No 64bit image format support in GLSL
+        "int8": f"i8vec{n}",
+        "uint8": f"u8vec{n}",
+        "int16": f"i16vec{n}",
+        "uint16": f"u16vec{n}",
+        "int32": f"ivec{n}",
+        "int": f"ivec{n}",
+        "uint32": f"uvec{n}",
+        "uint": f"uvec{n}",
+        "int64": f"ivec{n}",  # No 64bit image format support in GLSL
+        "uint64": f"uvec{n}",  # No 64bit image format support in GLSL
+        "bool": f"u8vec{n}",
+    }
+
+    vector_type = dtype_map.get(dtype)
+    if vector_type is None:
+        raise AssertionError(f"Invalid dtype: {dtype}")
+
+    return vector_type
 
 
 def texel_type(dtype: str) -> str:
@@ -365,15 +424,22 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
         if dtype == "half":
             nbit = "16bit"
             glsl_type = "float16"
-        elif dtype == "int16" or dtype == "uint16":
-            nbit = "16bit"
-            glsl_type = "int16"
-        elif dtype == "int8" or dtype == "uint8" or dtype == "bool":
+        elif dtype == "double":
+            # We only need to allow float64_t type usage
+            glsl_type = "float64"
+        elif dtype in ["int8", "uint8", "bool"]:
             nbit = "8bit"
             glsl_type = "int8"
+        elif dtype in ["int16", "uint16"]:
+            nbit = "16bit"
+            glsl_type = "int16"
+        elif dtype in ["int64", "uint64"]:
+            # We only need to allow int64_t and uint64_t type usage
+            glsl_type = "int64"
 
-        if nbit is not None and glsl_type is not None:
+        if nbit is not None:
             out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n"
+        if glsl_type is not None:
             out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n"
 
     return out_str
@@ -629,6 +695,10 @@ def generateVariantCombinations(
 
                     elif "VALUE" in value:
                         suffix = value.get("SUFFIX", value["VALUE"])
+                        if value["VALUE"] in ["int", "uint"]:
+                            raise ValueError(
+                                f"Use int32 or uint32 instead of {value['VALUE']}"
+                            )
                         param_values.append((param_name, suffix, value["VALUE"]))
 
                     else:
diff --git a/backends/vulkan/runtime/graph/ops/glsl/arange.yaml b/backends/vulkan/runtime/graph/ops/glsl/arange.yaml
index e3df8bf73a1..37b2027db85 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/arange.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/arange.yaml
@@ -7,13 +7,13 @@
 arange:
   parameter_names_with_default_values:
     NDIM: 3
-    DTYPE: int
+    DTYPE: int32
     STORAGE: texture3d
     PACKING: C_packed
   generate_variant_forall:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
   shader_variants:
     - NAME: arange
diff --git a/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml
index eddddec0d8d..b1e16dec8d6 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml
@@ -13,6 +13,6 @@ avg_pool2d:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
   shader_variants:
     - NAME: avg_pool2d
diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
index c0efdd81eb9..accfcf53599 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
@@ -17,7 +17,7 @@ binary_op:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
   shader_variants:
     - NAME: binary_add
     - NAME: binary_sub
diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml
index 9abd9c1deac..e8bb86dbf6a 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml
@@ -12,8 +12,9 @@ buffer_to_buffer:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: double
       - VALUE: int8
       - VALUE: uint8
+      - VALUE: int32
   shader_variants:
     - NAME: buffer_to_buffer
diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml
index e48eab63a64..679e686dc2f 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml
@@ -13,9 +13,10 @@ buffer_to_nchw:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: double
       - VALUE: int8
       - VALUE: uint8
+      - VALUE: int32
   shader_variants:
     - NAME: buffer_to_nchw
     - NAME: buffer_to_nchw_no_pc
diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml
index 414bf8191b9..984d9a09d43 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml
@@ -7,6 +7,6 @@ copy_channel_offset:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
   shader_variants:
     - NAME: copy_channel_offset
diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml
index 87df7bf9dc1..09f5ca36ea4 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml
@@ -7,7 +7,7 @@ copy_offset:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
       - VALUE: int8
       - VALUE: uint8
     STORAGE:
diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml
index e872d64e3c3..6e55876cb28 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml
@@ -7,6 +7,6 @@ copy_packed_dim_offset:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
   shader_variants:
     - NAME: copy_packed_dim_offset
diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml b/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml
index 5ffe37265b1..0e7b491c433 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml
@@ -7,6 +7,6 @@ embedding:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
   shader_variants:
     - NAME: embedding
diff --git a/backends/vulkan/runtime/graph/ops/glsl/flip.yaml b/backends/vulkan/runtime/graph/ops/glsl/flip.yaml
index 646fd05e420..f5e7c874773 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/flip.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/flip.yaml
@@ -6,8 +6,9 @@ flip:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: double
       - VALUE: int8
       - VALUE: uint8
+      - VALUE: int32
   shader_variants:
     - NAME: flip
diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml
index 804ce19bdb8..646d8f1be81 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml
@@ -14,9 +14,10 @@ image_to_nchw:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: double
       - VALUE: int8
       - VALUE: uint8
+      - VALUE: int32
   shader_variants:
     - NAME: image_to_nchw_texture3d
     - NAME: image_to_nchw_texture2d
diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml
index 5a6c525993e..abef2225cd9 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml
@@ -7,6 +7,6 @@ index_select:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
   shader_variants:
     - NAME: index_select
diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml
index 66cb7ec3f89..a306e3ce47d 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml
@@ -7,6 +7,6 @@ index_select_channel:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
   shader_variants:
     - NAME: index_select_channel
diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml
index 486d710cf55..99e41a0ab6f 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml
@@ -13,9 +13,10 @@ nchw_to_buffer:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: double
       - VALUE: int8
       - VALUE: uint8
+      - VALUE: int32
   shader_variants:
     - NAME: nchw_to_buffer
     - NAME: nchw_to_buffer_no_pc
diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl
index 4674822ce6a..f3f604e10cd 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl
+++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl
@@ -87,5 +87,9 @@ void main() {
     return;
   }
 
-  write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx));
+  $if DTYPE == "double" and DTYPE == "int64":
+    VEC4_T texel = read_texel(tidx);
+    write_texel(t_out, lpos_to_pos(lpos, axis_map), texel);
+  $else:
+    write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx));
 }
diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml
index 7e52ec10376..85119c8d508 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml
@@ -14,9 +14,10 @@ nchw_to_image:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: double
       - VALUE: int8
       - VALUE: uint8
+      - VALUE: int32
   shader_variants:
     - NAME: nchw_to_image_texture3d
     - NAME: nchw_to_image_texture2d
diff --git a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml
index e64e1bd260a..bfeaba2496b 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml
@@ -12,7 +12,7 @@ no_op:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
       - VALUE: int8
       - VALUE: uint8
     STORAGE:
diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml
index f678aeedf6e..a90ddcb41ce 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml
@@ -7,6 +7,6 @@ permute:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
   shader_variants:
     - NAME: permute
diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml
index 526980a0f41..f40d94142e1 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml
@@ -7,7 +7,7 @@ repeat:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
       - VALUE: int8
       - VALUE: uint8
   shader_variants:
diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
index f13393ce6c7..47f538aee6c 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
@@ -15,9 +15,9 @@ unary_op:
       OPERATOR: abs(X)
     - NAME: clamp
       OPERATOR: clamp(X, A, B)
-    - NAME: clamp_int
+    - NAME: clamp_int32
       OPERATOR: clamp(X, A, B)
-      DTYPE: int
+      DTYPE: int32
     - NAME: cos
       OPERATOR: cos(X)
     - NAME: exp
diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.yaml b/backends/vulkan/runtime/graph/ops/glsl/view.yaml
index ba11a2496a0..33364a25225 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/view.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/view.yaml
@@ -7,6 +7,6 @@ view:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
   shader_variants:
     - NAME: view
diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp
index e1ac4e9d40a..6388a8ad091 100644
--- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp
+++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp
@@ -34,24 +34,42 @@ void add_storage_type_suffix(
 
 void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) {
   switch (dtype) {
+    case vkapi::kDouble:
+      kernel_name += "_double";
+      break;
     case vkapi::kFloat:
       kernel_name += "_float";
       break;
     case vkapi::kHalf:
       kernel_name += "_half";
       break;
-    case vkapi::kInt:
-      kernel_name += "_int";
-      break;
     case vkapi::kChar:
     case vkapi::kQInt8:
       kernel_name += "_int8";
       break;
     case vkapi::kByte:
-    case vkapi::kQUInt8:
     case vkapi::kBool:
+    case vkapi::kQUInt8:
       kernel_name += "_uint8";
       break;
+    case vkapi::kShort:
+      kernel_name += "_int16";
+      break;
+    case vkapi::kUInt16:
+      kernel_name += "_uint16";
+      break;
+    case vkapi::kInt:
+      kernel_name += "_int32";
+      break;
+    case vkapi::kUInt:
+      kernel_name += "_uint32";
+      break;
+    case vkapi::kLong:
+      kernel_name += "_int64";
+      break;
+    case vkapi::kUInt64:
+      kernel_name += "_uint64";
+      break;
     default:
       break;
   }
diff --git a/backends/vulkan/runtime/vk_api/Types.h b/backends/vulkan/runtime/vk_api/Types.h
index f25fe95d72b..b3309aa6c69 100644
--- a/backends/vulkan/runtime/vk_api/Types.h
+++ b/backends/vulkan/runtime/vk_api/Types.h
@@ -30,11 +30,17 @@
 
 #define VK_FORALL_SCALAR_TYPES(_)                  \
   _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte)        \
-  _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char)         \
-  _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int)     \
   _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Bool)        \
+  _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char)         \
   _(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \
+  _(uint16_t, VK_FORMAT_R16G16B16A16_UINT, UInt16) \
+  _(int16_t, VK_FORMAT_R16G16B16A16_SINT, Short)   \
+  _(uint32_t, VK_FORMAT_R32G32B32A32_UINT, UInt)   \
+  _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int)     \
+  _(uint64_t, VK_FORMAT_R64G64B64A64_UINT, UInt64) \
+  _(int64_t, VK_FORMAT_R64G64B64A64_SINT, Long)    \
   _(float, VK_FORMAT_FLOAT4, Float)                \
+  _(double, VK_FORMAT_R64G64B64A64_SFLOAT, Double) \
   _(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8)        \
   _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8)      \
   _(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32)
@@ -86,17 +92,29 @@ inline VkFormat to_vkformat(const ScalarType t) {
  */
 inline ScalarType element_scalartype(const VkFormat vkformat) {
   switch (vkformat) {
+    case VK_FORMAT_R64G64B64A64_SFLOAT:
+      return kDouble;
+    case VK_FORMAT_R32G32B32A32_SFLOAT:
+      return kFloat;
+    case VK_FORMAT_R16G16B16A16_SFLOAT:
+      return kHalf;
     case VK_FORMAT_R8G8B8A8_SINT:
       return kChar;
     case VK_FORMAT_R8G8B8A8_UINT:
     case VK_FORMAT_R8G8B8A8_UNORM:
       return kByte;
+    case VK_FORMAT_R16G16B16A16_SINT:
+      return kShort;
+    case VK_FORMAT_R16G16B16A16_UINT:
+      return kUInt16;
     case VK_FORMAT_R32G32B32A32_SINT:
       return kInt;
-    case VK_FORMAT_R32G32B32A32_SFLOAT:
-      return kFloat;
-    case VK_FORMAT_R16G16B16A16_SFLOAT:
-      return kHalf;
+    case VK_FORMAT_R32G32B32A32_UINT:
+      return kUInt;
+    case VK_FORMAT_R64G64B64A64_SINT:
+      return kLong;
+    case VK_FORMAT_R64G64B64A64_UINT:
+      return kUInt64;
     default:
       VK_THROW("No corresponding scalar type for unknown VkFormat: ", vkformat);
   }
diff --git a/backends/vulkan/test/glsl/all_shaders.yaml b/backends/vulkan/test/glsl/all_shaders.yaml
index 37403c97ac8..4ef934eb105 100644
--- a/backends/vulkan/test/glsl/all_shaders.yaml
+++ b/backends/vulkan/test/glsl/all_shaders.yaml
@@ -51,7 +51,7 @@ idx_fill_texture:
     DTYPE:
       - VALUE: half
       - VALUE: float
-      - VALUE: int
+      - VALUE: int32
       - VALUE: int8
   shader_variants:
     - NAME: idx_fill_texture
diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp
new file mode 100644
index 00000000000..24c856e9d46
--- /dev/null
+++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp
@@ -0,0 +1,675 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <gtest/gtest.h>
+
+#include <ATen/ATen.h>
+
+#include <executorch/backends/vulkan/runtime/api/api.h>
+#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
+#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
+
+#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
+#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
+
+#include "test_utils.h"
+
+#include <cassert>
+#include <iostream>
+
+namespace torch {
+namespace executor {
+namespace native {
+
+// Forward declarations of the functions we're testing
+std::tuple<Tensor&, Tensor&> choose_qparams_tensor_out(
+    const Tensor& input,
+    int64_t quant_min,
+    int64_t quant_max,
+    ET_UNUSED double eps,
+    ScalarType dtype,
+    Tensor& scale_out,
+    Tensor& zero_point_out);
+
+std::tuple<Tensor&, Tensor&> choose_qparams_per_token_asymmetric_out(
+    const Tensor& input,
+    ScalarType dtype,
+    Tensor& scale_out,
+    Tensor& zero_point_out);
+
+// Wrapper function for choose_qparams_tensor_out without context
+Tensor& choose_qparams_tensor_out_no_context(
+    const Tensor& input,
+    int64_t quant_min,
+    int64_t quant_max,
+    ET_UNUSED double eps,
+    ScalarType dtype,
+    Tensor& scale_out,
+    Tensor& zero_point_out) {
+  torch::executor::native::choose_qparams_tensor_out(
+      input, quant_min, quant_max, eps, dtype, scale_out, zero_point_out);
+  return scale_out;
+}
+
+// Wrapper function for choose_qparams_per_token_asymmetric_out without context
+Tensor& choose_qparams_per_token_asymmetric_out_no_context(
+    const Tensor& input,
+    ScalarType dtype,
+    Tensor& scale_out,
+    Tensor& zero_point_out) {
+  torch::executor::native::choose_qparams_per_token_asymmetric_out(
+      input, dtype, scale_out, zero_point_out);
+  return scale_out;
+}
+
+// ATen wrapper for choose_qparams_tensor
+std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_aten(
+    const at::Tensor& input,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype) {
+  auto scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble));
+  auto zero_point_out = at::empty({}, at::device(at::kCPU).dtype(at::kLong));
+  double eps = 1e-7;
+
+  ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype);
+
+  // Use WRAP_TO_ATEN with the wrapper function
+  WRAP_TO_ATEN(choose_qparams_tensor_out_no_context, 5)
+  (input, quant_min, quant_max, eps, et_dtype, scale_out, zero_point_out);
+
+  return {scale_out, zero_point_out};
+}
+
+// ATen wrapper for choose_qparams_per_token_asymmetric
+std::tuple<at::Tensor, at::Tensor> choose_qparams_per_token_asymmetric_aten(
+    const at::Tensor& input,
+    at::ScalarType dtype) {
+  // Calculate output sizes for scale and zero_point tensors
+  std::vector<int64_t> output_sizes;
+  for (int64_t i = 0; i < input.dim() - 1; i++) {
+    output_sizes.push_back(input.size(i));
+  }
+  output_sizes.push_back(1);
+
+  auto scale_out =
+      at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble));
+  auto zero_point_out =
+      at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong));
+
+  ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype);
+
+  // Use WRAP_TO_ATEN with the wrapper function
+  WRAP_TO_ATEN(choose_qparams_per_token_asymmetric_out_no_context, 2)
+  (input, et_dtype, scale_out, zero_point_out);
+
+  return {scale_out, zero_point_out};
+}
+
+} // namespace native
+} // namespace executor
+} // namespace torch
+
+//
+// Reference Implementation
+//
+
+/*
+ * Reference implementation of choose_qparams_tensor
+ */
+std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_reference_impl(
+    const at::Tensor& input,
+    int64_t quant_min,
+    int64_t quant_max) {
+  // Create output tensors
+  at::Tensor scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble));
+  at::Tensor zero_point_out =
+      at::empty({}, at::device(at::kCPU).dtype(at::kLong));
+
+  // Find min and max values in the input tensor
+  float min_val = input.min().item<float>();
+  float max_val = input.max().item<float>();
+
+  // Extend the [min, max] interval to ensure it contains 0
+  min_val = std::min(min_val, 0.f);
+  max_val = std::max(max_val, 0.f);
+
+  // Calculate scale
+  double scale =
+      (static_cast<double>(max_val) - min_val) / (quant_max - quant_min);
+
+  // Handle small scale
+  constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f;
+  if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) {
+    scale = 0.1;
+  }
+
+  if (scale < SMALL_SCALE_THRESHOLD) {
+    float org_scale = scale;
+    scale = SMALL_SCALE_THRESHOLD;
+    // Adjust min and max based on new scale
+    if (min_val == 0.0f) {
+      max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
+    } else if (max_val == 0.0f) {
+      min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
+    } else {
+      float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
+      min_val *= amplifier;
+      max_val *= amplifier;
+    }
+  }
+
+  // Calculate zero point
+  double zero_point_from_min = quant_min - min_val / static_cast<double>(scale);
+  double zero_point_from_max = quant_max - max_val / static_cast<double>(scale);
+  double zero_point_from_min_error =
+      std::abs(quant_min) - std::abs(min_val / static_cast<double>(scale));
+  double zero_point_from_max_error =
+      std::abs(quant_max) - std::abs(max_val / static_cast<double>(scale));
+  double initial_zero_point =
+      zero_point_from_min_error < zero_point_from_max_error
+      ? zero_point_from_min
+      : zero_point_from_max;
+
+  // Nudge zero point to be an integer
+  int64_t nudged_zero_point = 0;
+  if (initial_zero_point < quant_min) {
+    nudged_zero_point = quant_min;
+  } else if (initial_zero_point > quant_max) {
+    nudged_zero_point = quant_max;
+  } else {
+    nudged_zero_point = std::nearbyint(static_cast<float>(initial_zero_point));
+  }
+
+  // Set output values - use item_mutable() for scalar tensors
+  scale_out.fill_(scale);
+  zero_point_out.fill_(nudged_zero_point);
+
+  return std::make_tuple(scale_out, zero_point_out);
+}
+
+/*
+ * Reference implementation of choose_qparams_per_token_asymmetric
+ */
+std::tuple<at::Tensor, at::Tensor>
+choose_qparams_per_token_asymmetric_reference_impl(
+    const at::Tensor& input,
+    at::ScalarType dtype) {
+  // For per-token quantization, we need to compute scale and zero_point for
+  // each token
+  int64_t quant_min = -128;
+  int64_t quant_max = 127;
+
+  // Calculate output sizes
+  std::vector<int64_t> output_sizes;
+  for (int64_t i = 0; i < input.dim() - 1; i++) {
+    output_sizes.push_back(input.size(i));
+  }
+  output_sizes.push_back(1);
+
+  // Create output tensors
+  at::Tensor scale_out =
+      at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble));
+  at::Tensor zero_point_out =
+      at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong));
+
+  // Calculate number of tokens
+  int64_t num_tokens = 1;
+  for (int64_t i = 0; i < input.dim() - 1; i++) {
+    num_tokens *= input.size(i);
+  }
+
+  // Reshape input to [num_tokens, last_dim]
+  at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)});
+
+  // Process each token
+  for (int64_t token_idx = 0; token_idx < num_tokens; token_idx++) {
+    at::Tensor token = reshaped_input[token_idx];
+
+    // Find min and max values for this token
+    float min_val = token.min().item<float>();
+    float max_val = token.max().item<float>();
+
+    // Extend the [min, max] interval to ensure it contains 0
+    min_val = std::min(min_val, 0.f);
+    max_val = std::max(max_val, 0.f);
+
+    // Calculate scale
+    double scale =
+        (static_cast<double>(max_val) - min_val) / (quant_max - quant_min);
+
+    // Handle small scale
+    constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f;
+    if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) {
+      scale = 0.1;
+    }
+
+    if (scale < SMALL_SCALE_THRESHOLD) {
+      float org_scale = scale;
+      scale = SMALL_SCALE_THRESHOLD;
+      // Adjust min and max based on new scale
+      if (min_val == 0.0f) {
+        max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
+      } else if (max_val == 0.0f) {
+        min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
+      } else {
+        float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
+        min_val *= amplifier;
+        max_val *= amplifier;
+      }
+    }
+
+    // Calculate zero point
+    double zero_point_from_min =
+        quant_min - min_val / static_cast<double>(scale);
+    double zero_point_from_max =
+        quant_max - max_val / static_cast<double>(scale);
+    double zero_point_from_min_error =
+        std::abs(quant_min) - std::abs(min_val / static_cast<double>(scale));
+    double zero_point_from_max_error =
+        std::abs(quant_max) - std::abs(max_val / static_cast<double>(scale));
+    double initial_zero_point =
+        zero_point_from_min_error < zero_point_from_max_error
+        ? zero_point_from_min
+        : zero_point_from_max;
+
+    // Nudge zero point to be an integer
+    int64_t nudged_zero_point = 0;
+    if (initial_zero_point < quant_min) {
+      nudged_zero_point = quant_min;
+    } else if (initial_zero_point > quant_max) {
+      nudged_zero_point = quant_max;
+    } else {
+      nudged_zero_point =
+          std::nearbyint(static_cast<float>(initial_zero_point));
+    }
+
+    // Set output values for this token - use index_put_ for safety
+    scale_out.view({num_tokens, 1}).index_put_({token_idx, 0}, scale);
+    zero_point_out.view({num_tokens, 1})
+        .index_put_({token_idx, 0}, nudged_zero_point);
+  }
+
+  return std::make_tuple(scale_out, zero_point_out);
+}
+
+// Forward declaration of implementation functions
+void test_vulkan_choose_qparams_tensor_impl(
+    const std::vector<int>& input_sizes,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    const vkcompute::utils::StorageType in_storage,
+    const vkcompute::utils::StorageType out_storage);
+
+void test_vulkan_choose_qparams_per_token_asymmetric_impl(
+    const std::vector<int>& input_sizes,
+    at::ScalarType dtype,
+    const vkcompute::utils::StorageType in_storage,
+    const vkcompute::utils::StorageType out_storage);
+
+// Wrapper function to test both buffer and texture storage types
+void test_vulkan_choose_qparams_tensor(
+    const std::vector<int>& input_sizes,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype) {
+  // Test with buffer storage
+  test_vulkan_choose_qparams_tensor_impl(
+      input_sizes,
+      quant_min,
+      quant_max,
+      dtype,
+      vkcompute::utils::kBuffer,
+      vkcompute::utils::kBuffer);
+
+  // Test with texture storage
+  test_vulkan_choose_qparams_tensor_impl(
+      input_sizes,
+      quant_min,
+      quant_max,
+      dtype,
+      vkcompute::utils::kTexture3D,
+      vkcompute::utils::kTexture3D);
+}
+
+// Wrapper function to test both buffer and texture storage types
+void test_vulkan_choose_qparams_per_token_asymmetric(
+    const std::vector<int>& input_sizes,
+    at::ScalarType dtype) {
+  // Test with buffer storage
+  test_vulkan_choose_qparams_per_token_asymmetric_impl(
+      input_sizes, dtype, vkcompute::utils::kBuffer, vkcompute::utils::kBuffer);
+
+  // Test with texture storage
+  test_vulkan_choose_qparams_per_token_asymmetric_impl(
+      input_sizes,
+      dtype,
+      vkcompute::utils::kTexture3D,
+      vkcompute::utils::kTexture3D);
+}
+
+void test_reference_choose_qparams_tensor(
+    const std::vector<int>& input_sizes,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype) {
+  std::vector<int64_t> input_sizes_int64(
+      input_sizes.begin(), input_sizes.end());
+  at::Tensor input =
+      at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat));
+
+  // Get reference output
+  auto [reference_scale, reference_zero_point] =
+      choose_qparams_tensor_reference_impl(input, quant_min, quant_max);
+
+  // Get implementation output
+  auto [impl_scale, impl_zero_point] =
+      torch::executor::native::choose_qparams_tensor_aten(
+          input, quant_min, quant_max, dtype);
+
+  // Compare outputs
+  const bool scale_correct = at::allclose(reference_scale, impl_scale);
+  const bool zero_point_correct =
+      at::equal(reference_zero_point, impl_zero_point);
+
+  if (!scale_correct || !zero_point_correct) {
+    std::cout << "\n"
+              << "Failed with parameters: " << std::endl;
+    std::cout << "  quant_min: " << quant_min << std::endl;
+    std::cout << "  quant_max: " << quant_max << std::endl;
+
+    std::cout << "input:" << std::endl;
+    std::cout << input << std::endl;
+    std::cout << "reference scale:" << std::endl;
+    std::cout << reference_scale << std::endl;
+    std::cout << "implementation scale:" << std::endl;
+    std::cout << impl_scale << std::endl;
+    std::cout << "reference zero_point:" << std::endl;
+    std::cout << reference_zero_point << std::endl;
+    std::cout << "implementation zero_point:" << std::endl;
+    std::cout << impl_zero_point << std::endl;
+  }
+
+  ASSERT_TRUE(scale_correct && zero_point_correct);
+}
+
+void test_vulkan_choose_qparams_tensor_impl(
+    const std::vector<int>& input_sizes,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    const vkcompute::utils::StorageType in_storage,
+    const vkcompute::utils::StorageType out_storage) {
+  std::vector<int64_t> input_sizes_int64(
+      input_sizes.begin(), input_sizes.end());
+  at::Tensor input =
+      at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat));
+
+  // Get reference output
+  auto [reference_scale, reference_zero_point] =
+      torch::executor::native::choose_qparams_tensor_aten(
+          input, quant_min, quant_max, dtype);
+
+  // Build Vulkan choose_qparams_tensor graph
+  using namespace vkcompute;
+
+  GraphConfig config;
+  config.set_storage_type_override(in_storage);
+  ComputeGraph graph(config);
+
+  IOValueRef r_input = graph.add_input_tensor(
+      input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage);
+
+  const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min);
+  const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max);
+
+  // Output tensors
+  const ValueRef r_scale = graph.add_tensor({}, vkapi::kFloat, out_storage);
+  const ValueRef r_zero_point = graph.add_tensor({}, vkapi::kInt, out_storage);
+
+  VK_GET_OP_FN("choose_qparams.tensor")
+  (graph,
+   {
+       r_input.value,
+       r_quant_min,
+       r_quant_max,
+       r_scale,
+       r_zero_point,
+   });
+
+  ValueRef staging_scale = graph.set_output_tensor(r_scale);
+  ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point);
+
+  graph.prepare();
+  graph.encode_prepack();
+  graph.prepack();
+  graph.encode_execute();
+
+  // Run Vulkan choose_qparams_tensor
+  graph.copy_into_staging(
+      r_input.staging, input.const_data_ptr(), input.numel());
+
+  graph.execute();
+
+  // Create output tensors to hold the results - use types that match GPU output
+  at::Tensor vk_scale =
+      at::empty({}, at::device(at::kCPU).dtype(at::kFloat)).contiguous();
+  at::Tensor vk_zero_point =
+      at::empty({}, at::device(at::kCPU).dtype(at::kInt)).contiguous();
+
+  // Copy results from GPU to CPU
+  graph.copy_from_staging(
+      staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel());
+  graph.copy_from_staging(
+      staging_zero_point,
+      vk_zero_point.mutable_data_ptr(),
+      vk_zero_point.numel());
+
+  // Convert reference values to match Vulkan output types for comparison
+  at::Tensor reference_scale_float = reference_scale.to(at::kFloat);
+  at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt);
+
+  // Compare outputs
+  const bool scale_correct = at::allclose(reference_scale_float, vk_scale);
+  const bool zero_point_correct =
+      at::equal(reference_zero_point_int, vk_zero_point);
+
+  if (!scale_correct || !zero_point_correct) {
+    std::cout << "\n"
+              << "Failed with parameters: " << std::endl;
+    std::cout << "  quant_min: " << quant_min << std::endl;
+    std::cout << "  quant_max: " << quant_max << std::endl;
+    std::cout << "  storage type: "
+              << (in_storage == vkcompute::utils::kBuffer ? "buffer"
+                                                          : "texture")
+              << std::endl;
+
+    // make sure that there arent a ton of elements in the input tensor
+    if (input.numel() < 100) {
+      std::cout << "input:" << std::endl;
+      std::cout << input << "\n" << std::endl;
+      std::cout << "reference scale:" << std::endl;
+      std::cout << reference_scale << std::endl;
+      std::cout << "vulkan scale:" << std::endl;
+      std::cout << vk_scale << "\n" << std::endl;
+      std::cout << "reference zero_point:" << std::endl;
+      std::cout << reference_zero_point << std::endl;
+      std::cout << "vulkan zero_point:" << std::endl;
+      std::cout << vk_zero_point << std::endl;
+    }
+  }
+
+  ASSERT_TRUE(scale_correct && zero_point_correct);
+}
+
+TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) {
+  test_reference_choose_qparams_tensor(
+      {2, 3, 4}, // input sizes
+      -128, // quant_min
+      127, // quant_max
+      at::kChar);
+}
+
+void test_reference_choose_qparams_per_token_asymmetric(
+    const std::vector<int>& input_sizes,
+    at::ScalarType dtype) {
+  std::vector<int64_t> input_sizes_int64(
+      input_sizes.begin(), input_sizes.end());
+  at::Tensor input =
+      at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat));
+
+  // Get reference output
+  auto [reference_scale, reference_zero_point] =
+      choose_qparams_per_token_asymmetric_reference_impl(input, dtype);
+
+  // Get implementation output
+  auto [impl_scale, impl_zero_point] =
+      torch::executor::native::choose_qparams_per_token_asymmetric_aten(
+          input, dtype);
+
+  // Compare outputs
+  const bool scale_correct = at::allclose(reference_scale, impl_scale);
+  const bool zero_point_correct =
+      at::equal(reference_zero_point, impl_zero_point);
+
+  if (!scale_correct || !zero_point_correct) {
+    std::cout << "\n"
+              << "Failed with parameters: " << std::endl;
+
+    std::cout << "input:" << std::endl;
+    std::cout << input << std::endl;
+    std::cout << "reference scale:" << std::endl;
+    std::cout << reference_scale << std::endl;
+    std::cout << "implementation scale:" << std::endl;
+    std::cout << impl_scale << std::endl;
+    std::cout << "reference zero_point:" << std::endl;
+    std::cout << reference_zero_point << std::endl;
+    std::cout << "implementation zero_point:" << std::endl;
+    std::cout << impl_zero_point << std::endl;
+  }
+
+  ASSERT_TRUE(scale_correct && zero_point_correct);
+}
+
+void test_vulkan_choose_qparams_per_token_asymmetric_impl(
+    const std::vector<int>& input_sizes,
+    at::ScalarType dtype,
+    const vkcompute::utils::StorageType in_storage,
+    const vkcompute::utils::StorageType out_storage) {
+  std::vector<int64_t> input_sizes_int64(
+      input_sizes.begin(), input_sizes.end());
+  at::Tensor input =
+      at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat));
+
+  // Calculate output sizes
+  std::vector<int64_t> output_sizes;
+  for (int64_t i = 0; i < input.dim() - 1; i++) {
+    output_sizes.push_back(input.size(i));
+  }
+  output_sizes.push_back(1);
+
+  // Get reference output
+  auto [reference_scale, reference_zero_point] =
+      torch::executor::native::choose_qparams_per_token_asymmetric_aten(
+          input, dtype);
+
+  // Build Vulkan choose_qparams_per_token_asymmetric graph
+  using namespace vkcompute;
+
+  GraphConfig config;
+  config.set_storage_type_override(in_storage);
+  ComputeGraph graph(config);
+
+  IOValueRef r_input = graph.add_input_tensor(
+      input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage);
+
+  // Output tensors
+  const ValueRef r_scale =
+      graph.add_tensor(output_sizes, vkapi::kFloat, out_storage);
+  const ValueRef r_zero_point =
+      graph.add_tensor(output_sizes, vkapi::kInt, out_storage);
+
+  VK_GET_OP_FN("choose_qparams_per_token_asymmetric.default")
+  (graph,
+   {
+       r_input.value,
+       r_scale,
+       r_zero_point,
+   });
+
+  ValueRef staging_scale = graph.set_output_tensor(r_scale);
+  ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point);
+
+  graph.prepare();
+  graph.encode_prepack();
+  graph.prepack();
+  graph.encode_execute();
+
+  // Run Vulkan choose_qparams_per_token_asymmetric
+  graph.copy_into_staging(
+      r_input.staging, input.const_data_ptr(), input.numel());
+
+  graph.execute();
+
+  // Create output tensors to hold the results - use types that match GPU output
+  at::Tensor vk_scale =
+      at::empty(output_sizes, at::device(at::kCPU).dtype(at::kFloat))
+          .contiguous();
+  at::Tensor vk_zero_point =
+      at::empty(output_sizes, at::device(at::kCPU).dtype(at::kInt))
+          .contiguous();
+
+  // Copy results from GPU to CPU
+  graph.copy_from_staging(
+      staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel());
+  graph.copy_from_staging(
+      staging_zero_point,
+      vk_zero_point.mutable_data_ptr(),
+      vk_zero_point.numel());
+
+  // Convert reference values to match Vulkan output types for comparison
+  at::Tensor reference_scale_float = reference_scale.to(at::kFloat);
+  at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt);
+
+  // Compare outputs
+  const bool scale_correct = at::allclose(reference_scale_float, vk_scale);
+  const bool zero_point_correct =
+      at::equal(reference_zero_point_int, vk_zero_point);
+  if (!scale_correct || !zero_point_correct) {
+    std::cout << "\n"
+              << "Failed with parameters: " << std::endl;
+    std::cout << "  storage type: "
+              << (in_storage == vkcompute::utils::kBuffer ? "buffer"
+                                                          : "texture")
+              << std::endl;
+
+    if (input.numel() < 100) {
+      std::cout << "input:" << std::endl;
+      std::cout << input << "\n" << std::endl;
+      std::cout << "reference scale:" << std::endl;
+      std::cout << reference_scale << std::endl;
+      std::cout << "vulkan scale:" << std::endl;
+      std::cout << vk_scale << "\n" << std::endl;
+      std::cout << "reference zero_point:" << std::endl;
+      std::cout << reference_zero_point << std::endl;
+      std::cout << "vulkan zero_point:" << std::endl;
+      std::cout << vk_zero_point << std::endl;
+    }
+  }
+
+  ASSERT_TRUE(scale_correct && zero_point_correct);
+}
+
+TEST(
+    VulkanChooseQparamsTest,
+    test_reference_choose_qparams_per_token_asymmetric_int8) {
+  test_reference_choose_qparams_per_token_asymmetric(
+      {2, 3, 4}, // input sizes (2*3=6 tokens)
+      at::kChar);
+}
diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp
new file mode 100644
index 00000000000..7b155c8f98b
--- /dev/null
+++ b/backends/vulkan/test/op_tests/dequantize_test.cpp
@@ -0,0 +1,1061 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <gtest/gtest.h>
+
+#include <ATen/ATen.h>
+
+#include <executorch/backends/vulkan/runtime/api/api.h>
+#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
+#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
+
+#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
+#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
+
+#include "test_utils.h"
+
+#include <cassert>
+#include <iostream>
+#include <limits>
+
+namespace torch {
+namespace executor {
+namespace native {
+
+// Forward declarations of the functions we're testing
+Tensor& dequantize_per_tensor_out(
+    const Tensor& input,
+    double scale,
+    int64_t zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    ScalarType dtype,
+    executorch::aten::optional<ScalarType> out_dtype,
+    Tensor& out);
+
+Tensor& dequantize_per_token_out(
+    const Tensor& input,
+    const Tensor& scale,
+    const Tensor& zero_points,
+    int64_t quant_min,
+    int64_t quant_max,
+    ScalarType dtype,
+    ScalarType out_dtype,
+    Tensor& out);
+
+// Wrapper function for dequantize_per_tensor_out without context
+Tensor& dequantize_per_tensor_out_no_context(
+    const Tensor& input,
+    double scale,
+    int64_t zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    ScalarType dtype,
+    executorch::aten::optional<ScalarType> out_dtype,
+    Tensor& out) {
+  return torch::executor::native::dequantize_per_tensor_out(
+      input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
+}
+
+// Wrapper function for dequantize_per_token_out without context
+Tensor& dequantize_per_token_out_no_context(
+    const Tensor& input,
+    const Tensor& scale,
+    const Tensor& zero_points,
+    int64_t quant_min,
+    int64_t quant_max,
+    ScalarType dtype,
+    ScalarType out_dtype,
+    Tensor& out) {
+  return torch::executor::native::dequantize_per_token_out(
+      input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out);
+}
+
+// ATen wrapper for dequantize_per_tensor
+at::Tensor dequantize_per_tensor_aten(
+    const at::Tensor& input,
+    double scale,
+    int64_t zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    at::ScalarType out_dtype) {
+  auto out = at::empty_like(input, out_dtype);
+  // Convert at::ScalarType to executorch::ScalarType
+  ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype);
+  ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype);
+
+  executorch::aten::optional<ScalarType> opt_et_out_dtype(et_out_dtype);
+
+  WRAP_TO_ATEN(dequantize_per_tensor_out_no_context, 7)
+  (input,
+   scale,
+   zero_point,
+   quant_min,
+   quant_max,
+   et_dtype,
+   opt_et_out_dtype,
+   out);
+  return out;
+}
+
+// ATen wrapper for dequantize_per_token
+at::Tensor dequantize_per_token_aten(
+    const at::Tensor& input,
+    const at::Tensor& scale,
+    const at::Tensor& zero_points,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    at::ScalarType out_dtype) {
+  auto out = at::empty_like(input, out_dtype);
+  // Convert at::ScalarType to executorch::ScalarType
+  ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype);
+  ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype);
+
+  WRAP_TO_ATEN(dequantize_per_token_out_no_context, 7)
+  (input,
+   scale,
+   zero_points,
+   quant_min,
+   quant_max,
+   et_dtype,
+   et_out_dtype,
+   out);
+  return out;
+}
+
+} // namespace native
+} // namespace executor
+} // namespace torch
+
+void check_dequantize_args(
+    int64_t quant_min,
+    int64_t quant_max,
+    c10::ScalarType in_dtype,
+    c10::ScalarType out_dtype) {
+  using namespace vkcompute;
+
+  // Check that quant_min <= quant_max
+  VK_CHECK_COND(
+      quant_min <= quant_max,
+      "quant_min must be <= quant_max, got quant_min: ",
+      quant_min,
+      " quant_max: ",
+      quant_max);
+
+  // Check that input dtype is a quantized type
+  switch (in_dtype) {
+    case c10::kByte:
+    case c10::kChar:
+    case c10::kShort:
+    case c10::kInt:
+    case c10::kLong:
+      break;
+    default:
+      VK_THROW(
+          "Unsupported input dtype: ",
+          scalar_type_name(in_dtype),
+          " (",
+          static_cast<int>(in_dtype),
+          ")");
+  }
+
+  // Check that output dtype is a floating point type
+  switch (out_dtype) {
+    case c10::kHalf:
+    case c10::kFloat:
+    case c10::kDouble:
+      break;
+    default:
+      VK_THROW(
+          "Unsupported output dtype: ",
+          scalar_type_name(out_dtype),
+          " (",
+          static_cast<int>(out_dtype),
+          ")");
+  }
+}
+
+//
+// Reference Implementation
+//
+
+/*
+ * Reference implementation of dequantize_per_tensor
+ */
+at::Tensor dequantize_per_tensor_reference_impl(
+    const at::Tensor& input,
+    double scale,
+    int64_t zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    at::ScalarType out_dtype) {
+  // Create output tensor with the target dtype
+  at::Tensor out = at::empty_like(input, out_dtype);
+
+  // Dequantize the input tensor
+  at::Tensor flat_input = input.flatten();
+  at::Tensor flat_out = out.flatten();
+
+  // Store casted values to avoid repeated casting
+  const int32_t zero_point_int32 = static_cast<int32_t>(zero_point);
+  const float scale_float = static_cast<float>(scale);
+
+  for (int i = 0; i < flat_input.numel(); i++) {
+    double dequantized_value = 0.0;
+
+    // Extract quantized value and dequantize based on input dtype
+    // Following the CPU implementation pattern: (input - zero_point) * scale
+    if (dtype == at::kByte) {
+      uint8_t qvalue = flat_input[i].item<uint8_t>();
+      dequantized_value = (qvalue - zero_point_int32) * scale_float;
+    } else if (dtype == at::kChar) {
+      int8_t qvalue = flat_input[i].item<int8_t>();
+      dequantized_value = (qvalue - zero_point_int32) * scale_float;
+    } else if (dtype == at::kShort) {
+      int16_t qvalue = flat_input[i].item<int16_t>();
+      dequantized_value = (qvalue - zero_point_int32) * scale_float;
+    } else if (dtype == at::kInt) {
+      int32_t qvalue = flat_input[i].item<int32_t>();
+      dequantized_value = (qvalue - zero_point_int32) * scale_float;
+    } else if (dtype == at::kLong) {
+      int64_t qvalue = flat_input[i].item<int64_t>();
+      dequantized_value = (qvalue - zero_point_int32) * scale_float;
+    }
+
+    // Store result based on output dtype
+    if (out_dtype == at::kFloat) {
+      flat_out[i] = static_cast<float>(dequantized_value);
+    } else if (out_dtype == at::kDouble) {
+      flat_out[i] = dequantized_value;
+    } else if (out_dtype == at::kHalf) {
+      flat_out[i] = static_cast<c10::Half>(dequantized_value);
+    }
+  }
+
+  return out.reshape(input.sizes());
+}
+
+/*
+ * Reference implementation of dequantize_per_token
+ */
+at::Tensor dequantize_per_token_reference_impl(
+    const at::Tensor& input,
+    const at::Tensor& scale,
+    const at::Tensor& zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    at::ScalarType out_dtype) {
+  // Create output tensor with the target dtype
+  at::Tensor out = at::empty_like(input, out_dtype);
+
+  // Calculate number of tokens
+  int num_tokens = 1;
+  for (int i = 0; i < input.dim() - 1; i++) {
+    num_tokens *= input.size(i);
+  }
+
+  // Verify that the number of tokens matches the size of scale and zero_point
+  // tensors
+  assert(num_tokens == scale.numel());
+  assert(num_tokens == zero_point.numel());
+
+  // Reshape input to [num_tokens, last_dim]
+  at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)});
+  at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)});
+
+  // Dequantize each token separately
+  for (int token_idx = 0; token_idx < num_tokens; token_idx++) {
+    // Get scale and zero_point for this token
+    float token_scale = scale[token_idx].item<float>();
+    int64_t token_zero_point = zero_point[token_idx].item<int64_t>();
+
+    // Store casted values to avoid repeated casting
+    const int32_t token_zero_point_int32 =
+        static_cast<int32_t>(token_zero_point);
+
+    // Dequantize the token
+    for (int i = 0; i < input.size(-1); i++) {
+      double dequantized_value = 0.0;
+
+      // Extract quantized value and dequantize based on input dtype
+      // Following the CPU implementation pattern: (input - zero_point) * scale
+      if (dtype == at::kByte) {
+        uint8_t qvalue = reshaped_input[token_idx][i].item<uint8_t>();
+        dequantized_value = (qvalue - token_zero_point_int32) * token_scale;
+      } else if (dtype == at::kChar) {
+        int8_t qvalue = reshaped_input[token_idx][i].item<int8_t>();
+        dequantized_value = (qvalue - token_zero_point_int32) * token_scale;
+      } else if (dtype == at::kShort) {
+        int16_t qvalue = reshaped_input[token_idx][i].item<int16_t>();
+        dequantized_value = (qvalue - token_zero_point_int32) * token_scale;
+      } else if (dtype == at::kInt) {
+        int32_t qvalue = reshaped_input[token_idx][i].item<int32_t>();
+        dequantized_value = (qvalue - token_zero_point_int32) * token_scale;
+      } else if (dtype == at::kLong) {
+        int64_t qvalue = reshaped_input[token_idx][i].item<int64_t>();
+        dequantized_value = (qvalue - token_zero_point_int32) * token_scale;
+      } else {
+        throw std::runtime_error("Unsupported input dtype");
+      }
+
+      // Store result based on output dtype
+      if (out_dtype == at::kFloat) {
+        reshaped_out[token_idx][i] = static_cast<float>(dequantized_value);
+      } else if (out_dtype == at::kDouble) {
+        reshaped_out[token_idx][i] = dequantized_value;
+      } else if (out_dtype == at::kHalf) {
+        reshaped_out[token_idx][i] = static_cast<c10::Half>(dequantized_value);
+      }
+    }
+  }
+
+  return out;
+}
+
+// Forward declaration of implementation functions
+void test_vulkan_dequantize_per_tensor_impl(
+    const std::vector<int>& input_sizes,
+    float scale,
+    int zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    at::ScalarType out_dtype,
+    const vkcompute::utils::StorageType in_storage,
+    const vkcompute::utils::StorageType out_storage);
+
+void test_vulkan_dequantize_per_token_impl(
+    const std::vector<int>& input_sizes,
+    const std::vector<float>& scales,
+    const std::vector<int>& zero_points,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    at::ScalarType out_dtype,
+    const vkcompute::utils::StorageType in_storage,
+    const vkcompute::utils::StorageType out_storage);
+
+// Wrapper function to test both buffer and texture storage types
+void test_vulkan_dequantize_per_tensor(
+    const std::vector<int>& input_sizes,
+    float scale,
+    int zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    at::ScalarType out_dtype) {
+  // Test with buffer storage
+  test_vulkan_dequantize_per_tensor_impl(
+      input_sizes,
+      scale,
+      zero_point,
+      quant_min,
+      quant_max,
+      dtype,
+      out_dtype,
+      vkcompute::utils::kBuffer,
+      vkcompute::utils::kBuffer);
+
+  // Test with texture storage
+  test_vulkan_dequantize_per_tensor_impl(
+      input_sizes,
+      scale,
+      zero_point,
+      quant_min,
+      quant_max,
+      dtype,
+      out_dtype,
+      vkcompute::utils::kTexture3D,
+      vkcompute::utils::kTexture3D);
+}
+
+// Wrapper function to test both buffer and texture storage types
+void test_vulkan_dequantize_per_token(
+    const std::vector<int>& input_sizes,
+    const std::vector<float>& scales,
+    const std::vector<int>& zero_points,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    at::ScalarType out_dtype) {
+  // Test with buffer storage
+  test_vulkan_dequantize_per_token_impl(
+      input_sizes,
+      scales,
+      zero_points,
+      quant_min,
+      quant_max,
+      dtype,
+      out_dtype,
+      vkcompute::utils::kBuffer,
+      vkcompute::utils::kBuffer);
+
+  // Test with texture storage
+  test_vulkan_dequantize_per_token_impl(
+      input_sizes,
+      scales,
+      zero_points,
+      quant_min,
+      quant_max,
+      dtype,
+      out_dtype,
+      vkcompute::utils::kTexture3D,
+      vkcompute::utils::kTexture3D);
+}
+
+void test_reference_dequantize_per_tensor(
+    const std::vector<int>& input_sizes,
+    float scale,
+    int zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    at::ScalarType out_dtype) {
+  check_dequantize_args(quant_min, quant_max, dtype, out_dtype);
+  std::vector<int64_t> input_sizes_int64(
+      input_sizes.begin(), input_sizes.end());
+
+  // Create a quantized input tensor with values from quant_min to quant_max
+  at::Tensor input;
+  if (dtype == at::kByte) {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte));
+  } else if (dtype == at::kChar) {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar));
+  } else if (dtype == at::kShort) {
+    input =
+        at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort));
+  } else if (dtype == at::kInt) {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt));
+  } else {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong));
+  }
+
+  // Fill with a simple pattern: values from quant_min to quant_max in steps
+  float step = 1.0f;
+  if (input.numel() > 1) {
+    step = static_cast<float>(quant_max - quant_min) / (input.numel() - 1);
+  }
+
+  auto flat_input = input.flatten();
+  for (int i = 0; i < flat_input.numel(); i++) {
+    int64_t qvalue = quant_min + i * step;
+    if (dtype == at::kByte) {
+      flat_input[i] = static_cast<uint8_t>(qvalue);
+    } else if (dtype == at::kChar) {
+      flat_input[i] = static_cast<int8_t>(qvalue);
+    } else if (dtype == at::kShort) {
+      flat_input[i] = static_cast<int16_t>(qvalue);
+    } else if (dtype == at::kInt) {
+      flat_input[i] = static_cast<int32_t>(qvalue);
+    } else if (dtype == at::kLong) {
+      flat_input[i] = static_cast<int64_t>(qvalue);
+    }
+  }
+
+  // Reshape back to original dimensions
+  input = flat_input.reshape(input_sizes_int64);
+
+  // Get reference output
+  at::Tensor reference_out = dequantize_per_tensor_reference_impl(
+      input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
+
+  // Get implementation output
+  at::Tensor impl_out = torch::executor::native::dequantize_per_tensor_aten(
+      input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
+
+  // Compare outputs
+  const bool output_correct = at::allclose(reference_out, impl_out);
+  if (!output_correct) {
+    std::cout << "\n"
+              << "Failed with parameters: " << std::endl;
+    std::cout << "  scale: " << scale << std::endl;
+    std::cout << "  zero_point: " << zero_point << std::endl;
+    std::cout << "  quant_min: " << quant_min << std::endl;
+    std::cout << "  quant_max: " << quant_max << std::endl;
+
+    std::cout << "input:" << std::endl;
+    std::cout << input << std::endl;
+    std::cout << "reference:" << std::endl;
+    std::cout << reference_out << std::endl;
+    std::cout << "implementation:" << std::endl;
+    std::cout << impl_out << std::endl;
+  }
+
+  ASSERT_TRUE(output_correct);
+}
+
+void test_vulkan_dequantize_per_tensor_impl(
+    const std::vector<int>& input_sizes,
+    float scale,
+    int zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    at::ScalarType out_dtype,
+    const vkcompute::utils::StorageType in_storage,
+    const vkcompute::utils::StorageType out_storage) {
+  check_dequantize_args(quant_min, quant_max, dtype, out_dtype);
+  std::vector<int64_t> input_sizes_int64(
+      input_sizes.begin(), input_sizes.end());
+
+  // Create a quantized input tensor with values from quant_min to quant_max
+  at::Tensor input;
+  if (dtype == at::kByte) {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte));
+  } else if (dtype == at::kChar) {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar));
+  } else if (dtype == at::kShort) {
+    input =
+        at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort));
+  } else if (dtype == at::kInt) {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt));
+  } else {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong));
+  }
+
+  // Fill with a simple pattern: values from quant_min to quant_max in steps
+  float step = 1.0f;
+  if (input.numel() > 1) {
+    step = static_cast<float>(quant_max - quant_min) / (input.numel() - 1);
+  }
+
+  auto flat_input = input.flatten();
+  for (int i = 0; i < flat_input.numel(); i++) {
+    int64_t qvalue = quant_min + i * step;
+    if (dtype == at::kByte) {
+      flat_input[i] = static_cast<uint8_t>(qvalue);
+    } else if (dtype == at::kChar) {
+      flat_input[i] = static_cast<int8_t>(qvalue);
+    } else if (dtype == at::kShort) {
+      flat_input[i] = static_cast<int16_t>(qvalue);
+    } else if (dtype == at::kInt) {
+      flat_input[i] = static_cast<int32_t>(qvalue);
+    } else if (dtype == at::kLong) {
+      flat_input[i] = static_cast<int64_t>(qvalue);
+    }
+  }
+
+  // Reshape back to original dimensions
+  input = flat_input.reshape(input_sizes_int64);
+
+  // Get reference output
+  at::Tensor reference_out =
+      torch::executor::native::dequantize_per_tensor_aten(
+          input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
+
+  // Build Vulkan dequantize_per_tensor graph
+  using namespace vkcompute;
+
+  GraphConfig config;
+  config.set_storage_type_override(in_storage);
+  ComputeGraph graph(config);
+
+  IOValueRef r_input = graph.add_input_tensor(
+      input.sizes().vec(), from_at_scalartype(dtype), in_storage);
+
+  const ValueRef r_scale = graph.add_scalar<double>(scale);
+  const ValueRef r_zero_point = graph.add_scalar<int64_t>(zero_point);
+  const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min);
+  const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max);
+
+  const ValueRef r_out = graph.add_tensor(
+      input.sizes().vec(), from_at_scalartype(out_dtype), out_storage);
+
+  VK_GET_OP_FN("dequantize_per_tensor.default")
+  (graph,
+   {
+       r_input.value,
+       r_scale,
+       r_zero_point,
+       r_quant_min,
+       r_quant_max,
+       r_out,
+   });
+
+  ValueRef staging_out = graph.set_output_tensor(r_out);
+
+  graph.prepare();
+  graph.encode_prepack();
+  graph.prepack();
+  graph.encode_execute();
+
+  // Run Vulkan dequantize_per_tensor
+  graph.copy_into_staging(
+      r_input.staging, input.const_data_ptr(), input.numel());
+
+  graph.execute();
+
+  at::Tensor vk_out = at::empty_like(reference_out).contiguous();
+  graph.copy_from_staging(
+      staging_out, vk_out.mutable_data_ptr(), vk_out.numel());
+
+  // Compare outputs
+  const bool output_correct = at::allclose(reference_out, vk_out);
+  if (!output_correct) {
+    std::cout << "\n"
+              << "Failed with parameters: " << std::endl;
+    std::cout << "  scale: " << scale << std::endl;
+    std::cout << "  zero_point: " << zero_point << std::endl;
+    std::cout << "  quant_min: " << quant_min << std::endl;
+    std::cout << "  quant_max: " << quant_max << std::endl;
+    std::cout << "  storage type: "
+              << (in_storage == vkcompute::utils::kBuffer ? "buffer"
+                                                          : "texture")
+              << std::endl;
+
+    std::cout << "input:" << std::endl;
+    std::cout << input << std::endl;
+    std::cout << "reference:" << std::endl;
+    std::cout << reference_out << std::endl;
+    std::cout << "vulkan:" << std::endl;
+    std::cout << vk_out << std::endl;
+  }
+
+  ASSERT_TRUE(output_correct);
+}
+
+// Test cases for dequantize_per_tensor
+TEST(
+    VulkanDequantizePerTensorTest,
+    test_reference_dequantize_per_tensor_uint8_to_float) {
+  test_reference_dequantize_per_tensor(
+      {2, 3, 4}, // input sizes
+      0.1, // scale
+      5, // zero_point
+      0, // quant_min
+      255, // quant_max
+      at::kByte, // input dtype
+      at::kFloat); // output dtype
+}
+
+TEST(
+    VulkanDequantizePerTensorTest,
+    test_reference_dequantize_per_tensor_int8_to_float) {
+  test_reference_dequantize_per_tensor(
+      {3, 4, 5}, // input sizes
+      0.05, // scale
+      0, // zero_point
+      -128, // quant_min
+      127, // quant_max
+      at::kChar, // input dtype
+      at::kFloat); // output dtype
+}
+
+TEST(
+    VulkanDequantizePerTensorTest,
+    test_reference_dequantize_per_tensor_int32_to_float) {
+  test_reference_dequantize_per_tensor(
+      {4, 6, 2}, // input sizes
+      0.2, // scale
+      2, // zero_point
+      std::numeric_limits<int32_t>::min(), // quant_min
+      std::numeric_limits<int32_t>::max(), // quant_max
+      at::kInt, // input dtype
+      at::kFloat); // output dtype
+}
+
+TEST(
+    VulkanDequantizePerTensorTest,
+    test_reference_dequantize_per_tensor_uint8_to_half) {
+  test_reference_dequantize_per_tensor(
+      {7, 4}, // input sizes
+      0.1, // scale
+      10, // zero_point
+      0, // quant_min
+      255, // quant_max
+      at::kByte, // input dtype (uint8)
+      at::kHalf); // output dtype
+}
+
+TEST(
+    VulkanDequantizePerTensorTest,
+    test_reference_dequantize_per_tensor_int32_to_half) {
+  test_reference_dequantize_per_tensor(
+      {2, 6, 5}, // input sizes
+      0.3, // scale
+      -10, // zero_point
+      std::numeric_limits<int32_t>::min(), // quant_min
+      std::numeric_limits<int32_t>::max(), // quant_max
+      at::kInt, // input dtype
+      at::kHalf); // output dtype
+}
+
+void test_reference_dequantize_per_token(
+    const std::vector<int>& input_sizes,
+    const std::vector<float>& scales,
+    const std::vector<int>& zero_points,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    at::ScalarType out_dtype) {
+  check_dequantize_args(quant_min, quant_max, dtype, out_dtype);
+  int num_tokens = 1;
+  for (int i = 0; i < input_sizes.size() - 1; i++) {
+    num_tokens *= input_sizes[i];
+  }
+
+  ASSERT_EQ(num_tokens, scales.size());
+  ASSERT_EQ(num_tokens, zero_points.size());
+
+  // Create input tensor with quantized values
+  std::vector<int64_t> input_sizes_int64(
+      input_sizes.begin(), input_sizes.end());
+  at::Tensor input;
+  if (dtype == at::kByte) {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte));
+  } else if (dtype == at::kChar) {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar));
+  } else if (dtype == at::kShort) {
+    input =
+        at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort));
+  } else if (dtype == at::kInt) {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt));
+  } else {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong));
+  }
+
+  // Fill with a simple pattern: values from quant_min to quant_max in steps
+  at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)});
+  for (int token_idx = 0; token_idx < num_tokens; token_idx++) {
+    float step = 1.0f;
+    if (input.size(-1) > 1) {
+      step = static_cast<float>(quant_max - quant_min) / (input.size(-1) - 1);
+    }
+
+    for (int i = 0; i < input.size(-1); i++) {
+      int64_t qvalue = quant_min + i * step;
+      if (dtype == at::kByte) {
+        reshaped_input[token_idx][i] = static_cast<uint8_t>(qvalue);
+      } else if (dtype == at::kChar) {
+        reshaped_input[token_idx][i] = static_cast<int8_t>(qvalue);
+      } else if (dtype == at::kShort) {
+        reshaped_input[token_idx][i] = static_cast<int16_t>(qvalue);
+      } else if (dtype == at::kInt) {
+        reshaped_input[token_idx][i] = static_cast<int32_t>(qvalue);
+      } else if (dtype == at::kLong) {
+        reshaped_input[token_idx][i] = static_cast<int64_t>(qvalue);
+      }
+    }
+  }
+
+  // Reshape back to original dimensions
+  input = reshaped_input.reshape(input_sizes_int64);
+
+  // Create scale and zero_point tensors
+  at::Tensor scale_tensor =
+      at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat));
+  at::Tensor zero_point_tensor =
+      at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong));
+
+  // Get reference output
+  at::Tensor reference_out = dequantize_per_token_reference_impl(
+      input,
+      scale_tensor,
+      zero_point_tensor,
+      quant_min,
+      quant_max,
+      dtype,
+      out_dtype);
+
+  // Get implementation output
+  at::Tensor impl_out = torch::executor::native::dequantize_per_token_aten(
+      input,
+      scale_tensor,
+      zero_point_tensor,
+      quant_min,
+      quant_max,
+      dtype,
+      out_dtype);
+
+  // Compare outputs
+  const bool output_correct = at::allclose(reference_out, impl_out);
+  if (!output_correct) {
+    std::cout << "\n"
+              << "Failed with parameters: " << std::endl;
+    std::cout << "  scale(s):";
+    for (size_t i = 0; i < scales.size(); i++) {
+      std::cout << " " << scales[i] << " ";
+    }
+    std::cout << "" << std::endl;
+    std::cout << "  zero_point(s):";
+    for (size_t i = 0; i < zero_points.size(); i++) {
+      std::cout << " " << zero_points[i] << " ";
+    }
+    std::cout << "" << std::endl;
+    std::cout << "  quant_min: " << quant_min << std::endl;
+    std::cout << "  quant_max: " << quant_max << std::endl;
+
+    std::cout << "input:" << std::endl;
+    std::cout << input << std::endl;
+    std::cout << "reference:" << std::endl;
+    std::cout << reference_out << std::endl;
+    std::cout << "implementation:" << std::endl;
+    std::cout << impl_out << std::endl;
+  }
+
+  ASSERT_TRUE(output_correct);
+}
+
+void test_vulkan_dequantize_per_token_impl(
+    const std::vector<int>& input_sizes,
+    const std::vector<float>& scales,
+    const std::vector<int>& zero_points,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype,
+    at::ScalarType out_dtype,
+    const vkcompute::utils::StorageType in_storage,
+    const vkcompute::utils::StorageType out_storage) {
+  check_dequantize_args(quant_min, quant_max, dtype, out_dtype);
+  int num_tokens = 1;
+  for (int i = 0; i < input_sizes.size() - 1; i++) {
+    num_tokens *= input_sizes[i];
+  }
+
+  ASSERT_EQ(num_tokens, scales.size());
+  ASSERT_EQ(num_tokens, zero_points.size());
+
+  // Create input tensor with quantized values
+  std::vector<int64_t> input_sizes_int64(
+      input_sizes.begin(), input_sizes.end());
+  at::Tensor input;
+  if (dtype == at::kByte) {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte));
+  } else if (dtype == at::kChar) {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar));
+  } else if (dtype == at::kShort) {
+    input =
+        at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort));
+  } else if (dtype == at::kInt) {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt));
+  } else {
+    input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong));
+  }
+
+  // Fill with a simple pattern: values from quant_min to quant_max in steps
+  at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)});
+  for (int token_idx = 0; token_idx < num_tokens; token_idx++) {
+    float step = 1.0f;
+    if (input.size(-1) > 1) {
+      step = static_cast<float>(quant_max - quant_min) / (input.size(-1) - 1);
+    }
+
+    for (int i = 0; i < input.size(-1); i++) {
+      int64_t qvalue = quant_min + i * step;
+      if (dtype == at::kByte) {
+        reshaped_input[token_idx][i] = static_cast<uint8_t>(qvalue);
+      } else if (dtype == at::kChar) {
+        reshaped_input[token_idx][i] = static_cast<int8_t>(qvalue);
+      } else if (dtype == at::kShort) {
+        reshaped_input[token_idx][i] = static_cast<int16_t>(qvalue);
+      } else if (dtype == at::kInt) {
+        reshaped_input[token_idx][i] = static_cast<int32_t>(qvalue);
+      } else if (dtype == at::kLong) {
+        reshaped_input[token_idx][i] = static_cast<int64_t>(qvalue);
+      }
+    }
+  }
+
+  // Reshape back to original dimensions
+  input = reshaped_input.reshape(input_sizes_int64);
+
+  // Create scale and zero_point tensors
+  at::Tensor scale_tensor =
+      at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat));
+  at::Tensor zero_point_tensor =
+      at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong));
+
+  // Get reference output
+  at::Tensor reference_out = torch::executor::native::dequantize_per_token_aten(
+      input,
+      scale_tensor,
+      zero_point_tensor,
+      quant_min,
+      quant_max,
+      dtype,
+      out_dtype);
+
+  // Build Vulkan dequantize_per_token graph
+  using namespace vkcompute;
+
+  GraphConfig config;
+  config.set_storage_type_override(in_storage);
+  ComputeGraph graph(config);
+
+  IOValueRef r_input = graph.add_input_tensor(
+      input.sizes().vec(), from_at_scalartype(dtype), in_storage);
+  IOValueRef r_scale = graph.add_input_tensor(
+      scale_tensor.sizes().vec(), vkapi::kFloat, in_storage);
+  IOValueRef r_zero_point = graph.add_input_tensor(
+      zero_point_tensor.sizes().vec(), vkapi::kInt, in_storage);
+
+  const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min);
+  const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max);
+
+  const ValueRef r_out = graph.add_tensor(
+      input.sizes().vec(), from_at_scalartype(out_dtype), out_storage);
+
+  VK_GET_OP_FN("dequantize_per_token.default")
+  (graph,
+   {
+       r_input.value,
+       r_scale.value,
+       r_zero_point.value,
+       r_quant_min,
+       r_quant_max,
+       r_out,
+   });
+
+  ValueRef staging_out = graph.set_output_tensor(r_out);
+
+  graph.prepare();
+  graph.encode_prepack();
+  graph.prepack();
+  graph.encode_execute();
+
+  // Copy input data to GPU
+  graph.copy_into_staging(
+      r_input.staging, input.const_data_ptr(), input.numel());
+
+  // Convert scale tensor to float and copy to GPU
+  at::Tensor scale_float = scale_tensor.to(at::kFloat);
+  graph.copy_into_staging(
+      r_scale.staging, scale_float.const_data_ptr(), scale_float.numel());
+
+  // Convert zero_point tensor to int and copy to GPU
+  at::Tensor zero_point_int = zero_point_tensor.to(at::kInt);
+  graph.copy_into_staging(
+      r_zero_point.staging,
+      zero_point_int.const_data_ptr(),
+      zero_point_int.numel());
+
+  // Execute the graph
+  graph.execute();
+
+  // Copy output data back to CPU
+  at::Tensor vk_out = at::empty_like(reference_out).contiguous();
+  graph.copy_from_staging(
+      staging_out, vk_out.mutable_data_ptr(), vk_out.numel());
+
+  // Compare outputs
+  const bool output_correct = at::allclose(reference_out, vk_out);
+  if (!output_correct) {
+    std::cout << "\n"
+              << "Failed with parameters: " << std::endl;
+    std::cout << "  scale(s):";
+    for (size_t i = 0; i < scales.size(); i++) {
+      std::cout << " " << scales[i] << " ";
+    }
+    std::cout << "" << std::endl;
+    std::cout << "  zero_point(s):";
+    for (size_t i = 0; i < zero_points.size(); i++) {
+      std::cout << " " << zero_points[i] << " ";
+    }
+    std::cout << "" << std::endl;
+    std::cout << "  quant_min: " << quant_min << std::endl;
+    std::cout << "  quant_max: " << quant_max << std::endl;
+    std::cout << "  storage type: "
+              << (in_storage == vkcompute::utils::kBuffer ? "buffer"
+                                                          : "texture")
+              << std::endl;
+
+    std::cout << "input:" << std::endl;
+    std::cout << input << std::endl;
+    std::cout << "reference:" << std::endl;
+    std::cout << reference_out << std::endl;
+    std::cout << "vulkan:" << std::endl;
+    std::cout << vk_out << std::endl;
+  }
+
+  ASSERT_TRUE(output_correct);
+}
+
+// Test cases for dequantize_per_token
+TEST(
+    VulkanDequantizePerTokenTest,
+    test_reference_dequantize_per_token_uint8_to_float) {
+  std::vector<float> scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6};
+  std::vector<int> zero_points = {5, 10, 15, 20, 25, 30};
+
+  test_reference_dequantize_per_token(
+      {2, 3, 4}, // input sizes (2*3=6 tokens)
+      scales,
+      zero_points,
+      0, // quant_min
+      255, // quant_max
+      at::kByte, // input dtype
+      at::kFloat); // output dtype
+}
+
+TEST(
+    VulkanDequantizePerTokenTest,
+    test_reference_dequantize_per_token_int8_to_float) {
+  std::vector<float> scales = {0.05, 0.1, 0.15, 0.2};
+  std::vector<int> zero_points = {0, -5, 5, 10};
+
+  test_reference_dequantize_per_token(
+      {2, 2, 5}, // input sizes (2*2=4 tokens)
+      scales,
+      zero_points,
+      -128, // quant_min
+      127, // quant_max
+      at::kChar, // input dtype
+      at::kFloat); // output dtype
+}
+
+TEST(
+    VulkanDequantizePerTokenTest,
+    test_reference_dequantize_per_token_int32_to_float) {
+  std::vector<float> scales = {0.05, 0.1, 0.15, 0.2};
+  std::vector<int> zero_points = {0, -5, 5, 10};
+
+  test_reference_dequantize_per_token(
+      {2, 2, 10}, // input sizes (2*2=4 tokens)
+      scales,
+      zero_points,
+      std::numeric_limits<int32_t>::min(), // quant_min
+      std::numeric_limits<int32_t>::max(), // quant_max
+      at::kInt, // input dtype
+      at::kFloat); // output dtype
+}
+
+TEST(
+    VulkanDequantizePerTokenTest,
+    test_reference_dequantize_per_token_int8_to_half) {
+  std::vector<float> scales = {0.05, 0.1, 0.15, 0.2};
+  std::vector<int> zero_points = {0, -5, 5, 10};
+
+  test_reference_dequantize_per_token(
+      {4, 1, 5}, // input sizes (4*1=4 tokens)
+      scales,
+      zero_points,
+      -128, // quant_min
+      127, // quant_max
+      at::kChar, // input dtype (int8)
+      at::kHalf); // output dtype
+}
+
+TEST(
+    VulkanDequantizePerTokenTest,
+    test_reference_dequantize_per_token_int32_to_half) {
+  std::vector<float> scales = {0.05, 0.1};
+  std::vector<int> zero_points = {0, -5};
+
+  test_reference_dequantize_per_token(
+      {2, 2}, // input sizes (2 tokens)
+      scales,
+      zero_points,
+      std::numeric_limits<int32_t>::min(), // quant_min
+      std::numeric_limits<int32_t>::max(), // quant_max
+      at::kInt, // input dtype
+      at::kHalf); // output dtype
+}
diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp
index b95b7b3aa6d..e48042c4620 100644
--- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp
+++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp
@@ -14,6 +14,8 @@
 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
 
+#include "test_utils.h"
+
 #include <cassert>
 
 //
@@ -201,26 +203,6 @@ void test_reference_linear_qcs4w(
   ASSERT_TRUE(at::allclose(out, out_ref));
 }
 
-vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
-  using namespace vkcompute;
-  switch (at_scalartype) {
-    case c10::kFloat:
-      return vkapi::kFloat;
-    case c10::kHalf:
-      return vkapi::kHalf;
-    case c10::kInt:
-      return vkapi::kInt;
-    case c10::kLong:
-      return vkapi::kInt;
-    case c10::kChar:
-      return vkapi::kChar;
-    case c10::kByte:
-      return vkapi::kByte;
-    default:
-      VK_THROW("Unsupported at::ScalarType!");
-  }
-}
-
 void test_vulkan_linear_qga4w_impl(
     const int B,
     const int M,
diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp
new file mode 100644
index 00000000000..8b79dc1ce6b
--- /dev/null
+++ b/backends/vulkan/test/op_tests/quantize_test.cpp
@@ -0,0 +1,843 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <gtest/gtest.h>
+
+#include <ATen/ATen.h>
+
+#include <executorch/backends/vulkan/runtime/api/api.h>
+#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
+#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
+
+#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
+#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
+
+#include "test_utils.h"
+
+#include <cassert>
+#include <iostream>
+
+namespace torch {
+namespace executor {
+namespace native {
+
+// Forward declarations of the functions we're testing
+Tensor& quantize_per_tensor_out(
+    const Tensor& input,
+    double scale,
+    int64_t zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    ScalarType dtype,
+    Tensor& out);
+
+Tensor& quantize_per_token_out(
+    const Tensor& input,
+    const Tensor& scale,
+    const Tensor& zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    ScalarType dtype,
+    Tensor& out);
+
+// Wrapper function for quantize_per_tensor_out without context
+Tensor& quantize_per_tensor_out_no_context(
+    const Tensor& input,
+    double scale,
+    int64_t zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    ScalarType dtype,
+    Tensor& out) {
+  return torch::executor::native::quantize_per_tensor_out(
+      input, scale, zero_point, quant_min, quant_max, dtype, out);
+}
+
+// Wrapper function for quantize_per_token_out without context
+Tensor& quantize_per_token_out_no_context(
+    const Tensor& input,
+    const Tensor& scale,
+    const Tensor& zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    ScalarType dtype,
+    Tensor& out) {
+  return torch::executor::native::quantize_per_token_out(
+      input, scale, zero_point, quant_min, quant_max, dtype, out);
+}
+
+// ATen wrapper for quantize_per_tensor
+at::Tensor quantize_per_tensor_aten(
+    const at::Tensor& input,
+    double scale,
+    int64_t zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype) {
+  auto out = at::empty_like(input, dtype);
+  ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype);
+
+  WRAP_TO_ATEN(quantize_per_tensor_out_no_context, 6)
+  (input, scale, zero_point, quant_min, quant_max, et_dtype, out);
+  return out;
+}
+
+// ATen wrapper for quantize_per_token
+at::Tensor quantize_per_token_aten(
+    const at::Tensor& input,
+    const at::Tensor& scale,
+    const at::Tensor& zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype) {
+  auto out = at::empty_like(input, dtype);
+  ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype);
+
+  WRAP_TO_ATEN(quantize_per_token_out_no_context, 6)
+  (input, scale, zero_point, quant_min, quant_max, et_dtype, out);
+  return out;
+}
+
+} // namespace native
+} // namespace executor
+} // namespace torch
+
+void check_quantize_args(
+    int64_t quant_min,
+    int64_t quant_max,
+    c10::ScalarType out_dtype) {
+  using namespace vkcompute;
+  int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0;
+  switch (out_dtype) {
+    case c10::kByte:
+      quant_min_lower_bound =
+          static_cast<int32_t>(std::numeric_limits<uint8_t>::min());
+      quant_max_upper_bound =
+          static_cast<int32_t>(std::numeric_limits<uint8_t>::max());
+      break;
+    case c10::kChar:
+      quant_min_lower_bound =
+          static_cast<int32_t>(std::numeric_limits<int8_t>::min());
+      quant_max_upper_bound =
+          static_cast<int32_t>(std::numeric_limits<int8_t>::max());
+      break;
+    case c10::kBits16:
+    case c10::kUInt16:
+      quant_min_lower_bound = std::numeric_limits<uint16_t>::min();
+      quant_max_upper_bound = std::numeric_limits<uint16_t>::max();
+      break;
+    case c10::kShort:
+      quant_min_lower_bound = std::numeric_limits<int16_t>::min();
+      quant_max_upper_bound = std::numeric_limits<int16_t>::max();
+      break;
+    case c10::kInt:
+      quant_min_lower_bound = std::numeric_limits<int32_t>::min();
+      quant_max_upper_bound = std::numeric_limits<int32_t>::max();
+      break;
+    default:
+      VK_CHECK_COND(false, "Unsupported dtype: ", scalar_type_name(out_dtype));
+  }
+  VK_CHECK_COND(
+      quant_min >= quant_min_lower_bound,
+      "quant_min out of bound for dtype, expected quant_min_lower_bound: ",
+      quant_min_lower_bound,
+      " actual quant_min: ",
+      quant_min);
+
+  VK_CHECK_COND(
+      quant_max <= quant_max_upper_bound,
+      "quant_max out of bound for dtype, expected quant_max_upper_bound: ",
+      quant_max_upper_bound,
+      " actual quant_max: ",
+      quant_max);
+}
+
+//
+// Reference Implementation
+//
+
+/*
+ * Reference implementation of quantize_per_tensor
+ */
+at::Tensor quantize_per_tensor_reference_impl(
+    const at::Tensor& input,
+    double scale,
+    int64_t zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype) {
+  // Create output tensor with the target dtype
+  at::Tensor out = at::empty_like(input, dtype);
+
+  // Quantize the input tensor
+  float inv_scale = 1.0 / scale;
+
+  // Iterate through the tensor and quantize each element
+  at::Tensor float_input = input.to(at::kFloat);
+  at::Tensor float_values = float_input.flatten();
+
+  auto out_flat = out.flatten();
+
+  for (int i = 0; i < float_values.numel(); i++) {
+    float value = float_values[i].item<float>();
+    int64_t qvalue = zero_point + std::nearbyint(inv_scale * value);
+
+    qvalue = std::max<int64_t>(qvalue, quant_min);
+    qvalue = std::min<int64_t>(qvalue, quant_max);
+
+    if (dtype == at::kByte) {
+      out_flat[i] = static_cast<uint8_t>(qvalue);
+    } else if (dtype == at::kChar) {
+      out_flat[i] = static_cast<int8_t>(qvalue);
+    } else if (dtype == at::kShort) {
+      out_flat[i] = static_cast<int16_t>(qvalue);
+    } else if (dtype == at::kInt) {
+      out_flat[i] = static_cast<int32_t>(qvalue);
+    } else if (dtype == at::kLong) {
+      out_flat[i] = static_cast<int64_t>(qvalue);
+    }
+  }
+
+  return out.reshape(input.sizes());
+}
+
+/*
+ * Reference implementation of quantize_per_token
+ */
+at::Tensor quantize_per_token_reference_impl(
+    const at::Tensor& input,
+    const at::Tensor& scale,
+    const at::Tensor& zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType dtype) {
+  // Create output tensor with the target dtype
+  at::Tensor out = at::empty_like(input, dtype);
+
+  // Calculate number of tokens
+  int num_tokens = 1;
+  for (int i = 0; i < input.dim() - 1; i++) {
+    num_tokens *= input.size(i);
+  }
+
+  // Verify that the number of tokens matches the size of scale and zero_point
+  // tensors
+  assert(num_tokens == scale.numel());
+  assert(num_tokens == zero_point.numel());
+
+  // Reshape input to [num_tokens, last_dim]
+  at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)});
+  at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)});
+
+  // Quantize each token separately
+  for (int token_idx = 0; token_idx < num_tokens; token_idx++) {
+    // Use float for scale since Vulkan doesn't support double
+    float token_scale = scale[token_idx].item<float>();
+    // Use int for zero_point since Vulkan doesn't support int64_t
+    int token_zero_point = zero_point[token_idx].item<int>();
+
+    float inv_scale = 1.0 / token_scale;
+
+    // Quantize the token
+    for (int i = 0; i < input.size(-1); i++) {
+      float value = reshaped_input[token_idx][i].item<float>();
+      int qvalue = token_zero_point + std::nearbyint(inv_scale * value);
+
+      qvalue = std::max<int64_t>(qvalue, quant_min);
+      qvalue = std::min<int64_t>(qvalue, quant_max);
+
+      if (dtype == at::kByte) {
+        reshaped_out[token_idx][i] = static_cast<uint8_t>(qvalue);
+      } else if (dtype == at::kChar) {
+        reshaped_out[token_idx][i] = static_cast<int8_t>(qvalue);
+      } else if (dtype == at::kShort) {
+        reshaped_out[token_idx][i] = static_cast<int16_t>(qvalue);
+      } else if (dtype == at::kInt) {
+        reshaped_out[token_idx][i] = static_cast<int32_t>(qvalue);
+      } else if (dtype == at::kLong) {
+        reshaped_out[token_idx][i] = static_cast<int64_t>(qvalue);
+      }
+    }
+  }
+
+  return out;
+}
+
+// Forward declaration of implementation functions
+void test_vulkan_quantize_per_tensor_impl(
+    const std::vector<int>& input_sizes,
+    float scale,
+    int zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType in_dtype,
+    at::ScalarType dtype,
+    const vkcompute::utils::StorageType in_storage,
+    const vkcompute::utils::StorageType out_storage);
+
+void test_vulkan_quantize_per_token_impl(
+    const std::vector<int>& input_sizes,
+    const std::vector<float>& scales,
+    const std::vector<int>& zero_points,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType in_dtype,
+    at::ScalarType dtype,
+    const vkcompute::utils::StorageType in_storage,
+    const vkcompute::utils::StorageType out_storage);
+
+// Wrapper function to test both buffer and texture storage types
+void test_vulkan_quantize_per_tensor(
+    const std::vector<int>& input_sizes,
+    float scale,
+    int zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType in_dtype = at::kFloat,
+    at::ScalarType dtype = at::kInt) {
+  // Test with buffer storage
+  test_vulkan_quantize_per_tensor_impl(
+      input_sizes,
+      scale,
+      zero_point,
+      quant_min,
+      quant_max,
+      in_dtype,
+      dtype,
+      vkcompute::utils::kBuffer,
+      vkcompute::utils::kBuffer);
+
+  // Test with texture storage
+  test_vulkan_quantize_per_tensor_impl(
+      input_sizes,
+      scale,
+      zero_point,
+      quant_min,
+      quant_max,
+      in_dtype,
+      dtype,
+      vkcompute::utils::kTexture3D,
+      vkcompute::utils::kTexture3D);
+}
+
+// Wrapper function to test both buffer and texture storage types
+void test_vulkan_quantize_per_token(
+    const std::vector<int>& input_sizes,
+    const std::vector<float>& scales,
+    const std::vector<int>& zero_points,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType in_dtype = at::kFloat,
+    at::ScalarType dtype = at::kInt) {
+  // Test with buffer storage
+  test_vulkan_quantize_per_token_impl(
+      input_sizes,
+      scales,
+      zero_points,
+      quant_min,
+      quant_max,
+      in_dtype,
+      dtype,
+      vkcompute::utils::kBuffer,
+      vkcompute::utils::kBuffer);
+
+  // Test with texture storage
+  test_vulkan_quantize_per_token_impl(
+      input_sizes,
+      scales,
+      zero_points,
+      quant_min,
+      quant_max,
+      in_dtype,
+      dtype,
+      vkcompute::utils::kTexture3D,
+      vkcompute::utils::kTexture3D);
+}
+
+void test_reference_quantize_per_tensor(
+    const std::vector<int>& input_sizes,
+    float scale,
+    int zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType in_dtype = at::kFloat,
+    at::ScalarType dtype = at::kInt) {
+  check_quantize_args(quant_min, quant_max, dtype);
+  std::vector<int64_t> input_sizes_int64(
+      input_sizes.begin(), input_sizes.end());
+  at::Tensor input =
+      at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype));
+
+  // Fill with a simple pattern: values from 0 to 1 in steps
+  float step = 1.0f / (input.numel() - 1);
+  auto flat_input = input.flatten();
+  for (int i = 0; i < flat_input.numel(); i++) {
+    flat_input[i] = i * step;
+  }
+
+  // Reshape back to original dimensions
+  input = flat_input.reshape(input_sizes_int64);
+
+  // Get reference output
+  at::Tensor reference_out = quantize_per_tensor_reference_impl(
+      input, scale, zero_point, quant_min, quant_max, dtype);
+
+  // Get implementation output
+  at::Tensor impl_out = torch::executor::native::quantize_per_tensor_aten(
+      input, scale, zero_point, quant_min, quant_max, dtype);
+
+  // Convert to int for consistent display regardless of underlying type
+  at::Tensor reference_int = reference_out.to(at::kInt);
+  at::Tensor impl_int = impl_out.to(at::kInt);
+
+  const bool output_correct = at::equal(reference_int, impl_int);
+  if (!output_correct) {
+    at::Tensor diffs = at::abs(reference_int - impl_int);
+
+    std::cout << "\n"
+              << "Failed with parameters: " << std::endl;
+    std::cout << "  scale: " << scale << std::endl;
+    std::cout << "  zero_point: " << zero_point << std::endl;
+    std::cout << "  quant_min: " << quant_min << std::endl;
+    std::cout << "  quant_max: " << quant_max << std::endl;
+
+    std::cout << "input:" << std::endl;
+    std::cout << input << std::endl;
+    std::cout << "reference:" << std::endl;
+    std::cout << reference_int << std::endl;
+    std::cout << "my_reference:" << std::endl;
+    std::cout << impl_int << std::endl;
+  }
+
+  ASSERT_TRUE(output_correct);
+}
+
+void test_vulkan_quantize_per_tensor_impl(
+    const std::vector<int>& input_sizes,
+    float scale,
+    int zero_point,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType in_dtype = at::kFloat,
+    at::ScalarType dtype = at::kInt,
+    const vkcompute::utils::StorageType in_storage =
+        vkcompute::utils::kTexture3D,
+    const vkcompute::utils::StorageType out_storage =
+        vkcompute::utils::kTexture3D) {
+  check_quantize_args(quant_min, quant_max, dtype);
+  std::vector<int64_t> input_sizes_int64(
+      input_sizes.begin(), input_sizes.end());
+  at::Tensor input =
+      at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype));
+
+  // Get reference output
+  at::Tensor reference_out = torch::executor::native::quantize_per_tensor_aten(
+      input, scale, zero_point, quant_min, quant_max, dtype);
+
+  // Build Vulkan quantize_per_tensor graph
+  using namespace vkcompute;
+
+  GraphConfig config;
+  config.set_storage_type_override(in_storage);
+  ComputeGraph graph(config);
+
+  IOValueRef r_input = graph.add_input_tensor(
+      input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage);
+
+  const ValueRef r_scale = graph.add_scalar<double>(scale);
+  const ValueRef r_zero_point = graph.add_scalar<int64_t>(zero_point);
+  const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min);
+  const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max);
+
+  const ValueRef r_out = graph.add_tensor(
+      input.sizes().vec(), from_at_scalartype(dtype), out_storage);
+
+  VK_GET_OP_FN("quantize_per_tensor.default")
+  (graph,
+   {
+       r_input.value,
+       r_scale,
+       r_zero_point,
+       r_quant_min,
+       r_quant_max,
+       r_out,
+   });
+
+  ValueRef staging_out = graph.set_output_tensor(r_out);
+
+  graph.prepare();
+  graph.encode_prepack();
+  graph.prepack();
+  graph.encode_execute();
+
+  // Run Vulkan quantize_per_tensor
+  graph.copy_into_staging(
+      r_input.staging, input.const_data_ptr(), input.numel());
+
+  graph.execute();
+
+  at::Tensor vk_out = at::empty_like(reference_out).contiguous();
+  graph.copy_from_staging(
+      staging_out, vk_out.mutable_data_ptr(), vk_out.numel());
+
+  // Compare outputs
+  // For quantized types, we need to compare the actual integer values
+  at::Tensor reference_int = reference_out.to(at::kInt);
+  at::Tensor vk_int = vk_out.to(at::kInt);
+
+  const bool output_correct = at::equal(reference_int, vk_int);
+  if (!output_correct) {
+    at::Tensor diffs = at::abs(reference_int - vk_int);
+
+    std::cout << "\n"
+              << "Failed with parameters: " << std::endl;
+    std::cout << "  scale: " << scale << std::endl;
+    std::cout << "  zero_point: " << zero_point << std::endl;
+    std::cout << "  quant_min: " << quant_min << std::endl;
+    std::cout << "  quant_max: " << quant_max << std::endl;
+
+    std::cout << "input:" << std::endl;
+    std::cout << input << std::endl;
+    std::cout << "reference:" << std::endl;
+    std::cout << reference_int << std::endl;
+    std::cout << "vulkan:" << std::endl;
+    std::cout << vk_int << std::endl;
+  }
+
+  ASSERT_TRUE(output_correct);
+}
+
+TEST(
+    VulkanQuantizePerTensorTest,
+    test_reference_quantize_per_tensor_float_to_int8) {
+  test_reference_quantize_per_tensor(
+      {2, 3, 4}, // input sizes
+      0.1, // scale
+      0, // zero_point
+      -128, // quant_min
+      127, // quant_max
+      at::kFloat,
+      at::kChar);
+}
+
+TEST(
+    VulkanQuantizePerTensorTest,
+    test_reference_quantize_per_tensor_float_to_int32) {
+  test_reference_quantize_per_tensor(
+      {2, 3, 4}, // input sizes
+      0.04, // scale
+      5, // zero_point
+      std::numeric_limits<int32_t>::min(), // quant_min
+      std::numeric_limits<int32_t>::max(), // quant_max
+      at::kFloat,
+      at::kInt);
+}
+
+TEST(
+    VulkanQuantizePerTensorTest,
+    test_reference_quantize_per_tensor_half_to_uint8) {
+  test_reference_quantize_per_tensor(
+      {2, 3, 4}, // input sizes
+      0.2, // scale
+      2, // zero_point
+      0, // quant_min
+      255, // quant_max
+      at::kHalf,
+      at::kByte);
+}
+
+TEST(
+    VulkanQuantizePerTensorTest,
+    test_reference_quantize_per_tensor_half_to_int32) {
+  test_reference_quantize_per_tensor(
+      {2, 3, 4}, // input sizes
+      0.01, // scale
+      1, // zero_point
+      std::numeric_limits<int32_t>::min(), // quant_min
+      std::numeric_limits<int32_t>::max(), // quant_max
+      at::kHalf,
+      at::kInt);
+}
+
+void test_reference_quantize_per_token(
+    const std::vector<int>& input_sizes,
+    const std::vector<float>& scales,
+    const std::vector<int>& zero_points,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType in_dtype = at::kFloat,
+    at::ScalarType dtype = at::kInt) {
+  check_quantize_args(quant_min, quant_max, dtype);
+  std::vector<int64_t> input_sizes_int64(
+      input_sizes.begin(), input_sizes.end());
+  at::Tensor input =
+      at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype));
+
+  // Fill with a simple pattern: values from 0 to 1 in steps
+  float step = 1.0 / (input.numel() - 1);
+  auto flat_input = input.flatten();
+  for (int i = 0; i < flat_input.numel(); i++) {
+    flat_input[i] = i * step;
+  }
+
+  // Reshape back to original dimensions
+  input = flat_input.reshape(input_sizes_int64);
+
+  // Calculate number of tokens
+  int num_tokens = 1;
+  for (int i = 0; i < input.dim() - 1; i++) {
+    num_tokens *= input.size(i);
+  }
+
+  // Verify that the number of tokens matches the size of scales and zero_points
+  ASSERT_EQ(num_tokens, scales.size());
+  ASSERT_EQ(num_tokens, zero_points.size());
+
+  // Create scale and zero_point tensors
+  at::Tensor scale_tensor =
+      at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble));
+  at::Tensor zero_point_tensor =
+      at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong));
+
+  // Get reference output
+  at::Tensor reference_out = quantize_per_token_reference_impl(
+      input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype);
+
+  // Get implementation output
+  at::Tensor impl_out = torch::executor::native::quantize_per_token_aten(
+      input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype);
+
+  // Convert to int for consistent display regardless of underlying type
+  at::Tensor reference_int = reference_out.to(at::kInt);
+  at::Tensor impl_int = impl_out.to(at::kInt);
+
+  const bool output_correct = at::equal(reference_int, impl_out);
+  if (!output_correct) {
+    std::cout << "\n"
+              << "Failed with parameters: " << std::endl;
+    std::cout << "  scale(s):";
+    for (size_t i = 0; i < scales.size(); i++) {
+      std::cout << " " << scales[i] << " ";
+    }
+    std::cout << "" << std::endl;
+    std::cout << "  zero_point(s):";
+    for (size_t i = 0; i < zero_points.size(); i++) {
+      std::cout << " " << zero_points[i] << " ";
+    }
+    std::cout << "" << std::endl;
+    std::cout << "  quant_min: " << quant_min << std::endl;
+    std::cout << "  quant_max: " << quant_max << std::endl;
+
+    std::cout << "input:" << std::endl;
+    std::cout << input << std::endl;
+    std::cout << "reference:" << std::endl;
+    std::cout << reference_int << std::endl;
+    std::cout << "my_reference:" << std::endl;
+    std::cout << impl_out << std::endl;
+  }
+
+  ASSERT_TRUE(output_correct);
+}
+
+void test_vulkan_quantize_per_token_impl(
+    const std::vector<int>& input_sizes,
+    const std::vector<float>& scales,
+    const std::vector<int>& zero_points,
+    int64_t quant_min,
+    int64_t quant_max,
+    at::ScalarType in_dtype = at::kFloat,
+    at::ScalarType dtype = at::kInt,
+    const vkcompute::utils::StorageType in_storage =
+        vkcompute::utils::kTexture3D,
+    const vkcompute::utils::StorageType out_storage =
+        vkcompute::utils::kTexture3D) {
+  check_quantize_args(quant_min, quant_max, dtype);
+  int num_tokens = 1;
+  for (int i = 0; i < input_sizes.size() - 1; i++) {
+    num_tokens *= input_sizes[i];
+  }
+
+  ASSERT_EQ(num_tokens, scales.size());
+  ASSERT_EQ(num_tokens, zero_points.size());
+
+  // Create input tensor with random values
+  std::vector<int64_t> input_sizes_int64(
+      input_sizes.begin(), input_sizes.end());
+  at::Tensor input =
+      at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype));
+  at::Tensor scale_tensor =
+      at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble));
+  at::Tensor zero_point_tensor =
+      at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong));
+
+  // Get reference output to show what we would compare against
+  at::Tensor reference_out = torch::executor::native::quantize_per_token_aten(
+      input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype);
+
+  using namespace vkcompute;
+
+  GraphConfig config;
+  config.set_storage_type_override(in_storage);
+  ComputeGraph graph(config);
+
+  IOValueRef r_input = graph.add_input_tensor(
+      input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage);
+  IOValueRef r_scale = graph.add_input_tensor(
+      scale_tensor.sizes().vec(), vkapi::kFloat, in_storage);
+  IOValueRef r_zero_point = graph.add_input_tensor(
+      zero_point_tensor.sizes().vec(), vkapi::kInt, in_storage);
+
+  const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min);
+  const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max);
+
+  const ValueRef r_out = graph.add_tensor(
+      input.sizes().vec(), from_at_scalartype(dtype), out_storage);
+
+  VK_GET_OP_FN("quantize_per_token.default")
+  (graph,
+   {
+       r_input.value,
+       r_scale.value,
+       r_zero_point.value,
+       r_quant_min,
+       r_quant_max,
+       r_out,
+   });
+
+  ValueRef staging_out = graph.set_output_tensor(r_out);
+
+  graph.prepare();
+  graph.encode_prepack();
+  graph.prepack();
+  graph.encode_execute();
+
+  // Copy input data to GPU
+  graph.copy_into_staging(
+      r_input.staging, input.const_data_ptr(), input.numel());
+
+  // Convert scale tensor to float and copy to GPU
+  at::Tensor scale_float = scale_tensor.to(at::kFloat);
+  graph.copy_into_staging(
+      r_scale.staging, scale_float.const_data_ptr(), scale_float.numel());
+
+  // Convert zero_point tensor to int and copy to GPU
+  at::Tensor zero_point_int = zero_point_tensor.to(at::kInt);
+  graph.copy_into_staging(
+      r_zero_point.staging,
+      zero_point_int.const_data_ptr(),
+      zero_point_int.numel());
+
+  // Execute the graph
+  graph.execute();
+
+  // Copy output data back to CPU
+  at::Tensor vk_out = at::empty_like(reference_out).contiguous();
+  graph.copy_from_staging(
+      staging_out, vk_out.mutable_data_ptr(), vk_out.numel());
+
+  // Compare outputs
+  at::Tensor reference_int = reference_out.to(at::kInt);
+  at::Tensor vk_int = vk_out.to(at::kInt);
+
+  const bool output_correct = at::equal(reference_int, vk_int);
+  if (!output_correct) {
+    at::Tensor diffs = at::abs(reference_int - vk_int);
+
+    std::cout << "\n"
+              << "Failed with parameters: " << std::endl;
+    std::cout << "  scale(s):";
+    for (size_t i = 0; i < scales.size(); i++) {
+      std::cout << " " << scales[i] << " ";
+    }
+    std::cout << "" << std::endl;
+    std::cout << "  zero_point(s):";
+    for (size_t i = 0; i < zero_points.size(); i++) {
+      std::cout << " " << zero_points[i] << " ";
+    }
+    std::cout << "" << std::endl;
+    std::cout << "  quant_min: " << quant_min << std::endl;
+    std::cout << "  quant_max: " << quant_max << std::endl;
+    std::cout << "  storage type: "
+              << (in_storage == vkcompute::utils::kBuffer ? "buffer"
+                                                          : "texture")
+              << std::endl;
+
+    std::cout << "input:" << std::endl;
+    std::cout << input << std::endl;
+    std::cout << "reference:" << std::endl;
+    std::cout << reference_int << std::endl;
+    std::cout << "vulkan:" << std::endl;
+    std::cout << vk_int << std::endl;
+  }
+
+  ASSERT_TRUE(output_correct);
+}
+
+TEST(
+    VulkanQuantizePerTensorTest,
+    test_reference_quantize_per_token_float_to_int8) {
+  std::vector<float> scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3};
+  std::vector<int> zero_points = {1, 2, 3, 0, -1, -2};
+
+  test_reference_quantize_per_token(
+      {2, 3, 4}, // input sizes (2*3=6 tokens)
+      scales,
+      zero_points,
+      -128, // quant_min
+      127, // quant_max
+      at::kFloat,
+      at::kChar);
+}
+
+TEST(
+    VulkanQuantizePerTensorTest,
+    test_reference_quantize_per_token_float_to_int32) {
+  std::vector<float> scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3};
+  std::vector<int> zero_points = {1, 2, 3, 0, -1, -2};
+
+  test_reference_quantize_per_token(
+      {2, 3, 4}, // input sizes (2*3=6 tokens)
+      scales,
+      zero_points,
+      std::numeric_limits<int32_t>::min(), // quant_min
+      std::numeric_limits<int32_t>::max(), // quant_max
+      at::kFloat,
+      at::kInt);
+}
+
+TEST(
+    VulkanQuantizePerTensorTest,
+    test_reference_quantize_per_token_half_to_int32) {
+  std::vector<float> scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3};
+  std::vector<int> zero_points = {1, 2, 3, 0, -1, -2};
+
+  test_reference_quantize_per_token(
+      {2, 3, 4}, // input sizes (2*3=6 tokens)
+      scales,
+      zero_points,
+      std::numeric_limits<int32_t>::min(), // quant_min
+      std::numeric_limits<int32_t>::max(), // quant_max
+      at::kHalf,
+      at::kInt);
+}
+
+TEST(
+    VulkanQuantizePerTensorTest,
+    test_reference_quantize_per_token_half_to_uint8) {
+  std::vector<float> scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3};
+  std::vector<int> zero_points = {1, 2, 3, 0, -1, -2};
+
+  test_reference_quantize_per_token(
+      {2, 3, 4}, // input sizes (2*3=6 tokens)
+      scales,
+      zero_points,
+      0, // quant_min
+      255, // quant_max
+      at::kHalf,
+      at::kByte);
+}
diff --git a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp
index 534bb577e7a..eebbb89ab40 100644
--- a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp
+++ b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp
@@ -14,6 +14,8 @@
 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
 
+#include "test_utils.h"
+
 #include <cassert>
 
 //
@@ -55,26 +57,6 @@ std::pair<at::Tensor, at::Tensor> rotary_embedding_impl(
 // Test functions
 //
 
-vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
-  using namespace vkcompute;
-  switch (at_scalartype) {
-    case c10::kFloat:
-      return vkapi::kFloat;
-    case c10::kHalf:
-      return vkapi::kHalf;
-    case c10::kInt:
-      return vkapi::kInt;
-    case c10::kLong:
-      return vkapi::kInt;
-    case c10::kChar:
-      return vkapi::kChar;
-    case c10::kByte:
-      return vkapi::kByte;
-    default:
-      VK_THROW("Unsupported at::ScalarType!");
-  }
-}
-
 void test_reference(
     const int n_heads = 4,
     const int n_kv_heads = 2,
diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp
index 772039eda6a..79b679674a5 100644
--- a/backends/vulkan/test/op_tests/sdpa_test.cpp
+++ b/backends/vulkan/test/op_tests/sdpa_test.cpp
@@ -18,6 +18,8 @@
 #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
 #include <executorch/extension/llm/custom_ops/op_sdpa.h>
 
+#include "test_utils.h"
+
 #include <cassert>
 #include <iostream>
 
@@ -261,24 +263,6 @@ void test_reference_sdpa(
   }
 }
 
-vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
-  using namespace vkcompute;
-  switch (at_scalartype) {
-    case c10::kFloat:
-      return vkapi::kFloat;
-    case c10::kHalf:
-      return vkapi::kHalf;
-    case c10::kInt:
-      return vkapi::kInt;
-    case c10::kLong:
-      return vkapi::kInt;
-    case c10::kChar:
-      return vkapi::kChar;
-    default:
-      VK_THROW("Unsupported at::ScalarType!");
-  }
-}
-
 void test_vulkan_sdpa(
     const int start_input_pos,
     const int base_sequence_len,
diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl
index 5c9afa40762..0d014c7ef29 100644
--- a/backends/vulkan/test/op_tests/targets.bzl
+++ b/backends/vulkan/test/op_tests/targets.bzl
@@ -142,6 +142,28 @@ def define_common_targets(is_fbcode = False):
         platforms = get_platforms(),
     )
 
+    runtime.cxx_library(
+        name = "test_utils",
+        srcs = [
+            "test_utils.cpp",
+        ],
+        headers = [
+            "test_utils.h",
+        ],
+        exported_headers = [
+            "test_utils.h",
+        ],
+        deps = [
+            "//executorch/backends/vulkan:vulkan_graph_runtime",
+            "//executorch/runtime/core/exec_aten:lib",
+            runtime.external_dep_location("libtorch"),
+        ],
+        visibility = [
+            "//executorch/backends/vulkan/test/op_tests/...",
+            "@EXECUTORCH_CLIENTS",
+        ],
+    )
+
     define_test_targets(
         "compute_graph_op_tests",
         src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]"
@@ -150,9 +172,47 @@ def define_common_targets(is_fbcode = False):
     define_test_targets(
         "sdpa_test",
         extra_deps = [
+            ":test_utils",
             "//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
             "//executorch/extension/tensor:tensor",
         ]
     )
-    define_test_targets("linear_weight_int4_test")
-    define_test_targets("rotary_embedding_test")
+    define_test_targets(
+        "quantize_test",
+        extra_deps = [
+            ":test_utils",
+            "//executorch/kernels/quantized/cpu:op_quantize",
+            "//executorch/extension/tensor:tensor",
+            "//executorch/extension/aten_util:aten_bridge",
+        ]
+    )
+    define_test_targets(
+        "dequantize_test",
+        extra_deps = [
+            ":test_utils",
+            "//executorch/kernels/quantized/cpu:op_dequantize",
+            "//executorch/extension/tensor:tensor",
+            "//executorch/extension/aten_util:aten_bridge",
+        ]
+    )
+    define_test_targets(
+        "choose_qparams_test",
+        extra_deps = [
+            ":test_utils",
+            "//executorch/kernels/quantized/cpu:op_choose_qparams",
+            "//executorch/extension/tensor:tensor",
+            "//executorch/extension/aten_util:aten_bridge",
+        ]
+    )
+    define_test_targets(
+        "linear_weight_int4_test",
+        extra_deps = [
+            ":test_utils",
+        ]
+    )
+    define_test_targets(
+        "rotary_embedding_test",
+        extra_deps = [
+            ":test_utils",
+        ]
+    )
diff --git a/backends/vulkan/test/op_tests/test_utils.cpp b/backends/vulkan/test/op_tests/test_utils.cpp
new file mode 100644
index 00000000000..196f079be2c
--- /dev/null
+++ b/backends/vulkan/test/op_tests/test_utils.cpp
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include "test_utils.h"
+
+#include <stdexcept>
+
+executorch::aten::ScalarType at_scalartype_to_et_scalartype(
+    at::ScalarType dtype) {
+  using ScalarType = executorch::aten::ScalarType;
+  switch (dtype) {
+    case at::kByte:
+      return ScalarType::Byte;
+    case at::kChar:
+      return ScalarType::Char;
+    case at::kShort:
+      return ScalarType::Short;
+    case at::kInt:
+      return ScalarType::Int;
+    case at::kLong:
+      return ScalarType::Long;
+    case at::kHalf:
+      return ScalarType::Half;
+    case at::kFloat:
+      return ScalarType::Float;
+    case at::kDouble:
+      return ScalarType::Double;
+    default:
+      throw std::runtime_error("Unsupported dtype");
+  }
+}
+
+std::string scalar_type_name(c10::ScalarType dtype) {
+  switch (dtype) {
+    case c10::kLong:
+      return "c10::kLong";
+    case c10::kShort:
+      return "c10::kShort";
+    case c10::kComplexHalf:
+      return "c10::kComplexHalf";
+    case c10::kComplexFloat:
+      return "c10::kComplexFloat";
+    case c10::kComplexDouble:
+      return "c10::kComplexDouble";
+    case c10::kBool:
+      return "c10::kBool";
+    case c10::kQInt8:
+      return "c10::kQInt8";
+    case c10::kQUInt8:
+      return "c10::kQUInt8";
+    case c10::kQInt32:
+      return "c10::kQInt32";
+    case c10::kBFloat16:
+      return "c10::kBFloat16";
+    case c10::kQUInt4x2:
+      return "c10::kQUInt4x2";
+    case c10::kQUInt2x4:
+      return "c10::kQUInt2x4";
+    case c10::kFloat:
+      return "c10::kFloat";
+    case c10::kHalf:
+      return "c10::kHalf";
+    case c10::kInt:
+      return "c10::kInt";
+    case c10::kChar:
+      return "c10::kChar";
+    case c10::kByte:
+      return "c10::kByte";
+    case c10::kDouble:
+      return "c10::kDouble";
+    case c10::kUInt16:
+      return "c10::kUInt16";
+    case c10::kBits16:
+      return "c10::kBits16";
+    default:
+      return "Unknown(" + std::to_string(static_cast<int>(dtype)) + ")";
+  }
+}
+
+vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
+  using namespace vkcompute;
+  switch (at_scalartype) {
+    case c10::kHalf:
+      return vkapi::kHalf;
+    case c10::kFloat:
+      return vkapi::kFloat;
+    case c10::kDouble:
+      return vkapi::kDouble;
+    case c10::kInt:
+      return vkapi::kInt;
+    case c10::kLong:
+      return vkapi::kLong;
+    case c10::kChar:
+      return vkapi::kChar;
+    case c10::kByte:
+      return vkapi::kByte;
+    case c10::kShort:
+      return vkapi::kShort;
+    case c10::kUInt16:
+      return vkapi::kUInt16;
+    default:
+      VK_THROW(
+          "Unsupported at::ScalarType: ",
+          scalar_type_name(at_scalartype),
+          " (",
+          static_cast<int>(at_scalartype),
+          ")");
+  }
+}
diff --git a/backends/vulkan/test/op_tests/test_utils.h b/backends/vulkan/test/op_tests/test_utils.h
new file mode 100644
index 00000000000..369767007e0
--- /dev/null
+++ b/backends/vulkan/test/op_tests/test_utils.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include <string>
+
+#include <ATen/ATen.h>
+#include <c10/core/ScalarType.h>
+#include <executorch/backends/vulkan/runtime/api/api.h>
+#include <executorch/runtime/core/exec_aten/exec_aten.h>
+
+/**
+ * Convert at::ScalarType to executorch::ScalarType
+ */
+executorch::aten::ScalarType at_scalartype_to_et_scalartype(
+    at::ScalarType dtype);
+
+/**
+ * Get the string name of a c10::ScalarType for better error messages
+ */
+std::string scalar_type_name(c10::ScalarType dtype);
+
+/**
+ * Convert c10::ScalarType to vkcompute::vkapi::ScalarType
+ */
+vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype);
diff --git a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py
index 65bb959f6d1..a054fdf1a19 100644
--- a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py
+++ b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py
@@ -177,6 +177,8 @@ def generate_benchmark_fixture(self) -> str:
 
 vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {{
   switch (at_scalartype) {{
+    case c10::kDouble:
+      return vkapi::kDouble;
     case c10::kFloat:
       return vkapi::kFloat;
     case c10::kHalf:
@@ -187,6 +189,8 @@ def generate_benchmark_fixture(self) -> str:
       return vkapi::kInt;
     case c10::kChar:
       return vkapi::kChar;
+    case c10::kBool:
+      return vkapi::kBool;
     default:
       VK_THROW("Unsupported at::ScalarType!");
   }}
diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py
index 4f0d2ff11ef..e7cf5ba92a5 100644
--- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py
+++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py
@@ -110,6 +110,8 @@ def gen_parameterization(self) -> str:
 
 vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
   switch (at_scalartype) {
+    case c10::kDouble:
+      return vkapi::kDouble;
     case c10::kFloat:
       return vkapi::kFloat;
     case c10::kHalf:
diff --git a/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml b/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml
index a00bba2bc5a..69587bd38d0 100644
--- a/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml
+++ b/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml
@@ -6,7 +6,7 @@
 
 warp_size:
   parameter_names_with_default_values:
-    DTYPE: int
+    DTYPE: int32
     STORAGE: buffer
   generate_variant_forall:
     METHOD:
diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp
index c1f2770d3d6..876099598dc 100644
--- a/kernels/quantized/cpu/op_dequantize.cpp
+++ b/kernels/quantized/cpu/op_dequantize.cpp
@@ -288,16 +288,16 @@ Tensor& dequantize_per_tensor_out(
           static_cast<float>(scale));                                          \
     }                                                                          \
   } break;
-#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype)               \
-  case ScalarType::in_dtype:                                 \
-    switch (out.scalar_type()) {                             \
-      ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \
-      default:                                               \
-        ET_CHECK_MSG(                                        \
-            false,                                           \
-            "Unhandled output dtype %" PRId8,                \
-            static_cast<int8_t>(out.scalar_type()));         \
-    }                                                        \
+#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype)                \
+  case ScalarType::in_dtype:                                  \
+    switch (out.scalar_type()) {                              \
+      ET_FORALL_FLOATH_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \
+      default:                                                \
+        ET_CHECK_MSG(                                         \
+            false,                                            \
+            "Unhandled output dtype %" PRId8,                 \
+            static_cast<int8_t>(out.scalar_type()));          \
+    }                                                         \
     break;
 
   switch (input.scalar_type()) {
@@ -459,7 +459,8 @@ Tensor& dequantize_per_channel_out(
               }                                                                \
               out_data_ptr[current_ix] =                                       \
                   static_cast<CTYPE_OUT>(                                      \
-                      input_data_ptr[current_ix] - zero_point) *               \
+                      input_data_ptr[current_ix] -                             \
+                      static_cast<int32_t>(zero_point)) *                      \
                   _scale;                                                      \
             }                                                                  \
           },                                                                   \
@@ -478,23 +479,24 @@ Tensor& dequantize_per_channel_out(
       apply_over_dim_list(                                                     \
           [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) {  \
             out_data_ptr[in_ix] = static_cast<CTYPE_OUT>(                      \
-                (input_data_ptr[in_ix] - _zero_point) * _scale);               \
+                (input_data_ptr[in_ix] - static_cast<int32_t>(_zero_point)) *  \
+                _scale);                                                       \
           },                                                                   \
           input,                                                               \
           optional_dim_list,                                                   \
           channel_ix);                                                         \
     }                                                                          \
     break;
-#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype)             \
-  case ScalarType::in_dtype:                                 \
-    switch (out.scalar_type()) {                             \
-      ET_FORALL_FLOAT_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \
-      default:                                               \
-        ET_CHECK_MSG(                                        \
-            false,                                           \
-            "Unhandled output dtype %" PRId8,                \
-            static_cast<int8_t>(out.scalar_type()));         \
-    }                                                        \
+#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype)              \
+  case ScalarType::in_dtype:                                  \
+    switch (out.scalar_type()) {                              \
+      ET_FORALL_FLOATH_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \
+      default:                                                \
+        ET_CHECK_MSG(                                         \
+            false,                                            \
+            "Unhandled output dtype %" PRId8,                 \
+            static_cast<int8_t>(out.scalar_type()));          \
+    }                                                         \
     break;
 
   switch (input.scalar_type()) {
diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp
index 4665c3d665b..d0b7c882f8e 100644
--- a/kernels/quantized/cpu/op_quantize.cpp
+++ b/kernels/quantized/cpu/op_quantize.cpp
@@ -150,7 +150,7 @@ Tensor& quantize_per_tensor_out(
     break;
 
   switch (input.scalar_type()) {
-    ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
+    ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE);
     default:
       ET_CHECK_MSG(
           false,
@@ -346,7 +346,7 @@ Tensor& quantize_per_channel_out(
     break;
 
   switch (input.scalar_type()) {
-    ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
+    ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE);
     default:
       ET_CHECK_MSG(
           false,
diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp
index bbda1590a10..4a0c195e3ab 100644
--- a/kernels/quantized/test/op_dequantize_test.cpp
+++ b/kernels/quantized/test/op_dequantize_test.cpp
@@ -67,6 +67,96 @@ TEST(OpDequantizeOutTest, AllDtypesSupported) {
   test_dtype<ScalarType::Int>();
 }
 
+/// Test all supported output dtypes for dequantization
+template <ScalarType OUT_DTYPE>
+void test_output_dtype() {
+  TensorFactory<ScalarType::Byte> tf;
+
+  Tensor input = tf.full({3, 5}, 100);
+  double scale = 0.5;
+  int64_t zero_point = 30;
+  int64_t quant_min = 0;
+  int64_t quant_max = 255;
+
+  TensorFactory<OUT_DTYPE> tfo;
+  Tensor out = tfo.zeros({3, 5});
+  // (100 - 30) * 0.5 = 35
+  Tensor expected = tfo.full({3, 5}, 35);
+  dequantize_per_tensor_out(
+      input,
+      scale,
+      zero_point,
+      quant_min,
+      quant_max,
+      ScalarType::Byte,
+      optional<ScalarType>(OUT_DTYPE),
+      out);
+
+  EXPECT_TENSOR_EQ(out, expected);
+}
+
+TEST(OpDequantizeOutTest, AllOutputDtypesSupported) {
+  et_pal_init();
+  test_output_dtype<ScalarType::Float>();
+  test_output_dtype<ScalarType::Double>();
+  test_output_dtype<ScalarType::Half>();
+}
+
+TEST(OpDequantizeOutTest, HalfOutput) {
+  et_pal_init();
+  TensorFactory<ScalarType::Byte> tf;
+
+  Tensor input = tf.full({3, 5}, 10);
+  double scale = 0.5;
+  int64_t zero_point = 100000;
+  int64_t quant_min = 0;
+  int64_t quant_max = 255;
+
+  TensorFactory<ScalarType::Half> tfo;
+  Tensor out = tfo.zeros({3, 5});
+  // (10 - 100000) * 0.5 = -49995
+  dequantize_per_tensor_out(
+      input,
+      scale,
+      zero_point,
+      quant_min,
+      quant_max,
+      ScalarType::Byte,
+      optional<ScalarType>(ScalarType::Half),
+      out);
+
+  // The expected result should be (10 - 100000) * 0.5 = -49995
+  Tensor expected = tfo.full({3, 5}, -49995);
+  EXPECT_TENSOR_EQ(out, expected);
+}
+
+TEST(OpDequantizeOutTest, DoubleOutput) {
+  et_pal_init();
+  TensorFactory<ScalarType::Byte> tf;
+
+  Tensor input = tf.full({3, 5}, 10);
+  double scale = 0.5;
+  int64_t zero_point = 100000;
+  int64_t quant_min = 0;
+  int64_t quant_max = 255;
+
+  TensorFactory<ScalarType::Double> tfo;
+  Tensor out = tfo.zeros({3, 5});
+  dequantize_per_tensor_out(
+      input,
+      scale,
+      zero_point,
+      quant_min,
+      quant_max,
+      ScalarType::Byte,
+      optional<ScalarType>(ScalarType::Double),
+      out);
+
+  // The expected result should be (10 - 100000) * 0.5 = -49995
+  Tensor expected = tfo.full({3, 5}, -49995);
+  EXPECT_TENSOR_EQ(out, expected);
+}
+
 TEST(OpDequantizeOutTest, NonWholeNumbers) {
   et_pal_init();
   TensorFactory<ScalarType::Byte> tf;
diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp
index 704d8d06c5c..5cd17223d80 100644
--- a/kernels/quantized/test/op_quantize_test.cpp
+++ b/kernels/quantized/test/op_quantize_test.cpp
@@ -49,6 +49,32 @@ void test_dtype() {
   EXPECT_TENSOR_EQ(out, expected);
 }
 
+template <ScalarType INPUT_DTYPE>
+void test_input_dtype() {
+  TensorFactory<INPUT_DTYPE> tf_input;
+
+  Tensor input = tf_input.full({3, 5}, 4);
+  double scale = 0.5;
+  int64_t zero_point = 108;
+  int64_t quant_min = 0;
+  int64_t quant_max = 127;
+
+  TensorFactory<ScalarType::Char> tfo;
+  Tensor out = tfo.zeros({3, 5});
+  // 4 / 0.5 + 108 = 116
+  Tensor expected = tfo.full({3, 5}, 116);
+  quantize_per_tensor_out(
+      input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
+
+  EXPECT_TENSOR_EQ(out, expected);
+}
+
+TEST(OpQuantizeOutTest, AllInputDtypesSupported) {
+  test_input_dtype<ScalarType::Float>();
+  test_input_dtype<ScalarType::Half>();
+  test_input_dtype<ScalarType::Double>();
+}
+
 TEST(OpQuantizeOutTest, AllDtypesSupported) {
   test_dtype<ScalarType::Byte>();
   test_dtype<ScalarType::Char>();
@@ -58,6 +84,45 @@ TEST(OpQuantizeOutTest, AllDtypesSupported) {
   test_dtype<ScalarType::Int>();
 }
 
+TEST(OpQuantizeOutTest, DoubleInputTest) {
+  TensorFactory<ScalarType::Double> tf_double;
+
+  // Test with a more complex value that might have precision differences
+  Tensor input = tf_double.full({2, 3}, 3.14159265359);
+  double scale = 0.01;
+  int64_t zero_point = -100;
+  int64_t quant_min = 0;
+  int64_t quant_max = 255;
+
+  TensorFactory<ScalarType::Byte> tfo;
+  Tensor out = tfo.zeros({2, 3});
+  // 3.14159265359 / 0.01 - 100 = 214.159265359
+  Tensor expected = tfo.full({2, 3}, 214);
+  quantize_per_tensor_out(
+      input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out);
+
+  EXPECT_TENSOR_EQ(out, expected);
+}
+
+TEST(OpQuantizeOutTest, HalfInputTest) {
+  TensorFactory<ScalarType::Half> tf_half;
+
+  Tensor input = tf_half.full({2, 3}, 2.5);
+  double scale = 0.5;
+  int64_t zero_point = 10;
+  int64_t quant_min = -128;
+  int64_t quant_max = 127;
+
+  TensorFactory<ScalarType::Char> tfo;
+  Tensor out = tfo.zeros({2, 3});
+  // 2.5 / 0.5 + 10 = 15
+  Tensor expected = tfo.full({2, 3}, 15);
+  quantize_per_tensor_out(
+      input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
+
+  EXPECT_TENSOR_EQ(out, expected);
+}
+
 TEST(OpQuantizeOutTest, TensorArgOverload) {
   TensorFactory<ScalarType::Float> tf_float;
   TensorFactory<ScalarType::Double> tf_double;
diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h
index 6f81146e925..d81b3ad4d0f 100644
--- a/runtime/core/exec_aten/util/scalar_type_util.h
+++ b/runtime/core/exec_aten/util/scalar_type_util.h
@@ -199,6 +199,11 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType)
   _(ANOTHER_INPUT, float, Float)                     \
   _(ANOTHER_INPUT, double, Double)
 
+#define ET_FORALL_FLOATH_TYPES_WITH(ANOTHER_INPUT, _) \
+  _(ANOTHER_INPUT, float, Float)                      \
+  _(ANOTHER_INPUT, double, Double)                    \
+  _(ANOTHER_INPUT, ::executorch::aten::Half, Half)
+
 #define ET_FORALL_FLOAT_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \
   _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float)                      \
   _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double)