diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 4feea786f38b..3c3da41c3abf 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -60,11 +60,13 @@ steps: mirror_hardwares: [amd] commands: # install aws cli for llava_example.py - - pip install awscli + # install tensorizer for tensorize_vllm_model.py + - pip install awscli tensorizer - python3 offline_inference.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - python3 llava_example.py + - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - label: Kernels Test %N command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py index e2456168de9d..8b74ae1d75a1 100644 --- a/examples/tensorize_vllm_model.py +++ b/examples/tensorize_vllm_model.py @@ -1,23 +1,20 @@ import argparse import dataclasses +import json import os -import time import uuid from functools import partial -from typing import Type -import torch -import torch.nn as nn -from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, - TensorSerializer, stream_io) -from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor -from transformers import AutoConfig, PretrainedConfig +from tensorizer import stream_io -from vllm.distributed import initialize_model_parallel +from vllm import LLM +from vllm.distributed import (init_distributed_environment, + initialize_model_parallel) from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.model_executor.model_loader.tensorizer import TensorizerArgs -from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs, + TensorizerConfig, + serialize_vllm_model) # yapf conflicts with isort for this docstring # yapf: disable @@ -27,25 +24,25 @@ to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint, or locally. Tensor encryption and decryption is also supported, although libsodium must be installed to use it. Install vllm with tensorizer support -using `pip install vllm[tensorizer]`. +using `pip install vllm[tensorizer]`. To learn more about tensorizer, visit +https://github.com/coreweave/tensorizer To serialize a model, install vLLM from source, then run something like this from the root level of this repository: python -m examples.tensorize_vllm_model \ - --model EleutherAI/gpt-j-6B \ - --dtype float16 \ + --model facebook/opt-125m \ serialize \ - --serialized-directory s3://my-bucket/ \ - --suffix vllm + --serialized-directory s3://my-bucket \ + --suffix v1 Which downloads the model from HuggingFace, loads it into vLLM, serializes it, and saves it to your S3 bucket. A local directory can also be used. This assumes your S3 credentials are specified as environment variables -in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. -To provide S3 credentials directly, you can provide `--s3-access-key-id` and -`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this -script. +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and +`S3_ENDPOINT_URL`. To provide S3 credentials directly, you can provide +`--s3-access-key-id` and `--s3-secret-access-key`, as well as `--s3-endpoint` +as CLI args to this script. You can also encrypt the model weights with a randomly-generated key by providing a `--keyfile` argument. @@ -57,7 +54,7 @@ --model EleutherAI/gpt-j-6B \ --dtype float16 \ deserialize \ - --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors + --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors Which downloads the model tensors from your S3 bucket and deserializes them. @@ -71,26 +68,30 @@ `python -m examples.tensorize_vllm_model deserialize --help`. -Once a model is serialized, it can be used to load the model when running the -OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing -the `--tensorizer-uri` CLI argument that is functionally the same as the -`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to -signify that the model to be deserialized is a vLLM model, rather than a -HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer -in the same inference server, albeit without the speed optimizations. To -deserialize an encrypted file, the `--encryption-keyfile` argument can be used -to provide the path to the keyfile used to encrypt the model weights. For -information on all the arguments that can be used to configure tensorizer's -deserialization, check out the tensorizer options argument group in the -`vllm/entrypoints/openai/api_server.py` script with `--help`. - -Tensorizer can also be invoked with the `LLM` class directly to load models: +Once a model is serialized, tensorizer can be invoked with the `LLM` class +directly to load models: llm = LLM(model="facebook/opt-125m", load_format="tensorizer", - tensorizer_uri=path_to_opt_tensors, - num_readers=3, - vllm_tensorized=True) + model_loader_extra_config=TensorizerConfig( + tensorizer_uri = path_to_tensors, + num_readers=3, + ) + ) + +A serialized model can be used during model loading for the vLLM OpenAI +inference server. `model_loader_extra_config` is exposed as the CLI arg +`--model-loader-extra-config`, and accepts a JSON string literal of the +TensorizerConfig arguments desired. + +In order to see all of the available arguments usable to configure +loading with tensorizer that are given to `TensorizerConfig`, run: + +`python -m examples.tensorize_vllm_model deserialize --help` + +under the `tensorizer options` section. These can also be used for +deserialization in this example script, although `--tensorizer-uri` and +`--path-to-tensors` are functionally the same in this case. """ @@ -158,95 +159,35 @@ def parse_args(): help=("Path to a binary key to use to decrypt the model weights," " if the model was serialized with encryption")) - return parser.parse_args() - - -def make_model_contiguous(model): - # Ensure tensors are saved in memory contiguously - for param in model.parameters(): - param.data = param.data.contiguous() - - -def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: - architectures = getattr(config, "architectures", []) - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - return model_cls - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") - - -def serialize(): - - eng_args_dict = {f.name: getattr(args, f.name) for f in - dataclasses.fields(EngineArgs)} - engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) - engine = LLMEngine.from_engine_args(engine_args) + TensorizerArgs.add_cli_args(deserialize_parser) - model = (engine.model_executor.driver_worker. - model_runner.model) - - encryption_params = EncryptionParams.random() if keyfile else None - if keyfile: - with _write_stream(keyfile) as stream: - stream.write(encryption_params.key) - - with _write_stream(model_path) as stream: - serializer = TensorSerializer(stream, encryption=encryption_params) - serializer.write_module(model) - serializer.close() + return parser.parse_args() - print("Serialization complete. Model tensors saved to", model_path) - if keyfile: - print("Key saved to", keyfile) def deserialize(): - config = AutoConfig.from_pretrained(model_ref) - - with no_init_or_tensor(): - model_class = _get_vllm_model_architecture(config) - model = model_class(config) - - before_mem = get_mem_usage() - start = time.time() - - if keyfile: - with _read_stream(keyfile) as stream: - key = stream.read() - decryption_params = DecryptionParams.from_key(key) - tensorizer_args.deserializer_params['encryption'] = \ - decryption_params - - with (_read_stream(model_path)) as stream, TensorDeserializer( - stream, **tensorizer_args.deserializer_params) as deserializer: - deserializer.load_into_module(model) - end = time.time() - - # Brag about how fast we are. - total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) - duration = end - start - per_second = convert_bytes(deserializer.total_tensor_bytes / duration) - after_mem = get_mem_usage() - print( - f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" + llm = LLM(model=args.model, + load_format="tensorizer", + model_loader_extra_config=tensorizer_config ) - print(f"Memory usage before: {before_mem}") - print(f"Memory usage after: {after_mem}") + return llm - return model args = parse_args() -s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") - or None) -s3_secret_access_key = (args.s3_secret_access_key - or os.environ.get("S3_SECRET_ACCESS_KEY") or None) +s3_access_key_id = (getattr(args, 's3_access_key_id', None) + or os.environ.get("S3_ACCESS_KEY_ID", None)) +s3_secret_access_key = (getattr(args, 's3_secret_access_key', None) + or os.environ.get("S3_SECRET_ACCESS_KEY", None)) +s3_endpoint = (getattr(args, 's3_endpoint', None) + or os.environ.get("S3_ENDPOINT_URL", None)) -s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) +credentials = { + "s3_access_key_id": s3_access_key_id, + "s3_secret_access_key": s3_secret_access_key, + "s3_endpoint": s3_endpoint +} _read_stream, _write_stream = (partial( stream_io.open_stream, @@ -263,20 +204,41 @@ def deserialize(): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "8080" -torch.distributed.init_process_group(world_size=1, rank=0) +init_distributed_environment(world_size=1, rank=0, local_rank=0) initialize_model_parallel() keyfile = args.keyfile if args.keyfile else None + +if args.model_loader_extra_config: + config = json.loads(args.model_loader_extra_config) + tensorizer_args = TensorizerConfig(**config)._construct_tensorizer_args() + tensorizer_args.tensorizer_uri = args.path_to_tensors +else: + tensorizer_args = None + if args.command == "serialize": + eng_args_dict = {f.name: getattr(args, f.name) for f in + dataclasses.fields(EngineArgs)} + + engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) + engine = LLMEngine.from_engine_args(engine_args) + input_dir = args.serialized_directory.rstrip('/') suffix = args.suffix if args.suffix else uuid.uuid4().hex base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" model_path = f"{base_path}/model.tensors" - serialize() + tensorizer_config = TensorizerConfig( + tensorizer_uri=model_path, + **credentials) + serialize_vllm_model(engine, tensorizer_config, keyfile) elif args.command == "deserialize": - tensorizer_args = TensorizerArgs.from_cli_args(args) - model_path = args.path_to_tensors + if not tensorizer_args: + tensorizer_config = TensorizerConfig( + tensorizer_uri=args.path_to_tensors, + encryption_keyfile = keyfile, + **credentials + ) deserialize() else: raise ValueError("Either serialize or deserialize must be specified.") diff --git a/requirements-dev.txt b/requirements-dev.txt index 796c9e37d023..4f6c27d95fe6 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,7 +14,7 @@ types-setuptools # testing pytest -tensorizer==2.9.0 +tensorizer>=2.9.0 pytest-forked pytest-asyncio pytest-rerunfailures diff --git a/setup.py b/setup.py index 0dc8818b44a9..a66af2c5d556 100644 --- a/setup.py +++ b/setup.py @@ -426,7 +426,7 @@ def _read_requirements(filename: str) -> List[str]: install_requires=get_requirements(), ext_modules=ext_modules, extras_require={ - "tensorizer": ["tensorizer==2.9.0"], + "tensorizer": ["tensorizer>=2.9.0"], }, cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, diff --git a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py deleted file mode 100644 index 0e113ab647e6..000000000000 --- a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py +++ /dev/null @@ -1,245 +0,0 @@ -import argparse -import dataclasses -import os -import time -import uuid -from functools import partial -from typing import Type - -import torch.nn as nn -from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, - TensorSerializer, stream_io) -from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor -from transformers import AutoConfig, PretrainedConfig - -from vllm.distributed import (init_distributed_environment, - initialize_model_parallel) -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine -from vllm.model_executor.model_loader.tensorizer import TensorizerArgs -from vllm.model_executor.models import ModelRegistry - -# yapf conflicts with isort for this docstring -# yapf: disable -""" -tensorize_vllm_model.py is a script that can be used to serialize and -deserialize vLLM models. These models can be loaded using tensorizer directly -to the GPU extremely quickly. Tensor encryption and decryption is also -supported, although libsodium must be installed to use it. Install -vllm with tensorizer support using `pip install vllm[tensorizer]`. - -To serialize a model, you can run something like this: - -python tensorize_vllm_model.py \ - --model EleutherAI/gpt-j-6B \ - --dtype float16 \ - serialize \ - --serialized-directory s3://my-bucket/ \ - --suffix vllm - -Which downloads the model from HuggingFace, loads it into vLLM, serializes it, -and saves it to your S3 bucket. A local directory can also be used. - -You can also encrypt the model weights with a randomly-generated key by -providing a `--keyfile` argument. - -To deserialize a model, you can run something like this: - -python tensorize_vllm_model.py \ - --model EleutherAI/gpt-j-6B \ - --dtype float16 \ - deserialize \ - --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors - -Which downloads the model tensors from your S3 bucket and deserializes them. -To provide S3 credentials, you can provide `--s3-access-key-id` and -`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, -the OpenAI entrypoint, as arguments for LLM(), or as environment variables -in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. - - -You can also provide a `--keyfile` argument to decrypt the model weights if -they were serialized with encryption. - -For more information on the available arguments, run -`python tensorize_vllm_model.py --help`. -""" - - -def parse_args(): - parser = argparse.ArgumentParser( - description="An example script that can be used to serialize and " - "deserialize vLLM models. These models " - "can be loaded using tensorizer directly to the GPU " - "extremely quickly. Tensor encryption and decryption is " - "also supported, although libsodium must be installed to " - "use it.") - parser = TensorizerArgs.add_cli_args(EngineArgs.add_cli_args(parser)) - subparsers = parser.add_subparsers(dest='command') - - serialize_parser = subparsers.add_parser( - 'serialize', help="Serialize a model to `--serialized-directory`") - - serialize_parser.add_argument( - "--suffix", - type=str, - required=False, - help=( - "The suffix to append to the serialized model directory, which is " - "used to construct the location of the serialized model tensors, " - "e.g. if `--serialized-directory` is `s3://my-bucket/` and " - "`--suffix` is `v1`, the serialized model tensors will be " - "saved to " - "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " - "If none is provided, a random UUID will be used.")) - serialize_parser.add_argument( - "--serialized-directory", - type=str, - required=True) - - serialize_parser.add_argument( - "--keyfile", - type=str, - required=False, - help=("Encrypt the model weights with a randomly-generated binary key," - " and save the key at this path")) - - deserialize_parser = subparsers.add_parser( - 'deserialize', - help=("Deserialize a model from `--path-to-tensors`" - " to verify it can be loaded and used.")) - - deserialize_parser.add_argument( - "--path-to-tensors", - type=str, - required=True, - help="The local path or S3 URI to the model tensors to deserialize. ") - - deserialize_parser.add_argument( - "--keyfile", - type=str, - required=False, - help=("Path to a binary key to use to decrypt the model weights," - " if the model was serialized with encryption")) - - return parser.parse_args() - - -def make_model_contiguous(model): - # Ensure tensors are saved in memory contiguously - for param in model.parameters(): - param.data = param.data.contiguous() - - -def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: - architectures = getattr(config, "architectures", []) - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - return model_cls - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") - - -def serialize(): - eng_args_dict = {f.name: getattr(args, f.name) for f in - dataclasses.fields(EngineArgs)} - engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) - engine = LLMEngine.from_engine_args(engine_args) - - model = (engine.model_executor.driver_worker. - model_runner.model) - - encryption_params = EncryptionParams.random() if keyfile else None - if keyfile: - with _write_stream(keyfile) as stream: - stream.write(encryption_params.key) - - with _write_stream(model_path) as stream: - serializer = TensorSerializer(stream, encryption=encryption_params) - serializer.write_module(model) - serializer.close() - - print("Serialization complete. Model tensors saved to", model_path) - if keyfile: - print("Key saved to", keyfile) - - -def deserialize(): - config = AutoConfig.from_pretrained(model_ref) - - with no_init_or_tensor(): - model_class = _get_vllm_model_architecture(config) - model = model_class(config) - - before_mem = get_mem_usage() - start = time.time() - - if keyfile: - with _read_stream(keyfile) as stream: - key = stream.read() - decryption_params = DecryptionParams.from_key(key) - tensorizer_args.deserializer_params['encryption'] = \ - decryption_params - - with (_read_stream(model_path)) as stream, TensorDeserializer( - stream, **tensorizer_args.deserializer_params) as deserializer: - deserializer.load_into_module(model) - end = time.time() - - # Brag about how fast we are. - total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) - duration = end - start - per_second = convert_bytes(deserializer.total_tensor_bytes / duration) - after_mem = get_mem_usage() - print( - f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" - ) - print(f"Memory usage before: {before_mem}") - print(f"Memory usage after: {after_mem}") - - return model - - -args = parse_args() - -s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") - or None) -s3_secret_access_key = (args.s3_secret_access_key - or os.environ.get("S3_SECRET_ACCESS_KEY") or None) - -s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) - -_read_stream, _write_stream = (partial( - stream_io.open_stream, - mode=mode, - s3_access_key_id=s3_access_key_id, - s3_secret_access_key=s3_secret_access_key, - s3_endpoint=s3_endpoint, -) for mode in ("rb", "wb+")) - -model_ref = args.model - -model_name = model_ref.split("/")[1] - -os.environ["MASTER_ADDR"] = "127.0.0.1" -os.environ["MASTER_PORT"] = "8080" - -init_distributed_environment(world_size=1, rank=0, local_rank=0) -initialize_model_parallel() - -keyfile = args.keyfile if args.keyfile else None - -if args.command == "serialize": - input_dir = args.serialized_directory.rstrip('/') - suffix = args.suffix if args.suffix else uuid.uuid4().hex - base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" - model_path = f"{base_path}/model.tensors" - serialize() -elif args.command == "deserialize": - tensorizer_args = TensorizerArgs.from_cli_args(args) - model_path = args.path_to_tensors - deserialize() -else: - raise ValueError("Either serialize or deserialize must be specified.") diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index ad4748c5ebe9..1579d53a7fe2 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -10,12 +10,19 @@ import torch from vllm import SamplingParams -from vllm.model_executor.model_loader.tensorizer import ( - EncryptionParams, TensorizerConfig, TensorSerializer, - is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream) +# yapf: disable +from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, + TensorSerializer, + is_vllm_tensorized, + load_with_tensorizer, + open_stream, + serialize_vllm_model) from ..utils import ServerRunner +# yapf conflicts with isort for this docstring + + prompts = [ "Hello, my name is", "The president of the United States is", @@ -40,7 +47,7 @@ def is_curl_installed(): @pytest.fixture(autouse=True) def tensorizer_config(): - config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True) + config = TensorizerConfig(tensorizer_uri="vllm") return config @@ -59,47 +66,6 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config): assert result == mock_agent_instance.deserialize.return_value -def test_is_vllm_model_with_vllm_in_uri(tensorizer_config): - tensorizer_config.vllm_tensorized = True - - result = is_vllm_serialized_tensorizer(tensorizer_config) - - assert result is True - - -def test_is_vllm_model_without_vllm_in_uri(tensorizer_config): - tensorizer_config.vllm_tensorized = False - - result = is_vllm_serialized_tensorizer(tensorizer_config) - - assert result is False - - -def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path): - vllm_model = vllm_runner(model_ref) - model_path = tmp_path / (model_ref + ".tensors") - outputs = vllm_model.generate(prompts, sampling_params) - model = (vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) - with open_stream(model_path, "wb+") as stream: - serializer = TensorSerializer(stream) - serializer.write_module(model) - del vllm_model, model - gc.collect() - torch.cuda.empty_cache() - loaded_vllm_model = vllm_runner( - model_ref, - load_format="tensorizer", - model_loader_extra_config=TensorizerConfig(tensorizer_uri=model_path, - num_readers=1, - vllm_tensorized=True), - ) - deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) - - # Assumes SamplingParams being seeded ensures the outputs are deterministic - assert outputs == deserialized_outputs - - @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_can_deserialize_s3(vllm_runner): model_ref = "EleutherAI/pythia-1.4b" @@ -110,7 +76,6 @@ def test_can_deserialize_s3(vllm_runner): model_loader_extra_config=TensorizerConfig( tensorizer_uri=tensorized_path, num_readers=1, - vllm_tensorized=False, s3_endpoint="object.ord1.coreweave.com", )) @@ -126,29 +91,26 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( model_path = tmp_path / (model_ref + ".tensors") key_path = tmp_path / (model_ref + ".key") outputs = vllm_model.generate(prompts, sampling_params) - model = (vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) - encryption_params = EncryptionParams.random() - with open_stream(model_path, "wb+") as stream: - serializer = TensorSerializer(stream, encryption=encryption_params) - serializer.write_module(model) - with open_stream(key_path, "wb+") as stream: - stream.write(encryption_params.key) - del vllm_model, model + config_for_serializing = TensorizerConfig(tensorizer_uri=model_path) + serialize_vllm_model(vllm_model.model.llm_engine, + config_for_serializing, + encryption_key_path=key_path) + + del vllm_model gc.collect() torch.cuda.empty_cache() - loaded_vllm_model = vllm_runner(model_ref, - load_format="tensorizer", - model_loader_extra_config=TensorizerConfig( - tensorizer_uri=model_path, - encryption_keyfile=key_path, - num_readers=1, - vllm_tensorized=True)) + + config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path, + encryption_keyfile=key_path) + + loaded_vllm_model = vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=config_for_deserializing) deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) - # Assumes SamplingParams being seeded ensures the outputs are deterministic assert outputs == deserialized_outputs @@ -169,7 +131,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, model_loader_extra_config=TensorizerConfig( tensorizer_uri=model_path, num_readers=1, - vllm_tensorized=False)) + )) deserialized_outputs = loaded_hf_model.generate_greedy( prompts, max_tokens=max_tokens) @@ -190,12 +152,11 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): # Serialize model before deserializing and binding LoRA adapters vllm_model = vllm_runner(model_ref, ) model_path = tmp_path / (model_ref + ".tensors") - model = (vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) - with open_stream(model_path, "wb+") as stream: - serializer = TensorSerializer(stream) - serializer.write_module(model) - del vllm_model, model + + serialize_vllm_model(vllm_model.model.llm_engine, + TensorizerConfig(tensorizer_uri=model_path)) + + del vllm_model gc.collect() torch.cuda.empty_cache() loaded_vllm_model = vllm_runner( @@ -204,7 +165,6 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): model_loader_extra_config=TensorizerConfig( tensorizer_uri=model_path, num_readers=1, - vllm_tensorized=True, ), enable_lora=True, max_loras=1, @@ -220,58 +180,28 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): def test_load_without_tensorizer_load_format(vllm_runner): with pytest.raises(ValueError): - vllm_runner(model_ref, - model_loader_extra_config=TensorizerConfig( - tensorizer_uri="test", vllm_tensorized=False)) - - -@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") -def test_tensorize_vllm_model(tmp_path): - # Test serialize command - serialize_args = [ - "python3", tensorize_model_for_testing_script, "--model", model_ref, - "--dtype", "float16", "serialize", "--serialized-directory", tmp_path, - "--suffix", "tests" - ] - result = subprocess.run(serialize_args, capture_output=True, text=True) - print(result.stdout) # Print the output of the serialize command - - assert result.returncode == 0, (f"Serialize command failed with output:" - f"\n{result.stdout}\n{result.stderr}") - - path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" - - # Test deserialize command - deserialize_args = [ - "python3", tensorize_model_for_testing_script, "--model", model_ref, - "--dtype", "float16", "deserialize", "--path-to-tensors", - path_to_tensors - ] - result = subprocess.run(deserialize_args, capture_output=True, text=True) - assert result.returncode == 0, (f"Deserialize command failed with output:" - f"\n{result.stdout}\n{result.stderr}") + vllm_runner( + model_ref, + model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") -def test_openai_apiserver_with_tensorizer(tmp_path): +def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): ## Serialize model - serialize_args = [ - "python3", tensorize_model_for_testing_script, "--model", model_ref, - "--dtype", "float16", "serialize", "--serialized-directory", tmp_path, - "--suffix", "tests" - ] - result = subprocess.run(serialize_args, capture_output=True, text=True) - print(result.stdout) # Print the output of the serialize command + vllm_model = vllm_runner(model_ref, ) + model_path = tmp_path / (model_ref + ".tensors") - assert result.returncode == 0, (f"Serialize command failed with output:" - f"\n{result.stdout}\n{result.stderr}") + serialize_vllm_model(vllm_model.model.llm_engine, + TensorizerConfig(tensorizer_uri=model_path)) - path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" model_loader_extra_config = { - "tensorizer_uri": path_to_tensors, - "vllm_tensorized": True + "tensorizer_uri": str(model_path), } + del vllm_model + gc.collect() + torch.cuda.empty_cache() + ## Start OpenAI API server openai_args = [ "--model", model_ref, "--dtype", "float16", "--load-format", @@ -304,10 +234,10 @@ def test_openai_apiserver_with_tensorizer(tmp_path): def test_raise_value_error_on_invalid_load_format(vllm_runner): with pytest.raises(ValueError): - vllm_runner(model_ref, - load_format="safetensors", - model_loader_extra_config=TensorizerConfig( - tensorizer_uri="test", vllm_tensorized=False)) + vllm_runner( + model_ref, + load_format="safetensors", + model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) def test_tensorizer_with_tp(vllm_runner): @@ -321,8 +251,29 @@ def test_tensorizer_with_tp(vllm_runner): model_loader_extra_config=TensorizerConfig( tensorizer_uri=tensorized_path, num_readers=1, - vllm_tensorized=False, s3_endpoint="object.ord1.coreweave.com", ), tensor_parallel_size=2, ) + + +def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path)) + + vllm_model = vllm_runner(model_ref) + outputs = vllm_model.generate(prompts, sampling_params) + serialize_vllm_model(vllm_model.model.llm_engine, config) + + assert is_vllm_tensorized(config) + del vllm_model + gc.collect() + torch.cuda.empty_cache() + + loaded_vllm_model = vllm_runner(model_ref, + load_format="tensorizer", + model_loader_extra_config=config) + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + + assert outputs == deserialized_outputs diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 163723b4be36..fd5338c46c34 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -167,8 +167,8 @@ def add_cli_args( '* "dummy" will initialize the weights with random values, ' 'which is mainly for profiling.\n' '* "tensorizer" will load the weights using tensorizer from ' - 'CoreWeave which assumes tensorizer_uri is set to the location of ' - 'the serialized weights.') + 'CoreWeave. See the Tensorize vLLM Model script in the Examples' + 'section for more information.\n') parser.add_argument( '--dtype', type=str, diff --git a/vllm/envs.py b/vllm/envs.py index 91cc8f3be775..68d8a074d091 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -145,7 +145,7 @@ # S3 access information, used for tensorizer to load model from S3 "S3_ACCESS_KEY_ID": - lambda: os.environ.get("S3_ACCESS_KEY", None), + lambda: os.environ.get("S3_ACCESS_KEY_ID", None), "S3_SECRET_ACCESS_KEY": lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), "S3_ENDPOINT_URL": diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index fc9c8aa0af44..b14824a359b6 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer, + TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, tensorizer_weights_iterator) from vllm.model_executor.model_loader.utils import (get_model_architecture, set_default_torch_dtype) @@ -291,7 +291,7 @@ def _get_weights_iterator( tensorizer_args = self.tensorizer_config._construct_tensorizer_args() return tensorizer_weights_iterator(tensorizer_args) - def _load_model_unserialized( + def _load_model_serialized_cpu( self, model_config: ModelConfig, device_config: DeviceConfig, @@ -299,11 +299,12 @@ def _load_model_unserialized( vision_language_config: Optional[VisionLanguageConfig], cache_config: CacheConfig, ) -> nn.Module: - """Load an unserialized model with tensorizer. + """Load a serialized model with tensorizer to the CPU. - Unserialized here means "not serialized with tensorizer". This - should still be faster than default HuggingFace loading, but will - be slower than loading a tensorizer-serialized model. + This is only necessary when the model isn't vLLM-tensorized (see + examples/tensorize_vllm_model.py) This should still be faster than + default HuggingFace loading, but will be slower than loading a + vLLM-tensorized model. """ with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): @@ -324,8 +325,9 @@ def _load_model_serialized( ) -> nn.Module: """Load a serialized model with tensorizer. - See the examples/tensorize_vllm_model.py example " - script for serializing vLLM models.""" + Expects a vLLM-tensorized model. See the + examples/tensorize_vllm_model.py example script + for serializing vLLM models.""" with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model_class = get_model_architecture(model_config)[0] @@ -353,15 +355,15 @@ def load_model(self, *, model_config: ModelConfig, cache_config: CacheConfig) -> nn.Module: self._verify_config(model_config, parallel_config) - if is_vllm_serialized_tensorizer(self.tensorizer_config): + if is_vllm_tensorized(self.tensorizer_config): return self._load_model_serialized(model_config, device_config, lora_config, vision_language_config, cache_config) - return self._load_model_unserialized(model_config, device_config, - lora_config, - vision_language_config, - cache_config) + return self._load_model_serialized_cpu(model_config, device_config, + lora_config, + vision_language_config, + cache_config) def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 219a2a392e12..2cf4ce5f8852 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -5,6 +5,7 @@ import time import typing from dataclasses import dataclass +from functools import partial from typing import Generator, Optional, Tuple, Type, Union import torch @@ -13,6 +14,7 @@ import vllm.envs as envs from vllm.config import ModelConfig, ParallelConfig +from vllm.engine.llm_engine import LLMEngine from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -27,6 +29,11 @@ from tensorizer.stream_io import open_stream from tensorizer.utils import (convert_bytes, get_mem_usage, no_init_or_tensor) + + _read_stream, _write_stream = (partial( + open_stream, + mode=mode, + ) for mode in ("rb", "wb+")) except ImportError as e: tensorizer_error_msg = str(e) @@ -43,7 +50,7 @@ class TensorizerConfig: tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, str, bytes, os.PathLike, int] - vllm_tensorized: bool + vllm_tensorized: Optional[bool] = False verify_hash: Optional[bool] = False num_readers: Optional[int] = None encryption_keyfile: Optional[str] = None @@ -93,17 +100,11 @@ def load_with_tensorizer(tensorizer_config: TensorizerConfig, return tensorizer.deserialize() -def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool: - if tensorizer_config is None: - return False - return tensorizer_config.vllm_tensorized - - @dataclass class TensorizerArgs: tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, str, bytes, os.PathLike, int] - vllm_tensorized: bool + vllm_tensorized: Optional[bool] = False verify_hash: Optional[bool] = False num_readers: Optional[int] = None encryption_keyfile: Optional[str] = None @@ -121,7 +122,9 @@ class TensorizerArgs: vLLM model. This is used to determine the behavior of the TensorDeserializer when loading tensors from a serialized model. It is far faster to deserialize a vLLM model as it utilizes - tensorizer's optimized GPU loading. + tensorizer's optimized GPU loading. Note that this is now + deprecated, as serialized vLLM models are now automatically + inferred as vLLM models. verify_hash: If True, the hashes of each tensor will be verified against the hashes stored in the metadata. A `HashMismatchError` will be raised if any of the hashes do not match. @@ -158,6 +161,7 @@ def __post_init__(self): "encryption": self.encryption_keyfile, "num_readers": self.num_readers } + if self.encryption_keyfile: with open_stream( self.encryption_keyfile, @@ -177,7 +181,14 @@ def add_cli_args( 'tensorizer options', description=('Options for configuring the behavior of the' ' tensorizer deserializer when ' - '--load-format=tensorizer')) + 'load_format=tensorizer is specified when ' + 'initializing an LLMEngine, either via the CLI ' + 'when running the vLLM OpenAI inference server ' + 'with a JSON string passed to ' + '--model-loader-extra-config or as arguments given ' + 'to TensorizerConfig when passed to ' + 'model_loader_extra_config in the constructor ' + 'for LLMEngine.')) group.add_argument( "--tensorizer-uri", @@ -222,13 +233,6 @@ def add_cli_args( help="The endpoint for the S3 bucket. Can also be set via the " "S3_ENDPOINT_URL environment variable.", ) - group.add_argument( - "--vllm-tensorized", - action="store_true", - help="If enabled, indicates that the serialized model is a vLLM " - "model. This is used to determine the behavior of the " - "TensorDeserializer when loading tensors from a " - "serialized model.") return parser @@ -322,10 +326,9 @@ def deserialize(self): """ before_mem = get_mem_usage() start = time.perf_counter() - with open_stream( - self.tensorizer_args.tensorizer_uri, - mode="rb", - **self.tensorizer_args.stream_params, + with _read_stream( + self.tensorizer_config.tensorizer_uri, + **self.tensorizer_args.stream_params ) as stream, TensorDeserializer( stream, dtype=self.tensorizer_config.dtype, @@ -345,6 +348,7 @@ def deserialize(self): self._check_tensors_on_meta_device() self._resize_lora_embeddings() + del self.model.vllm_tensorized_marker return self.model.eval() @@ -366,3 +370,63 @@ def tensorizer_weights_iterator( for name, param in state.items(): yield name, param del state + + +def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: + """ + Infer if the model is a vLLM model by checking the weights for + a vLLM tensorized marker. + + Args: + tensorizer_config: The TensorizerConfig object containing the + tensorizer_uri to the serialized model. + + Returns: + bool: True if the model is a vLLM model, False otherwise. + """ + tensorizer_args = tensorizer_config._construct_tensorizer_args() + deserializer = TensorDeserializer(open_stream( + tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params), + **tensorizer_args.deserializer_params, + lazy_load=True) + if tensorizer_config.vllm_tensorized: + logger.warning( + "Please note that newly serialized vLLM models are automatically " + "inferred as vLLM models, so setting vllm_tensorized=True is " + "only necessary for models serialized prior to this change.") + return True + if (".vllm_tensorized_marker" in deserializer): + return True + return False + + +def get_pretensorized_vllm_model(engine: "LLMEngine") -> nn.Module: + model = (engine.model_executor.driver_worker.model_runner.model) + model.register_parameter( + "vllm_tensorized_marker", + nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) + return model + + +def serialize_vllm_model(engine: "LLMEngine", + tensorizer_config : TensorizerConfig, + encryption_key_path: Optional[str] = None) \ + -> nn.Module: + + model = get_pretensorized_vllm_model(engine) + tensorizer_args = tensorizer_config._construct_tensorizer_args() + encryption_params = None + if encryption_key_path is not None: + encryption_params = EncryptionParams.random() + with _write_stream(encryption_key_path, + **tensorizer_args.stream_params) as stream: + stream.write(encryption_params.key) + + with _write_stream(tensorizer_args.tensorizer_uri, + **tensorizer_args.stream_params) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + logger.info("Successfully serialized model to %s", + str(tensorizer_args.tensorizer_uri)) + return model