diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py
index 8409e1a9f38..bc1c2ef3d66 100644
--- a/backends/arm/_passes/__init__.py
+++ b/backends/arm/_passes/__init__.py
@@ -39,6 +39,7 @@
 from .insert_table_ops import InsertTableOpsPass  # noqa
 from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass  # noqa
 from .match_arg_ranks_pass import MatchArgRanksPass  # noqa
+from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass  # noqa
 from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass  # noqa
 from .mm_to_bmm_pass import ConvertMmToBmmPass  # noqa
 from .remove_clone_pass import RemoveClonePass  # noqa
diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py
index f56be30083c..c085e3def1b 100644
--- a/backends/arm/_passes/arm_pass_manager.py
+++ b/backends/arm/_passes/arm_pass_manager.py
@@ -40,6 +40,7 @@
     InsertTableOpsPass,
     KeepDimsFalseToSqueezePass,
     MatchArgRanksPass,
+    MatchWhereSelfDtypePass,
     QuantizeOperatorArguments,
     RemoveClonePass,
     ReplaceScalarWithTensorArgPassTOSABI,
@@ -80,6 +81,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
         self.add_pass(ConvertToClampPass())
         self.add_pass(ConvertMinMaxPass())
         self.add_pass(ConvertAnyDefaultDimDimsPass())
+        self.add_pass(MatchWhereSelfDtypePass())
         if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
             self.add_pass(CastToInt32Pass())
 
@@ -130,6 +132,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
         self.add_pass(ConvertToClampPass())
         self.add_pass(ConvertMinMaxPass())
         self.add_pass(ConvertAnyDefaultDimDimsPass())
+        self.add_pass(MatchWhereSelfDtypePass())
 
         self.add_pass(AnnotateDecomposedMatmulPass())
         self.add_pass(QuantizeOperatorArguments())
diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py
index 759f215a034..2cfc9b2b86a 100644
--- a/backends/arm/_passes/match_arg_ranks_pass.py
+++ b/backends/arm/_passes/match_arg_ranks_pass.py
@@ -49,6 +49,7 @@ def __init__(self, exported_program):
         exir_ops.edge.aten.bitwise_left_shift.Tensor,
         exir_ops.edge.aten.eq.Tensor,
         exir_ops.edge.aten.pow.Tensor_Tensor,
+        exir_ops.edge.aten.where.self,
     ]
 
     def _match_op_rank(self, graph_module, node, arg, max_rank):
diff --git a/backends/arm/_passes/match_where_self_arg_dtype_pass.py b/backends/arm/_passes/match_where_self_arg_dtype_pass.py
new file mode 100644
index 00000000000..154602129f8
--- /dev/null
+++ b/backends/arm/_passes/match_where_self_arg_dtype_pass.py
@@ -0,0 +1,95 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from executorch.backends.arm._passes.arm_pass_utils import create_node
+from executorch.exir.dialects._ops import ops as exir_ops
+from executorch.exir.pass_base import ExportPass, PassResult
+
+DTYPE_RANK = {
+    torch.bool: 0,
+    torch.uint8: 1,
+    torch.int8: 2,
+    torch.int16: 3,
+    torch.int32: 4,
+    torch.int64: 5,
+    torch.float16: 6,
+    torch.float32: 7,
+    torch.float64: 8,
+}
+
+
+def get_largest_dtype(dtype_1, dtype_2):
+    """Find the largest dtype."""
+    return dtype_1 if DTYPE_RANK[dtype_1] > DTYPE_RANK[dtype_2] else dtype_2
+
+
+class MatchWhereSelfDtypePass(ExportPass):
+    """Pass to match data types of non-condition input tensors.
+
+    Edge dialect allows different data types for non-condition tensors, while TOSA
+    does not. In cases where they differ a TOSA CAST operator is inserted.
+
+    There is an edge case where one input is `boolean`, which cannot be directly cast
+    to, for example, float32. When this occurs two CAST operators are added to first
+    cast to int8 and then to the correct target data type.
+
+    """
+
+    def call(self, graph_module: torch.fx.GraphModule):
+        modified_graph = False
+        graph = graph_module.graph
+        node_list = graph.find_nodes(
+            op="call_function", target=exir_ops.edge.aten.where.self
+        )
+        for node in node_list:
+            cond, input_, other_ = node.args
+
+            input_dtype = input_.meta["val"].dtype
+            other_dtype = other_.meta["val"].dtype
+            target_dtype = torch.float32
+            if input_dtype != other_dtype:
+                target_dtype = get_largest_dtype(input_dtype, other_dtype)
+
+            for arg in node.args[1:]:
+                arg_dtype = arg.meta["val"].dtype
+
+                if arg_dtype != target_dtype:
+                    if arg_dtype == torch.bool:
+                        # Bool is an edge case which cannot necessarily be directly
+                        # converted to the target data type.
+                        with graph.inserting_after(arg):
+                            replace_node_int8 = create_node(
+                                graph,
+                                exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
+                            )
+                            replace_node_int8.args = (arg,)
+                            replace_node_int8.kwargs = {"dtype": torch.int8}
+
+                        with graph.inserting_after(replace_node_int8):
+                            replace_node_fp32 = create_node(
+                                graph,
+                                exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
+                            )
+                            replace_node_fp32.args = (replace_node_int8,)
+                            replace_node_fp32.kwargs = {"dtype": target_dtype}
+                            node.replace_input_with(arg, replace_node_fp32)
+                    else:
+                        with graph.inserting_after(arg):
+                            replace_node = create_node(
+                                graph,
+                                exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
+                            )
+                            replace_node.args = (arg,)
+                            replace_node.kwargs = {"dtype": target_dtype}
+                            node.replace_input_with(arg, replace_node)
+
+                    modified_graph = True
+
+        if modified_graph:
+            graph_module.recompile()
+            graph_module = super().call(graph_module).graph_module
+
+        return PassResult(graph_module, modified_graph)
diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py
index 64f3fb3f816..69fda636423 100644
--- a/backends/arm/operator_support/ethos_u55_support.py
+++ b/backends/arm/operator_support/ethos_u55_support.py
@@ -149,6 +149,7 @@ class EthosU55NotSupported(OperatorSupportBase):
         exir_ops.edge.aten.reflection_pad1d.default,  # REVERSE
         exir_ops.edge.aten.reflection_pad2d.default,  # REVERSE
         exir_ops.edge.aten.reflection_pad3d.default,  # REVERSE
+        exir_ops.edge.aten.where.self,  # SELECT
     ]
 
     def __init__(self, reporter: WhyNoPartitionReporter):
diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py
index 2a31ecbc775..0e5d7ecc958 100644
--- a/backends/arm/operator_support/tosa_supported_operators.py
+++ b/backends/arm/operator_support/tosa_supported_operators.py
@@ -207,6 +207,7 @@ def is_node_supported(
             exir_ops.edge.aten.squeeze_copy.dims,
             exir_ops.edge.aten.pow.Tensor_Scalar,
             exir_ops.edge.aten.pow.Tensor_Tensor,
+            exir_ops.edge.aten.where.self,
             operator.getitem,
             exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
             exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py
index f891d8d3b69..2a610536f3e 100644
--- a/backends/arm/operators/__init__.py
+++ b/backends/arm/operators/__init__.py
@@ -49,6 +49,7 @@
     op_transpose,
     op_upsample_nearest2d,
     op_view,
+    op_where,
     ops_binary,
     ops_unary,
 )
diff --git a/backends/arm/operators/op_where.py b/backends/arm/operators/op_where.py
new file mode 100644
index 00000000000..c8b35e831d4
--- /dev/null
+++ b/backends/arm/operators/op_where.py
@@ -0,0 +1,103 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Sequence
+
+import serializer.tosa_serializer as ts  # type: ignore
+
+from executorch.backends.arm.operators.node_visitor import (
+    NodeVisitor,
+    register_node_visitor,
+)
+from executorch.backends.arm.tosa_mapping import TosaArg
+from executorch.backends.arm.tosa_specification import TosaSpecification
+from serializer.tosa_serializer import TosaOp
+from torch.fx import Node
+
+
+def _add_node_to_tosa_graph(
+    tosa_graph: ts.TosaSerializer,
+    inputs: List[TosaArg],
+    output: TosaArg,
+    supported_dtypes: Sequence,
+) -> None:
+    if len(inputs) != 3:
+        raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}")
+
+    if inputs[0].dtype is not ts.DType.BOOL:
+        raise ValueError("Input 0 needs to have dtype BOOL")
+    if inputs[1].dtype != inputs[2].dtype:
+        raise ValueError(
+            "Non-condition tensors must have same data type, got "
+            f"{inputs[1].dtype} and {inputs[2].dtype}"
+        )
+    for input_ in inputs[1:]:
+        if input_.dtype not in supported_dtypes:
+            raise ValueError(
+                f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}"
+            )
+
+    tosa_graph.addOperator(
+        TosaOp.Op().SELECT,
+        [inputs[0].name, inputs[1].name, inputs[2].name],
+        [output.name],
+        None,
+    )
+
+
+@register_node_visitor
+class WhereVisitor_080_BI(NodeVisitor):
+    target = "aten.where.self"
+
+    tosa_specs = [
+        TosaSpecification.create_from_string("TOSA-0.80+BI"),
+    ]
+
+    def __init__(self, *args):
+        super().__init__(*args)
+
+    def define_node(
+        self,
+        node: Node,
+        tosa_graph: ts.TosaSerializer,
+        inputs: List[TosaArg],
+        output: TosaArg,
+    ) -> None:
+
+        bi_supported_dtypes = [
+            ts.DType.INT8,
+            ts.DType.INT16,
+            ts.DType.INT32,
+            ts.DType.BOOL,
+        ]
+        _add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes)
+
+
+@register_node_visitor
+class WhereVisitor_080_MI(WhereVisitor_080_BI):
+
+    tosa_specs = [
+        TosaSpecification.create_from_string("TOSA-0.80+MI"),
+    ]
+
+    def __init__(self, *args):
+        super().__init__(*args)
+
+    def define_node(
+        self,
+        node: Node,
+        tosa_graph: ts.TosaSerializer,
+        inputs: List[TosaArg],
+        output: TosaArg,
+    ) -> None:
+        mi_supported_dtypes = [
+            ts.DType.FP16,
+            ts.DType.FP32,
+            ts.DType.INT8,
+            ts.DType.INT16,
+            ts.DType.INT32,
+            ts.DType.BOOL,
+        ]
+        _add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes)
diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py
index e9ed6be81f3..baca13029a3 100644
--- a/backends/arm/quantizer/quantization_annotator.py
+++ b/backends/arm/quantizer/quantization_annotator.py
@@ -238,13 +238,14 @@ def _match_pattern(
     torch.ops.aten.dropout_.default,
     torch.ops.aten.clamp.default,
     torch.ops.aten.clamp.Tensor,
+    torch.ops.aten.where,
     operator.getitem,
 ]
 
 
 def get_quant_properties(  # noqa: C901
     node: Node, gm: torch.fx.GraphModule, quantization_config
-) -> _OpQuantProperties:
+) -> _OpQuantProperties | None:
     input_act_qspec = quantization_config.get_input_act_qspec()
     weight_qspec = quantization_config.get_weight_qspec()
     output_act_qspec = quantization_config.get_output_act_qspec()
@@ -322,6 +323,13 @@ def any_or_hardtanh_min_zero(n: Node):
             ),
         ]
         quant_properties.quant_output = _QuantProperty(0, shared_qspec)  # type: ignore[arg-type]
+    elif node.target in (torch.ops.aten.where.self,):
+        shared_qspec = SharedQuantizationSpec(node.args[1])  # type: ignore[arg-type]
+        quant_properties.quant_inputs = [
+            _QuantProperty(1, shared_qspec),  # type: ignore[arg-type]
+            _QuantProperty(2, shared_qspec),  # type: ignore[arg-type]
+        ]
+        quant_properties.quant_output = _QuantProperty(0, shared_qspec)  # type: ignore[arg-type]
     elif node.target == torch.ops.aten.adaptive_avg_pool2d.default:
         input_qspec = (
             SharedQuantizationSpec(node.args[0])  # type: ignore[arg-type]
@@ -376,16 +384,16 @@ def any_or_hardtanh_min_zero(n: Node):
         quant_properties.quant_output = None
     elif node.target in _parent_shared_qspec:
         if not isinstance(node.args[0], Node):
-            return None  # type: ignore[return-value]
+            return None
 
         if not arm_quantizer_utils.is_output_annotated(node.args[0]):  # type: ignore[attr-defined]
-            return None  # type: ignore[return-value]
+            return None
 
         shared_qspec = SharedQuantizationSpec(node.args[0])
         quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)]  # type: ignore[arg-type]
         quant_properties.quant_output = _QuantProperty(0, shared_qspec)  # type: ignore[arg-type]
     else:
-        return None  # type: ignore[return-value]
+        return None
 
     # Don't check if operator.getitem is ok for quantization, it's always ok
     if node.target == operator.getitem:
@@ -394,7 +402,7 @@ def any_or_hardtanh_min_zero(n: Node):
     # Check that each inputs/outputs can be quantized properly with the
     # provided quantization properties.
     if not _is_ok_for_quantization(node, quant_properties, gm):
-        return None  # type: ignore[return-value]
+        return None
 
     return quant_properties
 
diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py
index e270fb18205..dc5ecc7ca97 100644
--- a/backends/arm/test/models/test_conformer.py
+++ b/backends/arm/test/models/test_conformer.py
@@ -31,10 +31,8 @@ class TestConformer(unittest.TestCase):
     # .to_executorch step, i.e. after Arm partitioner.
     ops_after_partitioner = {
         "executorch_exir_dialects_edge__ops_aten_max_default": 1,
-        "executorch_exir_dialects_edge__ops_aten_where_self": 4,
         "torch.ops.aten._assert_scalar.default": 10,
         "torch.ops.aten._local_scalar_dense.default": 1,
-        "torch.ops.higher_order.executorch_call_delegate": 4,
     }
 
     dim = 16
diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py
index a6da04b0e2e..bd18ff1856f 100644
--- a/backends/arm/test/models/test_llama.py
+++ b/backends/arm/test/models/test_llama.py
@@ -114,7 +114,7 @@ def test_llama_tosa_MI(self):
                 )
                 .export()
                 .to_edge_transform_and_lower()
-                .check_count({"torch.ops.higher_order.executorch_call_delegate": 14})
+                .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
                 .to_executorch()
                 .run_method_and_compare_outputs(
                     inputs=llama_inputs,
diff --git a/backends/arm/test/ops/test_where.py b/backends/arm/test/ops/test_where.py
new file mode 100644
index 00000000000..bf127460f3e
--- /dev/null
+++ b/backends/arm/test/ops/test_where.py
@@ -0,0 +1,276 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Tuple
+
+import pytest
+
+import torch
+
+from executorch.backends.arm.quantizer.arm_quantizer import (
+    EthosUQuantizer,
+    get_symmetric_quantization_config,
+    TOSAQuantizer,
+)
+from executorch.backends.arm.test import common
+from executorch.backends.arm.test.tester.test_pipeline import (
+    EthosU85PipelineBI,
+    OpNotSupportedPipeline,
+    TosaPipelineBI,
+    TosaPipelineMI,
+)
+from executorch.backends.xnnpack.test.tester.tester import Quantize
+
+aten_op = "torch.ops.aten.where.self"
+exir_op = "executorch_exir_dialects_edge__ops_aten_where_self"
+
+
+class Where(torch.nn.Module):
+    def __init__(
+        self, shape: tuple | int, dtype: torch.dtype | Tuple[torch.dtype], condition
+    ):
+        super().__init__()
+        self.shape = shape if isinstance(shape, tuple) else (shape,) * shape
+        self.dtype = (dtype, dtype) if isinstance(dtype, torch.dtype) else dtype
+        self.condition = condition
+
+    def get_inputs(self):
+        inputs: List = [0, 0]
+        for i in range(2):
+            if self.dtype[i] in [torch.int8, torch.int16, torch.int32]:
+                inputs[i] = torch.randint(
+                    torch.iinfo(self.dtype[i]).min,
+                    torch.iinfo(self.dtype[i]).max,
+                    self.shape,
+                    dtype=self.dtype[i],
+                )
+            elif self.dtype[i] in [torch.float32]:
+                inputs[i] = torch.randn(*self.shape).to(self.dtype[i])
+            elif self.dtype[i] is torch.bool:
+                inputs[i] = torch.randint(0, 1, self.shape, dtype=torch.bool)
+            else:
+                raise TypeError(
+                    f"Input generation for dtype {self.dtype[i]} not implemented in "
+                    "Where()"
+                )
+
+        return tuple(inputs)
+
+    def forward(
+        self,
+        input_: torch.Tensor,
+        other_: torch.Tensor,
+    ):
+        return torch.where(self.condition(input_), input_, other_)
+
+
+def tensor_condition(input: torch.Tensor):
+    return input > torch.zeros_like(input)
+
+
+def scalar_condition(input: torch.Tensor):
+    return input > 0
+
+
+two_dim_tensor_cond = Where(
+    2,
+    torch.float32,
+    tensor_condition,
+)
+
+three_dim_tensor_cond = Where(
+    3,
+    torch.float32,
+    tensor_condition,
+)
+
+float32_tensor_cond = Where(
+    1,
+    torch.float32,
+    tensor_condition,
+)
+
+float32_tensor_cond_tuple_dtype = Where(
+    1,
+    (torch.float32, torch.int8),
+    tensor_condition,
+)
+
+float32_tensor_cond_tuple_dtype_bool = Where(
+    1,
+    (torch.float32, torch.bool),
+    tensor_condition,
+)
+
+# Scalar tests
+two_dim_scalar_cond = Where(
+    2,
+    torch.float32,
+    scalar_condition,
+)
+
+three_dim_scalar_cond = Where(
+    3,
+    torch.float32,
+    scalar_condition,
+)
+
+float32_scalar_cond = Where(
+    1,
+    torch.float32,
+    scalar_condition,
+)
+
+test_modules_common = {
+    "two_dim_tensor_cond": two_dim_tensor_cond,
+    "three_dim_tensor_cond": three_dim_tensor_cond,
+    "float32_tensor_cond": float32_tensor_cond,
+    "two_dim_scalar_cond": two_dim_scalar_cond,
+    "three_dim_scalar_cond": three_dim_scalar_cond,
+    "float32_scalar_cond": float32_scalar_cond,
+}
+
+test_modules_MI = {
+    **test_modules_common,
+    "float32_tensor_cond_tuple_dtype": float32_tensor_cond_tuple_dtype,
+    "float32_tensor_cond_tuple_dtype_bool": float32_tensor_cond_tuple_dtype_bool,
+}
+
+test_modules_BI = {
+    **test_modules_common,
+}
+
+input_t = Tuple[torch.Tensor]
+
+
+@common.parametrize("test_module", test_modules_MI)
+def test_where_tosa_MI(test_module):
+    pipeline = TosaPipelineMI[input_t](
+        test_module, test_module.get_inputs(), aten_op, exir_op
+    )
+    pipeline.run()
+
+
+@common.parametrize("test_module", test_modules_BI)
+def test_where_tosa_BI(test_module):
+    compile_spec = common.get_tosa_compile_spec("TOSA-0.80+BI")
+    quantizer = TOSAQuantizer(compile_spec).set_io(get_symmetric_quantization_config())
+    pipeline = TosaPipelineBI[input_t](
+        test_module, test_module.get_inputs(), aten_op, exir_op
+    )
+    pipeline.change_args(
+        "quantize", Quantize(quantizer, get_symmetric_quantization_config())
+    )
+    pipeline.run()
+
+
+@common.parametrize("test_module", test_modules_BI)
+def test_where_u55_BI(test_module):
+    compile_spec = common.get_u55_compile_spec()
+    quantizer = EthosUQuantizer(compile_spec).set_io(
+        get_symmetric_quantization_config()
+    )
+
+    # If condition is tensor_condition then there will be one full_like op which will be
+    # delegated.
+    if test_module.condition == tensor_condition:
+        num_delegates = 1
+        num_exir = 0
+    else:
+        num_delegates = 0
+        num_exir = 0
+
+    pipeline = OpNotSupportedPipeline[input_t](
+        test_module,
+        test_module.get_inputs(),
+        "TOSA-0.80+BI+u55",
+        {
+            exir_op: 1,
+            "executorch_exir_dialects_edge__ops_aten_full_default": num_exir,
+        },
+        num_delegates,
+    )
+
+    pipeline.change_args(
+        "quantize", Quantize(quantizer, get_symmetric_quantization_config())
+    )
+    pipeline.run()
+
+
+@common.parametrize("test_module", test_modules_BI)
+def test_where_u85_BI(test_module):
+    compile_spec = common.get_u85_compile_spec()
+    quantizer = EthosUQuantizer(compile_spec).set_io(
+        get_symmetric_quantization_config()
+    )
+    pipeline = EthosU85PipelineBI[input_t](
+        test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=False
+    )
+    pipeline.change_args(
+        "quantize", Quantize(quantizer, get_symmetric_quantization_config())
+    )
+    pipeline.run()
+
+
+@common.parametrize("test_module", test_modules_BI)
+@pytest.mark.skip(reason="The same as test_where_u55_BI")
+@common.XfailIfNoCorstone300
+def test_where_u55_BI_on_fvp(test_module):
+    compile_spec = common.get_u55_compile_spec()
+    quantizer = EthosUQuantizer(compile_spec).set_io(
+        get_symmetric_quantization_config()
+    )
+
+    # If condition is tensor_condition then there will be one full_like op which will be
+    # delegated.
+    if test_module.condition == tensor_condition:
+        num_delegates = 1
+        num_exir = 0
+    else:
+        num_delegates = 0
+        num_exir = 0
+
+    pipeline = OpNotSupportedPipeline[input_t](
+        test_module,
+        test_module.get_inputs(),
+        "TOSA-0.80+BI+u55",
+        {
+            exir_op: 1,
+            "executorch_exir_dialects_edge__ops_aten_full_default": num_exir,
+        },
+        num_delegates,
+    )
+
+    pipeline.change_args(
+        "quantize", Quantize(quantizer, get_symmetric_quantization_config())
+    )
+    pipeline.run()
+
+
+@common.parametrize(
+    "test_module",
+    test_modules_BI,
+    xfails={
+        "two_dim_scalar_cond": "E [executorch:method.cpp:601] Missing operator: "
+        "[2] aten::gt.Scalar_out",
+        "three_dim_scalar_cond": "E [executorch:method.cpp:601] Missing operator: "
+        "[2] aten::gt.Scalar_out",
+        "float32_scalar_cond": "E [executorch:method.cpp:601] Missing operator: "
+        "[2] aten::gt.Scalar_out",
+    },
+)
+@common.XfailIfNoCorstone320
+def test_where_u85_BI_on_fvp(test_module):
+    compile_spec = common.get_u85_compile_spec()
+    quantizer = EthosUQuantizer(compile_spec).set_io(
+        get_symmetric_quantization_config()
+    )
+    pipeline = EthosU85PipelineBI[input_t](
+        test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=True
+    )
+    pipeline.change_args(
+        "quantize", Quantize(quantizer, get_symmetric_quantization_config())
+    )
+    pipeline.run()