diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 4bd0c527605..547dba4a7c5 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -16,7 +16,7 @@ from torch.fx import Node _ScalarType = Union[int, bool, float] -_Argument = Union[torch.fx.Node, int, bool, float, str] +_Argument = Union[Node, int, bool, float, str] class VkGraphBuilder: @@ -29,7 +29,7 @@ def __init__(self, program: ExportedProgram) -> None: self.output_ids = [] self.const_tensors = [] - # Mapping from torch.fx.Node to VkValue id + # Mapping from Node to VkValue id self.node_to_value_ids = {} @staticmethod @@ -39,18 +39,18 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: else: raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") - def is_constant(self, node: torch.fx.Node): + def is_constant(self, node: Node): return ( node.name in self.program.graph_signature.inputs_to_lifted_tensor_constants ) - def is_get_attr_node(self, node: torch.fx.Node) -> bool: + def is_get_attr_node(self, node: Node) -> bool: """ Returns true if the given node is a get attr node for a tensor of the model """ - return isinstance(node, torch.fx.Node) and node.op == "get_attr" + return isinstance(node, Node) and node.op == "get_attr" - def is_param_node(self, node: torch.fx.Node) -> bool: + def is_param_node(self, node: Node) -> bool: """ Check if the given node is a parameter within the exported program """ @@ -61,7 +61,7 @@ def is_param_node(self, node: torch.fx.Node) -> bool: or self.is_constant(node) ) - def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]: + def get_constant(self, node: Node) -> Optional[torch.Tensor]: """ Returns the constant associated with the given node in the exported program. Returns None if the node is not a constant within the exported program @@ -79,7 +79,7 @@ def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]: return None - def get_param_tensor(self, node: torch.fx.Node) -> torch.Tensor: + def get_param_tensor(self, node: Node) -> torch.Tensor: tensor = None if node is None: raise RuntimeError("node is None") @@ -168,7 +168,7 @@ def create_string_value(self, string: str) -> int: return new_id def get_or_create_value_for(self, arg: _Argument): - if isinstance(arg, torch.fx.Node): + if isinstance(arg, Node): # If the value has already been created, return the existing id if arg in self.node_to_value_ids: return self.node_to_value_ids[arg]