@@ -118,9 +118,6 @@ def _vk_replace_linear_int4(
118
118
# Use custom vulkan linear layer as default
119
119
linear_class : Type [torch .nn .Module ] = VkWeightOnlyInt4Linear ,
120
120
copy_weights : bool = False ,
121
- # Serves the same purpose as `tensor_dim_limit` in
122
- # executorch.backends.vulkan.partitioner.VulkanSupportedOperators
123
- feature_limit : int = 16384 ,
124
121
):
125
122
for name , child in module .named_children ():
126
123
if isinstance (child , torch .nn .Linear ) and (
@@ -131,8 +128,6 @@ def _vk_replace_linear_int4(
131
128
if (
132
129
_check_linear_int4_k (child .in_features , groupsize , inner_k_tiles )
133
130
or padding_allowed
134
- ) and (
135
- child .out_features < feature_limit and child .in_features < feature_limit
136
131
):
137
132
new_linear = linear_class (
138
133
child .in_features ,
@@ -175,7 +170,6 @@ def __init__(
175
170
inner_k_tiles : Optional [int ] = 8 ,
176
171
device : torch .device = torch .device ("cpu" ), # noqa
177
172
precision : torch .dtype = torch .float32 ,
178
- feature_limit : int = 16384 ,
179
173
) -> None :
180
174
super ().__init__ ()
181
175
assert inner_k_tiles in [2 , 4 , 8 ]
@@ -186,9 +180,6 @@ def __init__(
186
180
self .padding_allowed : bool = padding_allowed
187
181
self .device : torch .device = device
188
182
self .precision : torch .dtype = precision
189
- # Serves the same purpose as `tensor_dim_limit` in
190
- # executorch.backends.vulkan.partitioner.VulkanSupportedOperators
191
- self .feature_limit = feature_limit
192
183
193
184
@torch .no_grad ()
194
185
def _create_quantized_state_dict (
@@ -197,10 +188,7 @@ def _create_quantized_state_dict(
197
188
cur_state_dict = model .state_dict ()
198
189
for fqn , mod in model .named_modules ():
199
190
# Add additional check to make sure features do not exceed feature limit
200
- if isinstance (mod , torch .nn .Linear ) and (
201
- mod .out_features < self .feature_limit
202
- and mod .in_features < self .feature_limit
203
- ):
191
+ if isinstance (mod , torch .nn .Linear ):
204
192
out_features = mod .out_features
205
193
in_features = mod .in_features
206
194
logging .info (f"linear: { fqn } , in={ in_features } , out={ out_features } " )
0 commit comments