-
Notifications
You must be signed in to change notification settings - Fork 356
[*.py] Resolve pylint w-class: W0102,W0107,W0212,W0221,W0223,W0237,W0404,W0611,W0612,W0621,W0622,W0631,W0707,W0718,W1201,W1203,W1309,W1514,W4901 ; [code_style.sh,.github/workflows/CPUTests.yml] Enable w-class #1749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,7 +110,7 @@ def apply_mask_to_logits(logits: Array, mask: Array): | |
|
||
|
||
# TODO(agagik): change splash_attention_mask._ComputableMask to be non protected | ||
class ChunkedCausalMask(splash_attention_mask._ComputableMask): | ||
class ChunkedCausalMask(splash_attention_mask._ComputableMask): # pylint: disable=protected-access | ||
"""Lazy chunked causal mask. | ||
|
||
Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), ... tokens attend to each other but not accross chunks. | ||
|
@@ -633,6 +633,9 @@ def tpu_flash_attention( | |
axis_names_q = nn.logical_to_mesh_axes(self.flash_axis_names_q) | ||
axis_names_kv = nn.logical_to_mesh_axes(self.flash_axis_names_kv) | ||
|
||
global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you explain the purpose of global here and below There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was overriding variables that were set globally in the file; shadow variables. Is it not intended for these to take over the global params? - The wording is confusing so I thought the author just forgot to make them |
||
global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel | ||
global global_q_layout, global_k_layout, global_v_layout | ||
global_block_q = self.config.sa_block_q | ||
global_block_kv = self.config.sa_block_kv | ||
global_block_kv_compute = self.config.sa_block_kv_compute | ||
|
@@ -1078,15 +1081,16 @@ def __call__( | |
value, | ||
decoder_segment_ids, | ||
model_mode, | ||
cached_values=[None, None], | ||
cached_values=None, | ||
previous_chunk=None, | ||
bidirectional_mask=None, | ||
slot: Optional[int] = None, | ||
page_state: Optional[page_manager.PageState] = None, | ||
): | ||
|
||
prefill_kv_cache = cached_values[0] | ||
ar_kv_cache = cached_values[1] | ||
if cached_values is None: | ||
prefill_kv_cache, ar_kv_cache = None, None | ||
else: | ||
prefill_kv_cache, ar_kv_cache = cached_values[0], cached_values[1] | ||
if model_mode != MODEL_MODE_TRAIN: | ||
assert prefill_kv_cache | ||
key, value, decoder_segment_ids = prefill_kv_cache | ||
|
@@ -1841,6 +1845,7 @@ def __call__( | |
return out | ||
|
||
|
||
# pylint: disable=protected-access | ||
class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): | ||
"""Lazy causal mask, prevents the model from attending to future tokens. | ||
Attributes: | ||
|
Uh oh!
There was an error while loading. Please reload this page.