Skip to content

Commit 701822c

Browse files
committed
[linting] Enable ruff on more files (wave 2/N)
Signed-off-by: William Zhang <[email protected]>
1 parent 8454640 commit 701822c

File tree

24 files changed

+1784
-1692
lines changed

24 files changed

+1784
-1692
lines changed

.git-blame-ignore-revs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Ruff formatting / linting adoption.
2+
dc52b67492b2f6531e310bed90f88c8427ad3908

pyproject.toml

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,20 @@ extend_skip_glob = [
3333
"tensorrt_llm/top_model_mixin.py",
3434
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
3535
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
36+
# Phase 2.
37+
"tensorrt_llm/_tensorrt_engine/*.py",
38+
"tensorrt_llm/_torch/auto_deploy/custom_ops/torch_libs/*.py",
39+
"tensorrt_llm/_torch/debug/*.py",
40+
"tensorrt_llm/_torch/shared_tensor/*.py",
41+
"tensorrt_llm/_torch/peft/*.py",
42+
"tensorrt_llm/evaluate/lm_eval_tasks/gpqa/cot_zeroshot_aa/*.py",
43+
"tensorrt_llm/models/clip/*.py",
44+
"tensorrt_llm/models/internlm/*.py",
45+
"tensorrt_llm/models/mmdit_sd3/*.py",
46+
"tensorrt_llm/models/multimodal_encoders/*.py",
47+
"tensorrt_llm/models/skywork/*.py",
48+
"tensorrt_llm/models/stdit/*.py",
49+
"tensorrt_llm/scaffolding/contrib/**/*.py",
3650
]
3751

3852
[tool.yapf]
@@ -63,6 +77,20 @@ ignore_patterns = [
6377
"tensorrt_llm/top_model_mixin.py",
6478
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
6579
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
80+
# Phase 2.
81+
"tensorrt_llm/_tensorrt_engine/*.py",
82+
"tensorrt_llm/_torch/auto_deploy/custom_ops/torch_libs/*.py",
83+
"tensorrt_llm/_torch/debug/*.py",
84+
"tensorrt_llm/_torch/shared_tensor/*.py",
85+
"tensorrt_llm/_torch/peft/*.py",
86+
"tensorrt_llm/evaluate/lm_eval_tasks/gpqa/cot_zeroshot_aa/*.py",
87+
"tensorrt_llm/models/clip/*.py",
88+
"tensorrt_llm/models/internlm/*.py",
89+
"tensorrt_llm/models/mmdit_sd3/*.py",
90+
"tensorrt_llm/models/multimodal_encoders/*.py",
91+
"tensorrt_llm/models/skywork/*.py",
92+
"tensorrt_llm/models/stdit/*.py",
93+
"tensorrt_llm/scaffolding/contrib/**/*.py",
6694
]
6795

6896
[tool.codespell]
@@ -74,8 +102,7 @@ ignore-words-list = "rouge,inout,atleast,strat,nd,subtile,thrid,improbe,NotIn,te
74102
in-place = true
75103
remove_all_unused_imports = true
76104
remove_unused_variables = true
77-
# This should match the `include` in `[tool.ruff]`. See the comments in that section for why this
78-
# is necessary.
105+
# This should match the `include` in `[tool.ruff]`. The reason is that it is setup to run on the entire codebase.
79106
exclude = [
80107
"**/auto_deploy/**",
81108
"tensorrt_llm/_common.py",
@@ -97,6 +124,20 @@ exclude = [
97124
"tensorrt_llm/top_model_mixin.py",
98125
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
99126
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
127+
# Phase 2.
128+
"tensorrt_llm/_tensorrt_engine/*.py",
129+
"tensorrt_llm/_torch/auto_deploy/custom_ops/torch_libs/*.py",
130+
"tensorrt_llm/_torch/debug/*.py",
131+
"tensorrt_llm/_torch/shared_tensor/*.py",
132+
"tensorrt_llm/_torch/peft/*.py",
133+
"tensorrt_llm/evaluate/lm_eval_tasks/gpqa/cot_zeroshot_aa/*.py",
134+
"tensorrt_llm/models/clip/*.py",
135+
"tensorrt_llm/models/internlm/*.py",
136+
"tensorrt_llm/models/mmdit_sd3/*.py",
137+
"tensorrt_llm/models/multimodal_encoders/*.py",
138+
"tensorrt_llm/models/skywork/*.py",
139+
"tensorrt_llm/models/stdit/*.py",
140+
"tensorrt_llm/scaffolding/contrib/**/*.py",
100141
]
101142

102143

@@ -140,6 +181,20 @@ include = [
140181
"tensorrt_llm/top_model_mixin.py",
141182
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
142183
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
184+
# Phase 2.
185+
"tensorrt_llm/_tensorrt_engine/*.py",
186+
"tensorrt_llm/_torch/auto_deploy/custom_ops/torch_libs/*.py",
187+
"tensorrt_llm/_torch/debug/*.py",
188+
"tensorrt_llm/_torch/shared_tensor/*.py",
189+
"tensorrt_llm/_torch/peft/*.py",
190+
"tensorrt_llm/evaluate/lm_eval_tasks/gpqa/cot_zeroshot_aa/*.py",
191+
"tensorrt_llm/models/clip/*.py",
192+
"tensorrt_llm/models/internlm/*.py",
193+
"tensorrt_llm/models/mmdit_sd3/*.py",
194+
"tensorrt_llm/models/multimodal_encoders/*.py",
195+
"tensorrt_llm/models/skywork/*.py",
196+
"tensorrt_llm/models/stdit/*.py",
197+
"tensorrt_llm/scaffolding/contrib/**/*.py",
143198
]
144199
exclude = [
145200
"**3rdparty/**",
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from tensorrt_llm.llmapi.llm import _TrtLLM as LLM
22

3-
__all__ = ['LLM']
3+
__all__ = ["LLM"]

tensorrt_llm/_torch/debug/debug_hook.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ def get_module_indices_tree(self):
9898
return self.layer_inner_counter
9999

100100
def get_current_model_loop_index(self):
101-
return self.layer_inner_counter[0] + 1 if len(
102-
self.layer_inner_counter) >= 1 else 0
101+
return self.layer_inner_counter[0] + 1 if len(self.layer_inner_counter) >= 1 else 0
103102

104103
def do_actions(self, module, tensors, actions):
105104
assert isinstance(actions, list), "Actions shall be list."
@@ -109,7 +108,6 @@ def do_actions(self, module, tensors, actions):
109108

110109

111110
class Filter:
112-
113111
def __init__(self):
114112
pass
115113

@@ -151,14 +149,12 @@ def pre_forward(module: nn.Module, args, kwargs):
151149
if len(debug_ctx.get_module_indices_tree()) == 0:
152150
debug_ctx.get_module_indices_tree().append(0)
153151

154-
if len(debug_ctx.get_current_modules_tree()) >= len(
155-
debug_ctx.get_module_indices_tree()):
152+
if len(debug_ctx.get_current_modules_tree()) >= len(debug_ctx.get_module_indices_tree()):
156153
debug_ctx.get_module_indices_tree().append(0)
157154

158-
debug_ctx.get_module_indices_tree()[
159-
len(debug_ctx.get_current_modules_tree()) -
160-
1] = debug_ctx.get_module_indices_tree()[
161-
len(debug_ctx.get_current_modules_tree()) - 1] + 1
155+
debug_ctx.get_module_indices_tree()[len(debug_ctx.get_current_modules_tree()) - 1] = (
156+
debug_ctx.get_module_indices_tree()[len(debug_ctx.get_current_modules_tree()) - 1] + 1
157+
)
162158
debug_ctx.do_actions(module, args, debug_ctx.get_pre_forward_action())
163159
return None
164160

@@ -179,8 +175,7 @@ def after_forward(module: nn.Module, args, kwargs, output):
179175
"""
180176
debug_ctx = get_current_debug_ctx()
181177
debug_ctx.mark_in_pre_forward(False)
182-
debug_ctx.do_actions(module, [args, output],
183-
debug_ctx.get_after_forward_action())
178+
debug_ctx.do_actions(module, [args, output], debug_ctx.get_after_forward_action())
184179
name = module.name if hasattr(module, "name") else module.__class__.__name__
185180
old_name = debug_ctx.get_current_modules_tree().pop(-1)
186181
assert name == old_name, "module mismatch"
@@ -189,9 +184,9 @@ def after_forward(module: nn.Module, args, kwargs, output):
189184
return None
190185

191186

192-
def enable_debug(model: nn.Module,
193-
dest_folder: Optional[str] = None,
194-
filter: Optional[Filter] = None):
187+
def enable_debug(
188+
model: nn.Module, dest_folder: Optional[str] = None, filter: Optional[Filter] = None
189+
):
195190
"""
196191
The function style to interface to enable debugger on model.
197192
If filter is provided, it will be used to filter out satisfied module to register hook.
@@ -231,16 +226,16 @@ def enable_debug(model: nn.Module,
231226
if submodule not in debug_ctx.forward_hook_handles:
232227
do_hook = filter(submodule) if filter is not None else True
233228
if do_hook:
234-
debug_ctx.forward_hook_handles[
235-
submodule] = submodule.register_forward_hook(
236-
after_forward, with_kwargs=True, always_call=True)
229+
debug_ctx.forward_hook_handles[submodule] = submodule.register_forward_hook(
230+
after_forward, with_kwargs=True, always_call=True
231+
)
237232

238233
if submodule not in debug_ctx.forward_pre_hook_handles:
239234
do_hook = filter(submodule) if filter is not None else True
240235
if do_hook:
241-
debug_ctx.forward_pre_hook_handles[
242-
submodule] = submodule.register_forward_pre_hook(
243-
pre_forward, with_kwargs=True)
236+
debug_ctx.forward_pre_hook_handles[submodule] = submodule.register_forward_pre_hook(
237+
pre_forward, with_kwargs=True
238+
)
244239

245240

246241
def disable_debug():
@@ -262,9 +257,9 @@ def disable_debug():
262257

263258

264259
@contextmanager
265-
def debug_mode(model: nn.Module,
266-
dest_folder: Optional[str] = None,
267-
filter: Optional[Filter] = None):
260+
def debug_mode(
261+
model: nn.Module, dest_folder: Optional[str] = None, filter: Optional[Filter] = None
262+
):
268263
"""
269264
The context manager style interface to enable debugger on model.
270265
If filter is provided, it will be used to filter out satisfied module to register hook.
@@ -329,8 +324,9 @@ def dump_tensor(module: nn.Module, data_tensor, debug_ctx: DebuggerContext):
329324
def get_dump_file_path(tensor):
330325
nonlocal tensor_counter
331326
nonlocal input_tensor_names
332-
assert debug_ctx.get_log_folder(
333-
) is not None, "Log folder shall be initialized by DebugContext."
327+
assert debug_ctx.get_log_folder() is not None, (
328+
"Log folder shall be initialized by DebugContext."
329+
)
334330

335331
name_parts = []
336332
for idx in range(len(debug_ctx.get_current_modules_tree())):

tensorrt_llm/_torch/peft/lora/layer.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class LoraModuleType(IntEnum):
1010
This enum maps to the different attention and MLP components in a transformer model
1111
that can be adapted using LoRA weights.
1212
"""
13+
1314
ATTENTION_QKV = 0 # Combined QKV projection
1415
ATTENTION_Q = 1 # Query projection
1516
ATTENTION_K = 2 # Key projection
@@ -60,32 +61,37 @@ def from_string(cls, name: str) -> "LoraModuleType":
6061
def is_attention(self) -> bool:
6162
"""Check if this is an attention module type."""
6263
return self in {
63-
self.ATTENTION_QKV, self.ATTENTION_Q, self.ATTENTION_K,
64-
self.ATTENTION_V, self.ATTENTION_DENSE, self.CROSS_ATTENTION_QKV,
65-
self.CROSS_ATTENTION_Q, self.CROSS_ATTENTION_K,
66-
self.CROSS_ATTENTION_V, self.CROSS_ATTENTION_DENSE
64+
self.ATTENTION_QKV,
65+
self.ATTENTION_Q,
66+
self.ATTENTION_K,
67+
self.ATTENTION_V,
68+
self.ATTENTION_DENSE,
69+
self.CROSS_ATTENTION_QKV,
70+
self.CROSS_ATTENTION_Q,
71+
self.CROSS_ATTENTION_K,
72+
self.CROSS_ATTENTION_V,
73+
self.CROSS_ATTENTION_DENSE,
6774
}
6875

6976
@property
7077
def is_mlp(self) -> bool:
7178
"""Check if this is an MLP module type."""
7279
return self in {
73-
self.MLP_H_TO_4H, self.MLP_4H_TO_H, self.MLP_GATE, self.MLP_GATE_UP,
74-
self.MLP_ROUTER
80+
self.MLP_H_TO_4H,
81+
self.MLP_4H_TO_H,
82+
self.MLP_GATE,
83+
self.MLP_GATE_UP,
84+
self.MLP_ROUTER,
7585
}
7686

7787
@property
7888
def is_moe(self) -> bool:
7989
"""Check if this is a Mixture of Experts (MoE) module type."""
80-
return self in {
81-
self.MOE_H_TO_4H, self.MOE_4H_TO_H, self.MOE_GATE, self.MOE_ROUTER
82-
}
90+
return self in {self.MOE_H_TO_4H, self.MOE_4H_TO_H, self.MOE_GATE, self.MOE_ROUTER}
8391

8492

8593
class LoraLayer(torch.nn.Module):
86-
87-
def __init__(self, lora_module_types: List[LoraModuleType],
88-
output_hidden_sizes: List[int]):
94+
def __init__(self, lora_module_types: List[LoraModuleType], output_hidden_sizes: List[int]):
8995
super().__init__()
9096

9197
self.lora_module_types = lora_module_types
@@ -98,7 +104,6 @@ def forward(
98104
lora_params: Dict,
99105
layer_idx: int,
100106
) -> Optional[torch.Tensor]:
101-
102107
if bool(lora_params):
103108
lora_ranks = []
104109
lora_weight_pointers = []
@@ -108,23 +113,23 @@ def forward(
108113
if module_idx in lora_params[layer_idx]:
109114
active_lora_module_ids.append(module_idx)
110115
# TODO (dafrimi): needs to pass this is_dora arg
111-
lora_params[layer_idx][module_idx]['is_dora']
112-
lora_ranks.append(
113-
lora_params[layer_idx][module_idx]['adapter_size'])
116+
lora_params[layer_idx][module_idx]["is_dora"]
117+
lora_ranks.append(lora_params[layer_idx][module_idx]["adapter_size"])
114118
lora_weight_pointers.append(
115-
lora_params[layer_idx][module_idx]['weight_pointers'])
119+
lora_params[layer_idx][module_idx]["weight_pointers"]
120+
)
116121

117-
num_seqs = lora_params['num_seqs']
122+
num_seqs = lora_params["num_seqs"]
118123

119124
if len(active_lora_module_ids) == 0:
120125
return None
121126
else:
122127
lora_outputs = torch.ops.trtllm.lora_grouped_gemm(
123128
x,
124-
lora_params['host_request_types'][:num_seqs],
129+
lora_params["host_request_types"][:num_seqs],
125130
lora_ranks,
126131
lora_weight_pointers,
127-
lora_params['prompt_lens_cpu'][:num_seqs],
132+
lora_params["prompt_lens_cpu"][:num_seqs],
128133
self.output_hidden_sizes,
129134
False, # transA
130135
True, # transB
@@ -144,13 +149,17 @@ def forward(
144149
lora_output.append(lora_outputs.pop(0))
145150
else:
146151
lora_output.append(
147-
torch.zeros(list(x.shape[:-1]) + [
148-
self.output_hidden_sizes[
149-
self.lora_module_types.index(
150-
module_idx)]
151-
],
152-
dtype=x.dtype,
153-
device=x.device))
152+
torch.zeros(
153+
list(x.shape[:-1])
154+
+ [
155+
self.output_hidden_sizes[
156+
self.lora_module_types.index(module_idx)
157+
]
158+
],
159+
dtype=x.dtype,
160+
device=x.device,
161+
)
162+
)
154163
lora_output = torch.cat(lora_output, dim=-1)
155164
return lora_output
156165

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
from .shared_tensor import (SharedTensorContainer,
2-
_SharedTensorRebuildMethodRegistry)
1+
from .shared_tensor import SharedTensorContainer, _SharedTensorRebuildMethodRegistry
32

43
# Initialize the registry when the package is imported
54
_SharedTensorRebuildMethodRegistry.initialize()
65

76
__all__ = [
8-
'SharedTensorContainer',
7+
"SharedTensorContainer",
98
]

0 commit comments

Comments
 (0)