Skip to content

Commit e6a1193

Browse files
committed
feat: Add example usage scripts for dynamo path
- Add sample scripts covering resnet18, transformers, and custom examples showcasing the `torch_tensorrt.dynamo.torch_compile` path, which can compile models with data-dependent control flow and other such restrictions which can make other compilation methods more difficult - Cover different customizeable features allowed in the new backend - Make scripts interactive Jupyter notebooks
1 parent a77017c commit e6a1193

File tree

3 files changed

+475
-0
lines changed

3 files changed

+475
-0
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "ff5a530b",
6+
"metadata": {},
7+
"source": [
8+
"# Overview\n",
9+
"This interactive notebook is intended as an overview of the process by which `torch_tensorrt.dynamo.torch_compile` works, and how it integrates with the new `torch.compile` API."
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"id": "2dae5728",
15+
"metadata": {},
16+
"source": [
17+
"## Imports and Model Definition"
18+
]
19+
},
20+
{
21+
"cell_type": "code",
22+
"execution_count": null,
23+
"id": "6fd29ec8",
24+
"metadata": {},
25+
"outputs": [],
26+
"source": [
27+
"import torch\n",
28+
"from torch_tensorrt.dynamo.torch_compile import create_backend\n",
29+
"from torch_tensorrt.fx.lower_setting import LowerPrecision"
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": null,
35+
"id": "eafb701f",
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"# We begin by defining a model\n",
40+
"class Model(torch.nn.Module):\n",
41+
" def __init__(self) -> None:\n",
42+
" super().__init__()\n",
43+
" self.relu = torch.nn.ReLU()\n",
44+
"\n",
45+
" def forward(self, x: torch.Tensor, y: torch.Tensor):\n",
46+
" x_out = self.relu(x)\n",
47+
" y_out = self.relu(y)\n",
48+
" x_y_out = x_out + y_out\n",
49+
" return torch.mean(x_y_out)"
50+
]
51+
},
52+
{
53+
"cell_type": "markdown",
54+
"id": "f7debfc0",
55+
"metadata": {},
56+
"source": [
57+
"## Compilation with `torch.compile` Using Default Settings"
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"execution_count": null,
63+
"id": "8ad82f23",
64+
"metadata": {},
65+
"outputs": [],
66+
"source": [
67+
"# Define sample float inputs and initialize model\n",
68+
"sample_inputs = [torch.rand((5, 7)).cuda(), torch.rand((5, 7)).cuda()]\n",
69+
"model = Model().eval().cuda()"
70+
]
71+
},
72+
{
73+
"cell_type": "code",
74+
"execution_count": null,
75+
"id": "33a4d0e7",
76+
"metadata": {},
77+
"outputs": [],
78+
"source": [
79+
"# Next, we compile the model using torch.compile\n",
80+
"# For the default settings, we can simply call torch.compile\n",
81+
"# with the backend \"tensorrt\", and run the model on an\n",
82+
"# input to cause compilation, as so:\n",
83+
"optimized_model = torch.compile(model, backend=\"tensorrt\")\n",
84+
"optimized_model(*sample_inputs)"
85+
]
86+
},
87+
{
88+
"cell_type": "markdown",
89+
"id": "ff31119d",
90+
"metadata": {},
91+
"source": [
92+
"## Compilation with `torch.compile` Using Custom Settings"
93+
]
94+
},
95+
{
96+
"cell_type": "code",
97+
"execution_count": null,
98+
"id": "be6692d9",
99+
"metadata": {},
100+
"outputs": [],
101+
"source": [
102+
"# Define sample half inputs and initialize model\n",
103+
"sample_inputs_half = [torch.rand((5, 7)).half().cuda(), torch.rand((5, 7)).half().cuda()]\n",
104+
"model_half = Model().eval().cuda()"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": null,
110+
"id": "67240828",
111+
"metadata": {},
112+
"outputs": [],
113+
"source": [
114+
"# If we want to customize certain options in the backend,\n",
115+
"# but still use the torch.compile call directly, we can call the\n",
116+
"# convenience/helper function create_backend to create a custom backend\n",
117+
"# which has been pre-populated with certain keys\n",
118+
"custom_backend = create_backend(\n",
119+
" lower_precision=LowerPrecision.FP16,\n",
120+
" debug=True,\n",
121+
" min_block_size=2,\n",
122+
" torch_executed_ops={},\n",
123+
")\n",
124+
"\n",
125+
"# Run the model on an input to cause compilation, as so:\n",
126+
"optimized_model_custom = torch.compile(model_half, backend=custom_backend)\n",
127+
"optimized_model_custom(*sample_inputs_half)"
128+
]
129+
}
130+
],
131+
"metadata": {
132+
"kernelspec": {
133+
"display_name": "Python 3 (ipykernel)",
134+
"language": "python",
135+
"name": "python3"
136+
},
137+
"language_info": {
138+
"codemirror_mode": {
139+
"name": "ipython",
140+
"version": 3
141+
},
142+
"file_extension": ".py",
143+
"mimetype": "text/x-python",
144+
"name": "python",
145+
"nbconvert_exporter": "python",
146+
"pygments_lexer": "ipython3",
147+
"version": "3.8.10"
148+
}
149+
},
150+
"nbformat": 4,
151+
"nbformat_minor": 5
152+
}
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "c0b24bfc",
6+
"metadata": {},
7+
"source": [
8+
"# Overview\n",
9+
"This script is intended as a sample of the `torch_tensorrt.dynamo.torch_compile` workflow on a ResNet model"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"id": "04ecf2a9",
15+
"metadata": {},
16+
"source": [
17+
"## Imports and Model Definition"
18+
]
19+
},
20+
{
21+
"cell_type": "code",
22+
"execution_count": null,
23+
"id": "6fc05cb6",
24+
"metadata": {},
25+
"outputs": [],
26+
"source": [
27+
"import torch\n",
28+
"from torch_tensorrt.dynamo import torch_compile\n",
29+
"import torchvision.models as models"
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": null,
35+
"id": "fb2b9221",
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"# Initialize model with half precision and sample inputs\n",
40+
"model = models.resnet18(pretrained=True).half().eval().to(\"cuda\")\n",
41+
"inputs = [torch.randn((1, 3, 224, 224)).to(\"cuda\").half()]"
42+
]
43+
},
44+
{
45+
"cell_type": "markdown",
46+
"id": "ee3ab312",
47+
"metadata": {},
48+
"source": [
49+
"## Optional Input Arguments to `torch_tensorrt.dynamo.torch_compile`"
50+
]
51+
},
52+
{
53+
"cell_type": "code",
54+
"execution_count": null,
55+
"id": "6864197b",
56+
"metadata": {},
57+
"outputs": [],
58+
"source": [
59+
"# Enabled precision for TensorRT optimization\n",
60+
"enabled_precisions = {torch.half}\n",
61+
"\n",
62+
"# Whether to print verbose logs\n",
63+
"debug = True\n",
64+
"\n",
65+
"# Workspace size for TensorRT\n",
66+
"workspace_size = 20 << 30\n",
67+
"\n",
68+
"# Maximum number of TRT Engines\n",
69+
"# (Lower value allows more graph segmentation)\n",
70+
"min_block_size = 3\n",
71+
"\n",
72+
"# Operations to Run in Torch, regardless of converter support\n",
73+
"torch_executed_ops = {}"
74+
]
75+
},
76+
{
77+
"cell_type": "markdown",
78+
"id": "7648ba15",
79+
"metadata": {},
80+
"source": [
81+
"## Compilation with `torch_tensorrt.dynamo.torch_compile`"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": null,
87+
"id": "1be9d0b8",
88+
"metadata": {},
89+
"outputs": [],
90+
"source": [
91+
"# Build and compile the model with torch.compile, using tensorrt backend\n",
92+
"optimized_model = torch_compile(\n",
93+
" model,\n",
94+
" inputs,\n",
95+
" enabled_precisions=enabled_precisions,\n",
96+
" debug=debug,\n",
97+
" workspace_size=workspace_size,\n",
98+
" min_block_size=min_block_size,\n",
99+
" torch_executed_ops=torch_executed_ops,\n",
100+
")"
101+
]
102+
},
103+
{
104+
"cell_type": "markdown",
105+
"id": "9c42544f",
106+
"metadata": {},
107+
"source": [
108+
"## Inference"
109+
]
110+
},
111+
{
112+
"cell_type": "code",
113+
"execution_count": null,
114+
"id": "6acf9768",
115+
"metadata": {},
116+
"outputs": [],
117+
"source": [
118+
"# Does not cause recompilation (same batch size as input)\n",
119+
"new_inputs = [torch.randn((1, 3, 224, 224)).half().to(\"cuda\")]\n",
120+
"new_outputs = optimized_model(*new_inputs)"
121+
]
122+
},
123+
{
124+
"cell_type": "code",
125+
"execution_count": null,
126+
"id": "83185cf2",
127+
"metadata": {},
128+
"outputs": [],
129+
"source": [
130+
"# Does cause recompilation (new batch size)\n",
131+
"new_batch_size_inputs = [torch.randn((8, 3, 224, 224)).half().to(\"cuda\")]\n",
132+
"new_batch_size_outputs = optimized_model(*new_batch_size_inputs)"
133+
]
134+
}
135+
],
136+
"metadata": {
137+
"kernelspec": {
138+
"display_name": "Python 3 (ipykernel)",
139+
"language": "python",
140+
"name": "python3"
141+
},
142+
"language_info": {
143+
"codemirror_mode": {
144+
"name": "ipython",
145+
"version": 3
146+
},
147+
"file_extension": ".py",
148+
"mimetype": "text/x-python",
149+
"name": "python",
150+
"nbconvert_exporter": "python",
151+
"pygments_lexer": "ipython3",
152+
"version": "3.8.10"
153+
}
154+
},
155+
"nbformat": 4,
156+
"nbformat_minor": 5
157+
}

0 commit comments

Comments
 (0)