Skip to content

Add support for save quantized checkpoint in llama code #553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions scripts/hf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def format_value(value):

print(tabulate(main_table, headers=['Task', 'Metrics'], tablefmt='grid'))

def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length):
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, save, batch_size, max_length):

tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)

if compile:
if quantization == "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)

if quantization == "int8dq":
Expand All @@ -57,6 +57,10 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
quantize_(model.to(device=device), int4_weight_only())
elif quantization == "autoquant":
model = autoquant(model.to(device=device))

if quantization != "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)

with torch.no_grad():
result = evaluate(
HFLM(
Expand All @@ -70,6 +74,12 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi

pretty_print_nested_results(result)

if save:
# This doesn't work yet: https://github.com/huggingface/transformers/issues/32364
# model.save_pretrained("quantized_model_test", safe_serialization=False)
file_name = repo_id.split("/")[-1] + "-" + quantization + ".pt"
torch.save(model.state_dict(), file_name)


if __name__ == '__main__':
import argparse
Expand All @@ -81,8 +91,9 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--save', action='store_true', help='Whether to save the model.')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')

args = parser.parse_args()
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length)
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.save, args.batch_size, args.max_length)
10 changes: 9 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import sys
import time
from pathlib import Path
Expand Down Expand Up @@ -165,6 +166,7 @@ def main(
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
quantization: Optional[str] = None,
kv_cache_quantization: bool = False,
save: bool = False,
compile: bool = True,
compile_prefill: bool = False,
profile: Optional[Path] = None,
Expand Down Expand Up @@ -238,6 +240,11 @@ def main(

model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9

if save:
output_dir = str(checkpoint_path.cwd())
filename = str(checkpoint_path.name).split(".")[0]
torch.save(model.state_dict(), os.path.join(output_dir, filename + f"-{quantization}.pt"))

if compile:
print("Compiling Model")
global decode_one_token, prefill
Expand Down Expand Up @@ -362,6 +369,7 @@ def callback(x):
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
parser.add_argument('--save', action='store_true', help='Whether to save the quantized model.')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
Expand All @@ -372,5 +380,5 @@ def callback(x):
args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.save, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
)
Loading