|
| 1 | +import copy |
1 | 2 | import operator
|
| 3 | +from typing import Callable, Dict, List, Optional, Tuple, Union |
2 | 4 |
|
3 | 5 | import torch
|
4 | 6 | import torch.fx as fx
|
5 | 7 |
|
| 8 | +from vllm.logger import init_logger |
| 9 | + |
| 10 | +from .compile_context import get_compile_context |
| 11 | +from .levels import CompilationLevel |
| 12 | + |
| 13 | +logger = init_logger(__name__) |
| 14 | + |
6 | 15 |
|
7 | 16 | def fix_functionalization(graph: fx.Graph):
|
8 | 17 | """
|
@@ -148,9 +157,113 @@ def fix_functionalization(graph: fx.Graph):
|
148 | 157 | # print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
149 | 158 |
|
150 | 159 |
|
151 |
| -def vllm_backend(graph, example_inputs): |
| 160 | +def wrap_inductor(graph, example_inputs, additional_inductor_config): |
152 | 161 | from torch._inductor import config
|
153 | 162 | current_config = config.shallow_copy_dict()
|
154 | 163 | from torch._inductor.compile_fx import compile_fx
|
| 164 | + |
| 165 | + if additional_inductor_config is not None: |
| 166 | + current_config.update(additional_inductor_config) |
| 167 | + if current_config['post_grad_custom_post_pass'] is not None: |
| 168 | + logger.warning( |
| 169 | + "post_grad_custom_post_pass is already set in the config. " |
| 170 | + "Overwriting it with the fix_functionalization") |
155 | 171 | current_config['post_grad_custom_post_pass'] = fix_functionalization
|
156 | 172 | return compile_fx(graph, example_inputs, config_patches=current_config)
|
| 173 | + |
| 174 | + |
| 175 | +def vllm_backend( |
| 176 | + graph, |
| 177 | + example_inputs, |
| 178 | + additional_inductor_config: Optional[Dict] = None) -> Callable: |
| 179 | + |
| 180 | + context = get_compile_context() |
| 181 | + context = copy.deepcopy(context) if context is not None else [] |
| 182 | + sizes_to_specialize: List[int] = context |
| 183 | + |
| 184 | + # flags for all the seen shapes, whether we need to specialize |
| 185 | + runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {} |
| 186 | + |
| 187 | + # if we need to specialize, the compiled graph for that shape |
| 188 | + runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {} |
| 189 | + |
| 190 | + # this is the first compilation, we will compile a graph with |
| 191 | + # dynamic shape, as the caller will mark first dimension as dynamic |
| 192 | + logger.info("Compiling a graph for general shapes") |
| 193 | + graph_for_symbolic_shape = wrap_inductor(graph, example_inputs, |
| 194 | + additional_inductor_config) |
| 195 | + |
| 196 | + # TODO: Dynamo does not pass all dynamic shapes. |
| 197 | + # Need to investigate why. It works now because all the dynamic |
| 198 | + # shapes have the same value, and either of them can be used. |
| 199 | + sym_shape_indices = [ |
| 200 | + i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt) |
| 201 | + ] |
| 202 | + |
| 203 | + first_run = True |
| 204 | + |
| 205 | + # this is the function we return to Dynamo to run finally |
| 206 | + def compiled_graph_wrapper(*args): |
| 207 | + |
| 208 | + runtime_shapes: Tuple[int, |
| 209 | + ...] = tuple(args[i] for i in sym_shape_indices) |
| 210 | + |
| 211 | + nonlocal first_run |
| 212 | + nonlocal runtime_shapes_to_compile_flags |
| 213 | + nonlocal runtime_shapes_to_compiled_graph |
| 214 | + |
| 215 | + if first_run: |
| 216 | + # the first compilation is for profiling, we directly run it |
| 217 | + first_run = False |
| 218 | + return graph_for_symbolic_shape(*args) |
| 219 | + |
| 220 | + if runtime_shapes not in runtime_shapes_to_compile_flags: |
| 221 | + # we haven't seen this shape before |
| 222 | + # query if we need to specialize for this shape |
| 223 | + # we only specialize for the first dimension. |
| 224 | + # TODO: investigate if any model needs to specialize |
| 225 | + # beyond the first dimension |
| 226 | + runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[ |
| 227 | + 0] in sizes_to_specialize |
| 228 | + |
| 229 | + if not runtime_shapes_to_compile_flags[runtime_shapes]: |
| 230 | + # we don't need to specialize for this shape |
| 231 | + return graph_for_symbolic_shape(*args) |
| 232 | + |
| 233 | + if runtime_shapes not in runtime_shapes_to_compiled_graph: |
| 234 | + # we need to specialize for this shape, and we haven't compiled |
| 235 | + # compile the graph for this shape |
| 236 | + logger.info("Compiling a graph for shapes %s", runtime_shapes) |
| 237 | + runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor( |
| 238 | + graph, args, additional_inductor_config) |
| 239 | + |
| 240 | + return runtime_shapes_to_compiled_graph[runtime_shapes](*args) |
| 241 | + |
| 242 | + return compiled_graph_wrapper |
| 243 | + |
| 244 | + |
| 245 | +def select_default_backend(level: int) -> Union[str, Callable]: |
| 246 | + if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: |
| 247 | + backend = "eager" |
| 248 | + return backend |
| 249 | + assert level in [ |
| 250 | + CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE |
| 251 | + ], f"Invalid level {level}" |
| 252 | + |
| 253 | + from vllm.compilation.backends import vllm_backend |
| 254 | + from vllm.plugins import get_inductor_additional_configs |
| 255 | + additional_configs = get_inductor_additional_configs() |
| 256 | + |
| 257 | + if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE: |
| 258 | + if "max_autotune" in additional_configs and not additional_configs[ |
| 259 | + "max_autotune"]: |
| 260 | + logger.warning( |
| 261 | + "max_autotune is disabled, but is overridden by level %s", |
| 262 | + CompilationLevel.INDUCTOR_MAX_AUTOTUNE) |
| 263 | + additional_configs['max_autotune'] = True |
| 264 | + |
| 265 | + from functools import partial |
| 266 | + backend = partial(vllm_backend, |
| 267 | + additional_inductor_config=additional_configs) |
| 268 | + |
| 269 | + return backend |
0 commit comments