Skip to content

[FX] Sync to OSS #1118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,18 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
input_specs_val = (
self.lower_setting.input_specs
if self.lower_setting.input_specs
else InputTensorSpec.from_tensors(input)
else (
InputTensorSpec.from_tensors_with_dynamic_batch_size(
input,
(
0,
self.lower_setting.max_batch_size,
self.lower_setting.max_batch_size,
),
)
if self.lower_setting.explicit_batch_dimension
else InputTensorSpec.from_tensors(input)
)
)

# Prepare algorithm selector and timing_cache for TRTInterpreter
Expand Down
7 changes: 7 additions & 0 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class LowerSetting(LowerSettingBasic):
save_timing_cache: Save updated timing cache data into timing cache file if the timing
cache file is provided.
cuda_graph_batch_size (int): Cuda graph batch size, default to be -1.
preset_lowerer (str): when specified, use a preset logic to build the
instance of Lowerer. Refer to
`caffe2.torch.fb.model_transform.fx2trt.presets.LowererPresetsManager` on
how presets are applied. Refer to
`caffe2.torch.fb.model_transform.fx2trt.presets.ESUHMLowererPreset` on how
to add a preset.
"""

input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
Expand All @@ -79,3 +85,4 @@ class LowerSetting(LowerSettingBasic):
timing_cache_prefix: str = ""
save_timing_cache: bool = False
cuda_graph_batch_size: int = -1
preset_lowerer: str = ""
4 changes: 3 additions & 1 deletion py/torch_tensorrt/fx/passes/lower_basic_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def skip_folding_quant_dequant(node: torch.fx.Node):
return True
return False

const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant)
const_split_mod = split_const_subgraphs(
traced_mod, skip_folding_quant_dequant, device_for_folded_attrs="cuda"
)
const_split_mod.run_folding()
return const_split_mod

Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2576,5 +2576,6 @@ def test_all_acc_ops_registered(self):
acc_ops.new_ones,
acc_ops.einsum,
acc_ops.as_strided,
acc_ops.var,
},
)
15 changes: 15 additions & 0 deletions py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2864,3 +2864,18 @@ def as_strided(*, input, size, stride, storage_offset=0):
return torch.as_strided(
input=input, size=size, stride=stride, storage_offset=storage_offset
)


@register_acc_op_mapping(op_and_target=("call_function", torch.var))
@register_acc_op_mapping(
op_and_target=("call_method", "var"),
arg_replacement_tuples=[
("input", "input"),
("dim", "dim"),
("unbiased", "unbiased"),
("keepdim", "keepdim"),
],
)
@register_acc_op
def var(*, input, dim, unbiased, keepdim=False):
return torch.var(input=input, dim=dim, unbiased=unbiased, keepdim=keepdim)