Skip to content

[*.py] Resolve pylint w-class: W0102,W0107,W0212,W0221,W0223,W0237,W0404,W0611,W0612,W0621,W0622,W0631,W0707,W0718,W1201,W1203,W1309,W1514,W4901 ; [code_style.sh,.github/workflows/CPUTests.yml] Enable w-class #1749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CPUTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
pytype --jobs auto --disable 'import-error,late-directive,wrong-arg-types,module-attr,unsupported-operands' MaxText/ || true
- name: Analysing the code with pylint in Maxtext/
run: |
pylint --verbose --msg-template='[{abspath}] {msg_id}:{line:3d},{column}: {obj}: {msg}' --disable R0401,R0917,W0102,W0107,W0201,W0212,W0221,W0223,W0237,W0404,W0611,W0612,W0613,W0621,W0622,W0631,W0707,W0718,W1201,W1203,W1309,W1514,W4901 MaxText/ && \
pylint --verbose --msg-template='[{abspath}] {msg_id}:{line:3d},{column}: {obj}: {msg}' --disable R0401,R0917,W0201,W0613 MaxText/ && \
echo 'Maxtext PyLint check successful' || { echo \
'PyLint check has failed. Please run bash code_style.sh to fix issues'; exit 20; }
- name: Analysing the code with pylint in pedagogical_examples/
Expand Down
1 change: 0 additions & 1 deletion MaxText/convert_gpt3_ckpt_from_paxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def fmt_size(num_bytes: int) -> str:
return f"{num_bytes:.2f} {unit}"



def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name):
"""convert ckpt."""

Expand Down
7 changes: 4 additions & 3 deletions MaxText/deepseek_fp8_to_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
import json
from argparse import ArgumentParser
from glob import glob
import string

from tqdm import tqdm

import torch

from safetensors.torch import load_file, save_file


Expand Down Expand Up @@ -93,7 +94,7 @@ def convert_fp8_to_bf16(fp8_path: str, bf16_path: str, cache_file_num: int = 2):
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
with open(model_index_file, "rt", encoding="utf8") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]

Expand Down Expand Up @@ -159,7 +160,7 @@ def get_tensor(tensor_name):
if scale_inv_name in weight_map:
weight_map.pop(scale_inv_name)
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
with open(new_model_index_file, "w") as f:
with open(new_model_index_file, "wt", encoding="utf8") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)


Expand Down
2 changes: 1 addition & 1 deletion MaxText/generate_distillation_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def upload_data(config, data, batch_num): # pylint: disable=redefined-outer-nam
if config.remove_local_dataset_files and os.path.exists(parquet_file_name):
try:
os.remove(parquet_file_name)
except Exception as e:
except OSError as e:
max_logging.log(f"Unable to remove local dataset file {parquet_file_name}: {e}")


Expand Down
6 changes: 3 additions & 3 deletions MaxText/inference/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def update_prefill_step_pages(
f"getting {key_pages_var.shape=} and {value_pages_var.shape=} instead"
)

v_n_kv, v_n_p, v_p, v_d = key_pages_var.value.shape
v_n_kv, _, v_p, v_d = key_pages_var.value.shape
assert v_n_kv == n_kv_head, f"{v_n_kv=} {n_kv_head=}"
assert v_p == self.tokens_per_page, f"{v_p=} {self.tokens_per_page=}"
assert v_d == head_dim, f"{v_d=} {head_dim=}"
Expand Down Expand Up @@ -300,8 +300,8 @@ def update_decode_step_pages(self, key_pages_var, value_pages_var, key, value, p
key_pages = key_pages_var.value
value_pages = value_pages_var.value

batch_size, seq_len, kv_heads, head_dim = key.shape
kv_heads, num_pages, tokens_per_page, head_dim = key_pages.shape
batch_size, _, kv_heads, head_dim = key.shape
kv_heads, _, _, head_dim = key_pages.shape

new_key = key.reshape(batch_size, kv_heads, head_dim)[:, :, :]
new_key = jnp.transpose(new_key, (1, 0, 2)) # [n_kv_heads, batch_size, head_dim]
Expand Down
17 changes: 5 additions & 12 deletions MaxText/inference_mlperf/matmul/matmul_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from jax.experimental import mesh_utils
from jax.sharding import Mesh

from MaxText.inference_mlperf.matmul.timing_util import simple_timeit

PREFILL_LENS = [128, 256, 512, 1024]
EMBED = 8192
Expand Down Expand Up @@ -55,16 +54,6 @@ def f(_A, _weights):
_A = jax.lax.with_sharding_constraint(_A @ _weights, activation_sharding)
return _A

num_bits = 32
if dtype == jax.numpy.bfloat16:
num_bits = 16
elif dtype == jax.numpy.int8:
num_bits = 8
elif dtype == jax.numpy.int4:
num_bits = 4

time = simple_timeit(f, A_, W1_, task=f"matmuls_{mesh_dim}_batch_{batch}_bits_{num_bits} ")


def matmuls(mesh, mesh_dim, enable_visual=False):
for dtype in DTYPES:
Expand All @@ -74,7 +63,7 @@ def matmuls(mesh, mesh_dim, enable_visual=False):


# Start here
if __name__ == "__main__":
def main():
devices = jax.devices()
print("Devices:")
print(devices)
Expand All @@ -91,3 +80,7 @@ def matmuls(mesh, mesh_dim, enable_visual=False):
print("Optimized device topology for 2x4")
print(new_devices)
matmuls(mesh, mesh_dim)


if __name__ == "__main__":
main()
12 changes: 6 additions & 6 deletions MaxText/inference_mlperf/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,18 @@ def run(self):
class PrefillHelper:
"""Helper class to manage prefill related code and provide a unified interface."""

def __init__(self, type: str, engine: MaxEngine):
self._type = type
def __init__(self, kind: str, engine: MaxEngine):
self._type = kind
self.engine = engine
if type == "default":
if self._type == "default":
self._processor = PrefillProcessor(engine)
elif type == "batch":
elif self._type == "batch":
self._batch_processor = BatchedPrefillProcessor(engine=engine, max_batch_size=16)
self._processor = PrefillProcessor(engine) # for fallback
elif type == "dummy":
elif self._type == "dummy":
pass
else:
raise ValueError(f"Invalid type: {type}")
raise ValueError(f"Invalid type: {self._type}")

def aot_compile(
self,
Expand Down
7 changes: 3 additions & 4 deletions MaxText/inference_mlperf/offline_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,8 @@ def flush_queries(self):
resp = make_response(key, val)
lg.QuerySamplesComplete([resp])

log.info("Flush queries end")
end = time.perf_counter()
log.info("Flush queries end-start: %d", end - start)
gc.collect()

def LoadSamplesToRam(self, sample_list):
Expand Down Expand Up @@ -426,12 +426,11 @@ def _estimated_counts_by_bucket(dataset):

def main(argv):
del argv
args = FLAGS
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
# jax.config.update("jax_explain_cache_misses", True)

if FLAGS.enable_profile:
server = jax.profiler.start_server(FLAGS.jax_profiler_port)
jax.profiler.start_server(FLAGS.jax_profiler_port)

settings = lg.TestSettings()
settings.scenario = lg.TestScenario.Offline
Expand All @@ -454,7 +453,6 @@ def main(argv):
estimated_counts_by_bucket = _estimated_counts_by_bucket(dataset)
log.info("Dataset len %d, estimated counts by bucket %s", len(dataset), estimated_counts_by_bucket)

rows = list(dataset.iterrows())
len_batch_str = FLAGS.prefill_lengths_and_per_device_batch_sizes
log.info("Prefill lengths and Batch sizes: %s", len_batch_str)
log.info("Maxengine args: %s", FLAGS.maxengine_args)
Expand Down Expand Up @@ -524,6 +522,7 @@ def main(argv):
)
log.info("Starting Benchmark run")
lg.StartTestWithLogSettings(lgSUT, qsl, settings, log_settings, FLAGS.audit_conf)
# pylint: disable=protected-access
log.info("query counts %s", str(list(map(len, sut._query_batches.values()))))
log.info("Run Completed!")
log.info("Destroying SUT...")
Expand Down
15 changes: 10 additions & 5 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def apply_mask_to_logits(logits: Array, mask: Array):


# TODO(agagik): change splash_attention_mask._ComputableMask to be non protected
class ChunkedCausalMask(splash_attention_mask._ComputableMask):
class ChunkedCausalMask(splash_attention_mask._ComputableMask): # pylint: disable=protected-access
"""Lazy chunked causal mask.

Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), ... tokens attend to each other but not accross chunks.
Expand Down Expand Up @@ -633,6 +633,9 @@ def tpu_flash_attention(
axis_names_q = nn.logical_to_mesh_axes(self.flash_axis_names_q)
axis_names_kv = nn.logical_to_mesh_axes(self.flash_axis_names_kv)

global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv
Copy link
Collaborator

@shralex shralex May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain the purpose of global here and below

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was overriding variables that were set globally in the file; shadow variables.

Is it not intended for these to take over the global params? - The wording is confusing so I thought the author just forgot to make them global. I can modify to rename all these variables instead if that's not desired.

global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel
global global_q_layout, global_k_layout, global_v_layout
global_block_q = self.config.sa_block_q
global_block_kv = self.config.sa_block_kv
global_block_kv_compute = self.config.sa_block_kv_compute
Expand Down Expand Up @@ -1078,15 +1081,16 @@ def __call__(
value,
decoder_segment_ids,
model_mode,
cached_values=[None, None],
cached_values=None,
previous_chunk=None,
bidirectional_mask=None,
slot: Optional[int] = None,
page_state: Optional[page_manager.PageState] = None,
):

prefill_kv_cache = cached_values[0]
ar_kv_cache = cached_values[1]
if cached_values is None:
prefill_kv_cache, ar_kv_cache = None, None
else:
prefill_kv_cache, ar_kv_cache = cached_values[0], cached_values[1]
if model_mode != MODEL_MODE_TRAIN:
assert prefill_kv_cache
key, value, decoder_segment_ids = prefill_kv_cache
Expand Down Expand Up @@ -1841,6 +1845,7 @@ def __call__(
return out


# pylint: disable=protected-access
class LoadBalancedCausalMask(splash_attention_mask._ComputableMask):
"""Lazy causal mask, prevents the model from attending to future tokens.
Attributes:
Expand Down
1 change: 1 addition & 0 deletions MaxText/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def setup(self):
# Move embeddings to device if parameter offloading is enabled
if self.config.parameter_memory_host_offload:
max_logging.log("embeddings.py: Moving embedding parameter to device")
# pylint: disable=protected-access
self.embedding = jax.device_put(embedding, jax._src.sharding_impls.TransferToMemoryKind("device"))
else:
self.embedding = embedding
Expand Down
8 changes: 5 additions & 3 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,9 @@ def move_to_device(variables):

def map_fn(path, value):
max_logging.log(f"models.py: Moving parameter {path} to device")
return jax.device_put(value, jax._src.sharding_impls.TransferToMemoryKind("device"))
return jax.device_put(
value, jax._src.sharding_impls.TransferToMemoryKind("device") # pylint: disable=protected-access
)

return jax.tree_util.tree_map_with_path(map_fn, variables)

Expand Down Expand Up @@ -513,7 +515,7 @@ def __call__(
else:
if cfg.scan_layers:
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
assert len(RemattedBlockLayers) == 2, f"Scanned layers must have a length of 2 using deepseek."
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
dense_layer = RemattedBlockLayers[0]
moe_layer = RemattedBlockLayers[1]
y, _ = self.scan_decoder_layers(cfg, dense_layer, cfg.first_num_dense_layers, "dense_layers", mesh)(
Expand Down Expand Up @@ -553,7 +555,7 @@ def __call__(
)
else:
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
assert len(RemattedBlockLayers) == 2, f"Unscanned layers must have a length of 2 using deepseek."
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."
dense_layer = RemattedBlockLayers[0]
moe_layer = RemattedBlockLayers[1]

Expand Down
3 changes: 3 additions & 0 deletions MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo):
layer_w0 = checkpoint_name(layer_w0, "mlpwi_0")
layer_w1 = gmm(x, w1, group_sizes)
layer_w1 = checkpoint_name(layer_w1, "mlpwi_1")
# pylint: disable=protected-access
layer_act = linears._convert_to_activation_function(self.config.mlp_activations[0])(layer_w0)
intermediate_layer = jnp.multiply(layer_act, layer_w1)
intermediate_output = gmm(intermediate_layer, wo, group_sizes)
Expand Down Expand Up @@ -881,6 +882,7 @@ def dense_matmul(self, inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kerne
mlp_axis,
)
layer_w1 = checkpoint_name(layer_w1, "mlpwi_1")
# pylint: disable=protected-access
layer_w0_act = linears._convert_to_activation_function(self.config.mlp_activations[0])(layer_w0)
layer_multiply = jnp.multiply(layer_w0_act, layer_w1).astype(self.dtype)
with jax.named_scope("wo"):
Expand Down Expand Up @@ -924,6 +926,7 @@ def dense_matmul(self, inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kerne
if self.config.activations_in_float32:
layer_w1 = layer_w1.astype(jnp.float32)
layer_w1 = checkpoint_name(layer_w1, "mlpwi_1")
# pylint: disable=protected-access
layer_w0_act = linears._convert_to_activation_function(self.config.mlp_activations[0])(layer_w0)
layer_multiply = jnp.multiply(layer_w0_act, layer_w1).astype(self.dtype)
with jax.named_scope("wo"):
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def run_iteration_scannable(model, loop_state, xs):
)
loop_state, _ = run_all_iterations_scanned(self, loop_state, None)
else:
for loop_iteration in range(total_iterations):
for _ in range(total_iterations):
loop_state, _ = run_iteration_scannable(self, loop_state, None)

# The final output is located in the input/output array, however the output microbatches may be permuted relative to
Expand Down
3 changes: 3 additions & 0 deletions MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _get_mixed_precision_cfg(self):
quant_dg = None
is_tiled = False
tiling_fn = None
# pylint: disable=protected-access
module_path = "/".join(nn.module._context.module_stack[-1].path)
tile_size = -1
for layer_name_re, layer_quant_dg in self.quant_dg.items():
Expand Down Expand Up @@ -410,6 +411,7 @@ def match_aqt_and_unquantized_param(aqt_params, params):
param_paths = []

for aqt_k, _ in aqt_param_flat:
index = None
for index, (k, _) in enumerate(param_tree_flat):
path_depth = len(k)
# every quantized parameter has AQT.. as the leaf node
Expand All @@ -419,6 +421,7 @@ def match_aqt_and_unquantized_param(aqt_params, params):
aqt_paths.append(aqt_k)
param_paths.append(k)
break
assert index is not None
# since the parameter is already added, we can delete it.
param_tree_flat.pop(index)
return jax.tree_util.tree_unflatten(aqt_tree_def, param_paths)
Expand Down
5 changes: 1 addition & 4 deletions MaxText/llama4_ckpt_unscanned.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

from MaxText import max_logging
from MaxText.inference_utils import str2bool
from MaxText.llama_or_mistral_ckpt import save_weights_to_checkpoint, permute_to_match_maxtext_rope, MODEL_PARAMS_DICT
from MaxText.llama_or_mistral_ckpt import save_weights_to_checkpoint, MODEL_PARAMS_DICT

SIMULATED_CPU_DEVICES_COUNT = 16

Expand Down Expand Up @@ -264,7 +264,6 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m

# self attention ###############################################
max_logging.log("Processing self attention")
has_printed_warning = False
for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False):
layer_name = f"layers_{layer_idx}"
jax_weights["decoder"].update(
Expand Down Expand Up @@ -513,7 +512,6 @@ def _convert_pytorch_to_jax_weights(base_model_path: str, model_size: str, model

# llama3.1-405b kv weight is replicated within every two files.
wkv_step = 1 if model_size != "llama3.1-405b" else 2
has_printed_warning = False

for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False):
layer_name = f"layers_{layer_idx}"
Expand Down Expand Up @@ -727,7 +725,6 @@ def _convert_pytorch_to_jax_weights(base_model_path: str, model_size: str, model
axis=2,
)
# NOTE: should probably update this to be more rigorous, but this should be fine for now
f_dim = wi_0.shape[-1]
wi_1 = np.concatenate(
[
var[f"layers.{layer_idx}.feed_forward.experts.moe_w_swiglu_eD_F"]
Expand Down
2 changes: 1 addition & 1 deletion MaxText/multimodal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def load_image_from_path(image_path):
image.load() # Load image data to catch errors early
return jnp.asarray(np.array(image))
except (IOError, OSError) as e:
raise IOError(f"Error loading image from {image_path}")
raise IOError(f"Error loading image from {image_path}") from e


def _normalize_images(images, mean, std):
Expand Down
4 changes: 3 additions & 1 deletion MaxText/scratch_code/generate_grpo_golden_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def test_w_trl_and_write_golden_data(self):
# using the same model as the ref model,
# which is equivalent of step 0 of GRPO training when
# the on-policy params are the same as the ref model
# pylint: disable=protected-access
"ref_per_token_logps": self.trainer._get_per_token_logps(
self.hf_model, hf_input_ids, attention_mask, logits_to_keep
), # pylint: disable=protected-access
Expand All @@ -170,7 +171,7 @@ def test_w_trl_and_write_golden_data(self):
}
hf_loss = self.trainer.compute_loss(self.hf_model, inputs)

hf_per_token_logps = self.trainer._get_per_token_logps(self.hf_model, hf_input_ids, attention_mask, logits_to_keep) # pylint: disable=protected-access
self.trainer._get_per_token_logps(self.hf_model, hf_input_ids, attention_mask, logits_to_keep) # pylint: disable=protected-access

input_ids, input_segmentation, input_position, completion_segmentation = prepare_maxtext_inputs(
self.cfg.prompt, self.tokenizer_model
Expand Down Expand Up @@ -199,6 +200,7 @@ def test_w_trl_and_write_golden_data(self):
"ar_completions_segmentation": completion_segmentation,
}
maxtext_loss, aux = grpo_loss_fn(self.model, self.cfg, data, self.rng, self.state.params, reference_params)
# pylint: disable=protected-access
self.assertEqual(self.trainer._metrics["train"]["kl"][0], aux.avg_kl.tolist())
self.assertEqual(hf_loss.item(), maxtext_loss.tolist())
# since this is on-policy
Expand Down
Loading
Loading