Skip to content

add CI to disallow syntax errors and undefined vars in all Python files #861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ruff_linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ jobs:
- name: Analyzing the code with ruff
run: |
ruff check .
- name: Check all Python files for syntax errors (E999) and undefined vars (F821)
run: |
ruff check --isolated --select E999,F821
- name: Check well formatted code
run: |
ruff format --check
6 changes: 4 additions & 2 deletions benchmarks/benchmark_gpu_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def run_gpu_sparse_benchmark(m, k, n, args):
elif args.eval_fn == "mm":
dense_output = torch.mm(A, x.t())
sparse_output = torch.mm(A_sparse, x.t())
dense_time = benchmark_in_us(torch.mm, A, x.t())
sparse_time = benchmark_in_us(torch.mm, A_sparse, x.t())
# dense_time = benchmark_in_us(torch.mm, A, x.t())
# sparse_time = benchmark_in_us(torch.mm, A_sparse, x.t())
# TODO(future PR) fixme
dense_time, sparse_time = 1.0, 1.0
else:
raise ValueError(f"Unknown eval_fn: {args.eval_fn}")

Expand Down
2 changes: 2 additions & 0 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def run(
scale_b = torch.tensor([1.0], device=device)

def do_matmul(A, B):
nonlocal scale_a
nonlocal scale_b
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)
Expand Down
2 changes: 1 addition & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):


def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
f"need input_tensor shape: {input_tensor.shape} final"
f"dim to match weight_tensor shape: {weight_tensor.shape} second dim "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.multiprocessing as mp
from ax.modelbridge.cross_validation import cross_validate
from utils import write_history_to_csv, cal_wikitext_ppl, cal_model_size, load_model, quantize_by_fqn_to_config, load_parameters_from_json, load_initial_samples
from BO_acc_throughput import define_parameter_list

# return evaluation results to complete BO trials
def eval(model, tokenizer, num_PPL_eval_samples, fqn_to_config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
_load_model,
)

from utils import write_history_to_csv, cal_wikitext_ppl, load_model, quantize_by_fqn_to_config, load_parameters_from_json
from utils import write_history_to_csv, cal_wikitext_ppl, load_model, quantize_by_fqn_to_config, load_parameters_from_json, load_initial_samples

default_device = 'cuda' if torch.cuda.is_available() else 'cpu'

Expand Down Expand Up @@ -380,6 +380,8 @@ def run_sequential_BO(device, checkpoint_path, repo_id, num_PPL_eval_samples, nu
parameters_list = load_parameters_from_json(args.parameters_list)

# sample initial points
# TODO(future PR): fix me
initial_samples = []
initial_points_set = load_initial_samples(initial_samples)
num_BO_initial_samples = len(initial_points_set)

Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class Int8DynamicallyQuantizedLinearWeight(QuantizedLinearWeightBase):
@staticmethod
def __new__(cls, int_data, q_scales, transposed, shape, dtype=None, **kwargs):
if dtype is None:
dtype = qscales.dtype
dtype = q_scales.dtype
kwargs["dtype"] = dtype
return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined]

Expand Down
4 changes: 2 additions & 2 deletions torchao/sparsity/prototype/superblock/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from torchao.sparsity import sparsify_, semi_sparse_weight
from torchao.sparsity.prototype.superblock.supermask import apply_supermask
from torchao.sparsity.prototype.superblock.utils import apply_sparsity, verify_sparsity, mlp_only_with_args
from torchao.sparsity.prototype.superblock.utils import apply_sparsity, verify_sparsity, mlp_only_with_args, simulate_sparsity, accelerate_with_sparsity
from torchao.sparsity.prototype.superblock.train import evaluate, _get_cache_path, load_data
from torchao.sparsity.prototype.sparsifier.weight_norm_sparsifier import WeightNormSparsifier

Expand Down Expand Up @@ -56,7 +56,7 @@ def main(args):
model.to(device).bfloat16()

if sparsifier_or_none is not None:
sparsifier.squash_mask()
sparsifier_or_none.squash_mask()
accelerate_with_sparsity(model, args)

criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
Expand Down
1 change: 1 addition & 0 deletions torchao/sparsity/prototype/superblock/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch.distributed as dist

from torchao.quantization import quantize_, int8_dynamic_activation_int8_semi_sparse_weight
from torchao.sparsity import sparsify_, semi_sparse_weight
from torchao.sparsity.prototype.superblock.supermask import SupermaskLinear, apply_supermask
from torchao.sparsity.prototype.superblock.blocksparse import block_sparse_weight
Expand Down
2 changes: 1 addition & 1 deletion tutorials/developer_api_guide/my_dtype_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
LayoutType,
PlainLayoutType,
)
from torchao.utils import TorchAOBaseTensor
from torchao.utils import TorchAOBaseTensor, _register_layout_cls, _get_layout_tensor_constructor

aten = torch.ops.aten

Expand Down
Loading