Skip to content

Prototype reorg of the FX converters to clean dependency chain #1683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 11 additions & 99 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
trt_transposed_matmul,
)
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
import activation

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1004,9 +1005,7 @@ def acc_ops_relu(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
operation_type = trt.ActivationType.RELU
return add_activation_layer(network, input_val, operation_type, target, name)
return activation.add_relu(network, target, kwargs, name)


@tensorrt_converter(acc_ops.leaky_relu)
Expand All @@ -1017,12 +1016,7 @@ def acc_ops_leaky_relu(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
negative_slope = kwargs["negative_slope"]
operation_type = trt.ActivationType.LEAKY_RELU
return add_activation_layer(
network, input_val, operation_type, target, name, negative_slope
)
return activation.add_leaky_relu(network, target, kwargs, name)


@tensorrt_converter(acc_ops.elu)
Expand All @@ -1033,11 +1027,7 @@ def acc_ops_elu(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
alpha = kwargs["alpha"]
operation_type = trt.ActivationType.ELU
return add_activation_layer(network, input_val, operation_type, target, name, alpha)

return activation.add_elu(network, target, kwargs, name)

@tensorrt_converter(acc_ops.selu)
def acc_ops_selu(
Expand All @@ -1047,9 +1037,7 @@ def acc_ops_selu(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
operation_type = trt.ActivationType.SELU
return add_activation_layer(network, input_val, operation_type, target, name)
return activation.add_selu(network, target, kwargs, name)


@tensorrt_converter(acc_ops.softsign)
Expand All @@ -1060,10 +1048,7 @@ def acc_ops_softsign(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
operation_type = trt.ActivationType.SOFTSIGN
return add_activation_layer(network, input_val, operation_type, target, name)

return activation.add_softsign(network, target, kwargs, name)

@tensorrt_converter(acc_ops.sin)
def acc_ops_sin(
Expand Down Expand Up @@ -1138,10 +1123,7 @@ def acc_ops_tanh(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
operation_type = trt.ActivationType.TANH
return add_activation_layer(network, input_val, operation_type, target, name)

return activation.add_tanh(network, target, kwargs, name)

@tensorrt_converter(acc_ops.asin)
def acc_ops_asin(
Expand Down Expand Up @@ -3129,23 +3111,7 @@ def acc_ops_hard_sigmoid(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]

if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"Hard sigmoid received input {input_val} that is not part "
"of the TensorRT region!"
)

return add_activation_layer(
network,
input_val,
trt.ActivationType.HARD_SIGMOID,
target,
name,
alpha=1 / 6,
beta=0.5,
)
return activation.add_hard_sigmoid(network, target, kwargs, name)


@tensorrt_converter(acc_ops.sigmoid)
Expand All @@ -3156,17 +3122,7 @@ def acc_ops_sigmoid(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]

if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"Sigmoid received input {input_val} that is not part "
"of the TensorRT region!"
)

return add_activation_layer(
network, input_val, trt.ActivationType.SIGMOID, target, name
)
return activation.add_sigmoid(network, target, kwargs, name)


@tensorrt_converter(acc_ops.permute)
Expand Down Expand Up @@ -3367,34 +3323,7 @@ def acc_ops_gelu(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
approximate = kwargs["approximate"]
if approximate != "none":
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"GELU received input {input_val} that is not part "
"of the TensorRT region!"
)
if network.has_implicit_batch_dimension:
raise RuntimeError(
"GeLU converter currently doesn't support implicit batch dimension"
)

plugin_name = "CustomGeluPluginDynamic"
# type_id 0 for float32, 1 for float16
type_id = trt.PluginField(
"type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
)
field_collection = TRTPluginFieldCollection([type_id])
plugin_version = "1"

plugin = get_trt_plugin(plugin_name, field_collection, plugin_version)

layer = network.add_plugin_v2([input_val], plugin)
set_layer_name(layer, target, name)
return layer.get_output(0)

return activation.add_gelu(network, target, kwargs, name)

@tensorrt_converter(acc_ops.chunk)
def acc_ops_chunk(
Expand Down Expand Up @@ -3549,24 +3478,7 @@ def acc_ops_hardtanh(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]

if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"hardtanh received input {input_val} that is not part "
"of the TensorRT region!"
)

return add_activation_layer(
network,
input_val,
trt.ActivationType.CLIP,
target,
name,
alpha=kwargs["min_val"],
beta=kwargs["max_val"],
)

return activation.add_hardtanh(network, target, kwargs, name)

@tensorrt_converter(acc_ops.interpolate)
def acc_ops_interpolate(
Expand Down
Loading