Skip to content

Commit e9f1077

Browse files
committed
Added changes for inductor and cudagraphs
1 parent bf51948 commit e9f1077

File tree

2 files changed

+144
-9
lines changed

2 files changed

+144
-9
lines changed

_includes/quick-start-module.js

+19-9
Original file line numberDiff line numberDiff line change
@@ -220,15 +220,18 @@ function getInstallCommand(optionID) {
220220

221221
function getTorchCompileUsage(optionId) {
222222
backend = getIDFromBackend(optionId);
223-
importCmd = "<br>" + getImportCmd(optionId) + "<br>";
223+
importCmd = getImportCmd(optionId) + "<br>";
224224
finalCmd = "";
225-
tcUsage = "# Torch Compile usage: ";
225+
tcUsage = "# Torch Compile usage: " + "<br>";
226226
backendCmd = `torch.compile(model, backend="${backend}")`;
227227
libtorchCmd = `# Torch compile ${backend} not supported with Libtorch`;
228228

229229
if (opts.pm == "libtorch") {
230230
return libtorchCmd;
231231
}
232+
if (backend == "inductor" || backend == "cudagraphs") {
233+
return tcUsage + backendCmd;
234+
}
232235
if (backend == "openvino") {
233236
if (opts.pm == "source") {
234237
finalCmd += "# Follow instructions at this URL to build openvino from source: https://github.com/openvinotoolkit/openvino/blob/master/docs/dev/build.md" + "<br>" ;
@@ -268,18 +271,25 @@ function addTorchCompileCommandNote(selectedOptionId) {
268271
if (!selectedOptionId) {
269272
return;
270273
}
271-
272-
$("#command").append(
273-
`<pre> ${getInstallCommand(selectedOptionId)} </pre>`
274-
);
275-
$("#command").append(
276-
`<pre> ${getTorchCompileUsage(selectedOptionId)} </pre>`
277-
);
274+
if (selectedOptionId == "inductor" || selectedOptionId == "cgraphs") {
275+
$("#command").append(
276+
`<pre> ${getTorchCompileUsage(selectedOptionId)} </pre>`
277+
);
278+
}
279+
else {
280+
$("#command").append(
281+
`<pre> ${getInstallCommand(selectedOptionId)} </pre>`
282+
);
283+
$("#command").append(
284+
`<pre> ${getTorchCompileUsage(selectedOptionId)} </pre>`
285+
);
286+
}
278287
}
279288

280289
function selectedOption(option, selection, category) {
281290
$(option).removeClass("selected");
282291
$(selection).addClass("selected");
292+
const previousSelection = opts[category];
283293
opts[category] = selection.id;
284294
if (category === "pm") {
285295
var elements = document.getElementsByClassName("language")[0].children;

assets/quick-start-module.js

+125
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ var opts = {
2121
pm: 'pip',
2222
language: 'python',
2323
ptbuild: 'stable',
24+
'torch-compile': null
2425
};
2526

2627
var supportedCloudPlatforms = [
@@ -34,6 +35,7 @@ var package = $(".package > .option");
3435
var language = $(".language > .option");
3536
var cuda = $(".cuda > .option");
3637
var ptbuild = $(".ptbuild > .option");
38+
var torchCompile = $(".torch-compile > .option")
3739

3840
os.on("click", function() {
3941
selectedOption(os, this, "os");
@@ -50,6 +52,9 @@ cuda.on("click", function() {
5052
ptbuild.on("click", function() {
5153
selectedOption(ptbuild, this, "ptbuild")
5254
});
55+
torchCompile.on("click", function() {
56+
selectedOption(torchCompile, this, "torch-compile")
57+
});
5358

5459
// Pre-select user's operating system
5560
$(function() {
@@ -63,6 +68,119 @@ $(function() {
6368
}
6469
});
6570

71+
function getIDFromBackend(backend) {
72+
const idTobackendMap = {
73+
inductor: 'inductor',
74+
cgraphs : 'cudagraphs',
75+
onnxrt: 'onnxrt',
76+
openvino: 'openvino',
77+
tensorrt: 'tensorrt',
78+
tvm: 'tvm',
79+
};
80+
return idTobackendMap[backend];
81+
}
82+
83+
function getPmCmd(backend) {
84+
const pmCmd = {
85+
onnxrt: 'onnxruntime',
86+
tvm: 'apache-tvm',
87+
openvino: 'openvino',
88+
tensorrt: 'torch-tensorrt',
89+
};
90+
return pmCmd[backend];
91+
}
92+
93+
function getImportCmd(backend) {
94+
const importCmd = {
95+
onnxrt: 'import onnxruntime',
96+
tvm: 'import tvm',
97+
openvino: 'import openvino.torch',
98+
tensorrt: 'import torch_tensorrt'
99+
}
100+
return importCmd[backend];
101+
}
102+
103+
function getInstallCommand(optionID) {
104+
backend = getIDFromBackend(optionID);
105+
pmCmd = getPmCmd(optionID);
106+
finalCmd = "";
107+
if (opts.pm == "pip") {
108+
finalCmd = `pip3 install ${pmCmd}`;
109+
}
110+
else if (opts.pm == "conda") {
111+
finalCmd = `conda install ${pmCmd}`;
112+
}
113+
return finalCmd;
114+
}
115+
116+
function getTorchCompileUsage(optionId) {
117+
backend = getIDFromBackend(optionId);
118+
importCmd = getImportCmd(optionId) + "<br>";
119+
finalCmd = "";
120+
tcUsage = "# Torch Compile usage: " + "<br>";
121+
backendCmd = `torch.compile(model, backend="${backend}")`;
122+
libtorchCmd = `# Torch compile ${backend} not supported with Libtorch`;
123+
console.log("Surya log", finalCmd)
124+
125+
if (opts.pm == "libtorch") {
126+
return libtorchCmd;
127+
}
128+
if (backend == "inductor" || backend == "cudagraphs") {
129+
return tcUsage + backendCmd;
130+
}
131+
if (backend == "openvino") {
132+
if (opts.pm == "source") {
133+
finalCmd += "# Follow instructions at this URL to build openvino from source: https://github.com/openvinotoolkit/openvino/blob/master/docs/dev/build.md" + "<br>" ;
134+
tcUsage += importCmd;
135+
}
136+
else if (opts.pm == "conda") {
137+
tcUsage += importCmd;
138+
}
139+
if (opts.os == "windows" && !tcUsage.includes(importCmd)) {
140+
tcUsage += importCmd;
141+
}
142+
}
143+
else{
144+
tcUsage += importCmd;
145+
}
146+
if (backend == "onnxrt") {
147+
if (opts.pm == "source") {
148+
finalCmd += "# Follow instructions at this URL to build onnxruntime from source: https://onnxruntime.ai/docs/build" + "<br>" ;
149+
}
150+
}
151+
if (backend == "tvm") {
152+
if (opts.pm == "source") {
153+
finalCmd += "# Follow instructions at this URL to build tvm from source: https://tvm.apache.org/docs/install/from_source.html" + "<br>" ;
154+
}
155+
}
156+
if (backend == "tensorrt") {
157+
if (opts.pm == "source") {
158+
finalCmd += "# Follow instructions at this URL to build tensorrt from source: https://pytorch.org/TensorRT/getting_started/installation.html#compiling-from-source" + "<br>" ;
159+
}
160+
}
161+
finalCmd += tcUsage + backendCmd;
162+
return finalCmd
163+
}
164+
165+
function addTorchCompileCommandNote(selectedOptionId) {
166+
167+
if (!selectedOptionId) {
168+
return;
169+
}
170+
if (selectedOptionId == "inductor" || selectedOptionId == "cgraphs") {
171+
$("#command").append(
172+
`<pre> ${getTorchCompileUsage(selectedOptionId)} </pre>`
173+
);
174+
}
175+
else {
176+
$("#command").append(
177+
`<pre> ${getInstallCommand(selectedOptionId)} </pre>`
178+
);
179+
$("#command").append(
180+
`<pre> ${getTorchCompileUsage(selectedOptionId)} </pre>`
181+
);
182+
}
183+
}
66184

67185
// determine os (mac, linux, windows) based on user's platform
68186
function getDefaultSelectedOS() {
@@ -171,6 +289,7 @@ function changeAccNoneName(osname) {
171289
function selectedOption(option, selection, category) {
172290
$(option).removeClass("selected");
173291
$(selection).addClass("selected");
292+
const previousSelection = opts[category];
174293
opts[category] = selection.id;
175294
if (category === "pm") {
176295
var elements = document.getElementsByClassName("language")[0].children;
@@ -208,13 +327,19 @@ function selectedOption(option, selection, category) {
208327
changeVersion(opts.ptbuild);
209328
//make sure unsupported platforms are disabled
210329
disableUnsupportedPlatforms(opts.os);
330+
} else if (category === "torch-compile") {
331+
if (selection.id === previousSelection) {
332+
$(selection).removeClass("selected");
333+
opts[category] = null;
334+
}
211335
}
212336
commandMessage(buildMatcher());
213337
if (category === "os") {
214338
disableUnsupportedPlatforms(opts.os);
215339
display(opts.os, 'installation', 'os');
216340
}
217341
changeAccNoneName(opts.os);
342+
addTorchCompileCommandNote(opts['torch-compile'])
218343
}
219344

220345
function display(selection, id, category) {

0 commit comments

Comments
 (0)