Skip to content

Commit 8015490

Browse files
committed
using nccl ops from TRT-LLM namespace
1 parent 6d40ff1 commit 8015490

File tree

5 files changed

+266
-8
lines changed

5 files changed

+266
-8
lines changed

examples/distributed_inference/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,7 @@ See the examples started with `data_parallel` for more details.
1414
Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.
1515

1616
torchrun --nproc_per_node=2 tensor_parallel_llama2.py
17+
18+
3. Tensor parallel distributed inference using nccl ops plugin
19+
20+
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
accelerate
22
transformers
3-
diffusers
3+
diffusers
4+
site
5+
# Install tensorrt-llm without its dependencies (use the command separately). pip install tensorrt-llm --no-deps
6+
tensorrt-llm

examples/distributed_inference/tensor_parallel_simple_example.py

Lines changed: 184 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1+
import ctypes
2+
import logging
13
import os
4+
import site
25
import sys
36
import time
7+
from enum import IntEnum, IntFlag, auto
8+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
49

10+
import numpy as np
11+
import tensorrt as trt
12+
import tensorrt_llm
513
import torch
14+
import torch.distributed as dist
615
import torch.nn as nn
716
import torch_tensorrt
817
from torch.distributed._tensor import Shard
@@ -12,6 +21,181 @@
1221
RowwiseParallel,
1322
parallelize_module,
1423
)
24+
from torch.fx import GraphModule, Node
25+
from torch.fx.node import Argument, Target
26+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
27+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
28+
dynamo_tensorrt_converter,
29+
)
30+
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
31+
custom_fused_all_gather_op,
32+
custom_fused_reduce_scatter_op,
33+
)
34+
from torch_tensorrt.dynamo.types import TRTTensor
35+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
36+
37+
38+
# This is required for env initialization since we use mpirun
39+
def initialize(rank=0, world_size=1, port=29500):
40+
local_rank = int(
41+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
42+
)
43+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
44+
45+
# Set up environment variable to run with mpirun
46+
os.environ["RANK"] = str(local_rank)
47+
os.environ["WORLD_SIZE"] = str(world_size)
48+
os.environ["MASTER_ADDR"] = "127.0.0.1"
49+
os.environ["MASTER_PORT"] = str(port)
50+
51+
# Necessary to assign a device to each rank.
52+
torch.cuda.set_device(local_rank)
53+
54+
# We use nccl backend
55+
dist.init_process_group("nccl")
56+
57+
# set a manual seed for reproducibility
58+
torch.manual_seed(1111)
59+
60+
return local_rank, world_size
61+
62+
63+
initialize()
64+
# create a device mesh based on the given world_size.
65+
_world_size = int(os.environ["WORLD_SIZE"])
66+
67+
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
68+
_rank = device_mesh.get_rank()
69+
device_id = _rank % torch.cuda.device_count() # Ensure each rank gets a unique device
70+
torch.cuda.set_device(device_id)
71+
72+
73+
logger = logging.getLogger()
74+
logger.setLevel(logging.INFO)
75+
fh = logging.FileHandler(f"./tensor_parallel_simple_example_{_rank}.log", mode="w")
76+
fh.setLevel(logging.INFO)
77+
logger.addHandler(fh)
78+
79+
80+
# TensorRT NCCL plugins
81+
tensorrt_llm_lib_path = tensorrt_llm.__file__
82+
plugin_lib_path = tensorrt_llm_lib_path + "/libs/libnvinfer_plugin_tensorrt_llm.so"
83+
try:
84+
ctypes.CDLL(plugin_lib_path)
85+
logger.info(f"plugin loaded successfully")
86+
except OSError as e:
87+
logger.info(f"unsuccessful load : {e}")
88+
trt.init_libnvinfer_plugins(None, "")
89+
# Iterate over all registered plugin creators
90+
plugin_registry = trt.get_plugin_registry()
91+
for plugin_creator in plugin_registry.plugin_creator_list:
92+
logger.info(
93+
f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
94+
)
95+
96+
97+
# class for AllReduce
98+
class AllReduceStrategy(IntEnum):
99+
"""Warning: actual definition is in kernels/customAllReduceKernels.h.
100+
101+
They must be kept in sync.
102+
"""
103+
104+
NCCL = 0
105+
ONESHOT = 1
106+
TWOSHOT = 2
107+
AUTO = 3
108+
109+
110+
class AllReduceConfig(IntFlag):
111+
"""Warning: actual definition is in kernels/customAllReduceKernels.h.
112+
113+
They must be kept in sync
114+
"""
115+
116+
USE_MEMCPY = auto()
117+
PUSH_MODE = auto()
118+
119+
120+
@dynamo_tensorrt_converter(custom_fused_all_gather_op)
121+
def insert_nccl_gather_op(
122+
ctx: ConversionContext,
123+
target: Target,
124+
args: Tuple[Argument, ...],
125+
kwargs: Dict[str, Argument],
126+
name: str,
127+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
128+
plug_inputs = [args[0]]
129+
allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator(
130+
"AllGather", "1", "tensorrt_llm"
131+
)
132+
assert allgather_plg_creator is not None
133+
world_size = dist.get_world_size()
134+
group = list(range(world_size))
135+
group = trt.PluginField(
136+
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
137+
)
138+
p_dtype = trt.float16
139+
pf_type = trt.PluginField(
140+
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
141+
)
142+
pfc = trt.PluginFieldCollection([group, pf_type])
143+
allgather = allgather_plg_creator.create_plugin("allgather", pfc)
144+
layer = ctx.net.add_plugin_v2(plug_inputs, allgather)
145+
set_layer_name(layer, target, name)
146+
return layer.get_output(0)
147+
148+
149+
@dynamo_tensorrt_converter(custom_fused_reduce_scatter_op)
150+
def insert_nccl_reduce_scatter_plugin(
151+
ctx: ConversionContext,
152+
target: Target,
153+
args: Tuple[Argument, ...],
154+
kwargs: Dict[str, Argument],
155+
name: str,
156+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
157+
plug_inputs = [args[0]]
158+
allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator(
159+
"ReduceScatter", "1", "tensorrt_llm"
160+
)
161+
162+
assert allreduce_plg_creator is not None
163+
164+
counter = 0
165+
strategy = AllReduceStrategy.NCCL
166+
config = AllReduceConfig(0)
167+
168+
world_size = dist.get_world_size()
169+
group = list(range(world_size))
170+
group = trt.PluginField(
171+
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
172+
)
173+
174+
p_dtype = trt.float16
175+
pf_dtype = trt.PluginField(
176+
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
177+
)
178+
pfc = [group, pf_dtype]
179+
p_strategy = trt.PluginField(
180+
"strategy", np.array([int(strategy)], np.int8), trt.PluginFieldType.INT8
181+
)
182+
pfc.append(p_strategy)
183+
p_config = trt.PluginField(
184+
"config", np.array([int(config)], np.int8), trt.PluginFieldType.INT8
185+
)
186+
pfc.append(p_config)
187+
p_counter = trt.PluginField(
188+
"counter", np.array([counter], np.int32), trt.PluginFieldType.INT32
189+
)
190+
pfc.append(p_counter)
191+
192+
pfc = trt.PluginFieldCollection(pfc)
193+
ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc)
194+
195+
layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug)
196+
set_layer_name(layer, target, name)
197+
return layer.get_output(0)
198+
15199

16200
"""
17201
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
@@ -36,13 +220,6 @@ def forward(self, x):
36220
return x
37221

38222

39-
# create a device mesh based on the given world_size.
40-
_world_size = int(os.environ["WORLD_SIZE"])
41-
42-
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
43-
_rank = device_mesh.get_rank()
44-
45-
46223
print(f"Starting PyTorch TP example on rank {_rank}.")
47224
assert (
48225
_world_size % 2 == 0

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .accumulate_fp32_matmul import accumulate_fp32_matmul
88
from .constant_folding import constant_fold
9+
from .fuse_distributed_ops import fuse_distributed_ops
910
from .fuse_prims_broadcast import fuse_prims_broadcast
1011
from .lower_linear import lower_linear
1112
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
@@ -26,6 +27,7 @@
2627
lower_scaled_dot_product_attention,
2728
lower_linear,
2829
fuse_prims_broadcast,
30+
fuse_distributed_ops,
2931
replace_max_pool_with_indices,
3032
replace_full_like_with_full,
3133
view_to_reshape,
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import logging
2+
from typing import Sequence
3+
4+
import torch
5+
from torch_tensorrt.dynamo._settings import CompilationSettings
6+
7+
# dead-code elimination, linting, and recompilation for graph, in-place
8+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
9+
clean_up_graph_after_modifications,
10+
)
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
def custom_fused_all_gather_op(args0, args1, args2):
16+
return torch.ops._c10d_functional.wait_tensor.default(
17+
torch.ops._c10d_functional.all_gather_into_tensor.default(args0, args1, args2)
18+
)
19+
20+
21+
def custom_fused_reduce_scatter_op(args0, args1, args2, args3):
22+
return torch.ops._c10d_functional.wait_tensor.default(
23+
torch.ops._c10d_functional.reduce_scatter_tensor.default(
24+
args0, args1, args2, args3
25+
)
26+
)
27+
28+
29+
def fuse_distributed_ops(
30+
gm: torch.fx.GraphModule, settings: CompilationSettings
31+
) -> torch.fx.GraphModule:
32+
modified_graph = False
33+
for node in gm.graph.nodes:
34+
if (
35+
node.target
36+
in (
37+
torch.ops._c10d_functional.all_gather_into_tensor.default,
38+
torch.ops._c10d_functional.reduce_scatter_tensor.default,
39+
)
40+
and len(node.users) == 1
41+
and list(node.users)[0].target
42+
== torch.ops._c10d_functional.wait_tensor.default
43+
):
44+
wait_tensor_node = list(node.users)[0]
45+
fused_op = None
46+
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
47+
fused_op = custom_fused_all_gather_op
48+
fused_op_args = (node.args[0], node.args[1], node.args[2])
49+
else:
50+
fused_op = custom_fused_reduce_scatter_op
51+
fused_op_args = (node.args[0], node.args[1], node.args[2], node.args[3])
52+
with gm.graph.inserting_after(wait_tensor_node):
53+
fused_node = gm.graph.create_node(
54+
op="call_function",
55+
target=fused_op, # Define your custom fused function
56+
args=fused_op_args,
57+
)
58+
59+
wait_tensor_node.replace_all_uses_with(fused_node)
60+
fused_node.meta.update(node.meta)
61+
modified_graph = True
62+
gm.graph.erase_node(wait_tensor_node)
63+
gm.graph.erase_node(node)
64+
65+
# If graph was modified, clean it up
66+
if modified_graph:
67+
gm = clean_up_graph_after_modifications(gm)
68+
logger.debug(
69+
f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}"
70+
)
71+
72+
return gm

0 commit comments

Comments
 (0)