Skip to content

Reorg for converters leaky_relu (FX Converter Refactor [6/N]) <Target: converter_reorg_proto> #1902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,17 +1026,9 @@ def acc_ops_leaky_relu(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
negative_slope = kwargs["negative_slope"]
operation_type = trt.ActivationType.LEAKY_RELU
return activation.convert_activation(
network,
target,
SourceIR.ACC,
name,
operation_type,
input_val,
alpha=negative_slope,

return activation.leaky_relu(
network, target, SourceIR.ACC, name, kwargs["input"], kwargs["negative_slope"]
)


Expand Down
27 changes: 27 additions & 0 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,33 @@ def aten_ops_hardtanh(
)


@tensorrt_converter(torch.ops.aten.fmod.Tensor)
def aten_ops_fmod(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
kwargs_new = {
"input": args[0],
"other": args[1],
}
return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name)


@tensorrt_converter(torch.ops.aten.leaky_relu.default)
def aten_ops_leaky_relu(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:

return activation.leaky_relu(network, target, SourceIR.ATEN, name, args[0], args[1])


@tensorrt_converter(torch.ops.aten.linear)
def aten_ops_linear(
network: TRTNetwork,
Expand Down
27 changes: 27 additions & 0 deletions py/torch_tensorrt/fx/converters/impl/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,30 @@ def tanh_fn(x):
input_val,
dyn_range_fn=tanh_dyn_range_fn,
)


def leaky_relu(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
alpha: Optional[Any],
):
operation_type = trt.ActivationType.LEAKY_RELU

def leaky_relu_dyn_range_fn(dyn_range):
return (max(0, dyn_range[0]) + alpha * min(0, dyn_range[0])), (
max(0, dyn_range[1]) + alpha * min(0, dyn_range[1])
)

return convert_activation(
network,
target,
source_ir,
name,
operation_type,
input_val,
alpha,
dyn_range_fn=leaky_relu_dyn_range_fn,
)
16 changes: 16 additions & 0 deletions py/torch_tensorrt/fx/converters/nn_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,19 @@ def tanh(network, submod, args, kwargs, layer_name):
name=layer_name,
input_val=kwargs["input"],
)


@tensorrt_converter(torch.nn.functional.leaky_relu)
@tensorrt_converter(torch.nn.modules.activation.LeakyReLU)
def leaky_relu(network, submod, args, kwargs, layer_name):
# args/kwargs should have already been normalized to kwargs
assert len(args) == 0

return activation.leaky_relu(
network=network,
target="torch.nn.functional.leaky_relu",
source_ir=SourceIR.NN,
name=layer_name,
input_val=kwargs["input"],
alpha=kwargs["negative_slope"],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec


class TestLeakyReLUConverter(DispatchTestCase):
def test_leaky_relu(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.leaky_relu(x, negative_slope=0.05)

inputs = [torch.randn(1, 10)]
self.run_test(
TestModule(), inputs, expected_ops={torch.ops.aten.leaky_relu.default}
)

def test_leaky_relu_with_dynamic_shape(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.leaky_relu(x, negative_slope=0.05)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
),
]
self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default}
)

def test_leaky_relu_with_dynamic_shape_four_dimensions(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.leaky_relu(x, negative_slope=0.05)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
),
]

self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default}
)


if __name__ == "__main__":
run_tests()