2
2
3
3
from typing import Dict , Optional , Tuple
4
4
5
+ import jax
5
6
import jax .numpy as jnp
7
+ import optax
6
8
from flax import jax_utils
7
9
from algoperf import param_utils
10
+ from algoperf import sharding_utils
8
11
from algoperf import spec
9
12
from algoperf .workloads .lm .workload import BaseLmWorkload
10
13
from algoperf .workloads .lm .lm_jax .models import LinearModel
14
+ from algoperf .workloads .lm .input_pipeline import get_hf_dataloader , get_lm_dataset
11
15
12
16
13
17
class LmWorkload (BaseLmWorkload ):
14
18
"""LM JAX workload."""
19
+ def _build_input_queue (self ,
20
+ data_rng : jax .random .PRNGKey ,
21
+ split : str ,
22
+ data_dir : str ,
23
+ global_batch_size : int ,
24
+ num_batches : Optional [int ] = None ,
25
+ repeat_final_dataset : bool = False ):
26
+ """Build an input queue using pre-cached FineWeb dataset."""
27
+ del num_batches
28
+ del repeat_final_dataset
29
+ loader = get_lm_dataset (
30
+ data_rng = data_rng ,
31
+ split = split ,
32
+ data_dir = data_dir ,
33
+ global_batch_size = global_batch_size )
34
+ return loader
15
35
16
36
def init_model_fn (
17
37
self ,
@@ -21,14 +41,15 @@ def init_model_fn(
21
41
22
42
model = LinearModel (vocab_size = self ._vocab_size )
23
43
input_shape = (1 , self ._seq_len , self ._vocab_size )
24
- variables = model .init (rng , jnp .ones (input_shape , jnp .float32 ))
25
- model_state , params = variables .pop ('params' )
26
-
44
+ params_rng , init_rng = jax .random .split (rng )
45
+ variables = jax .jit (model .init )({'params' : params_rng },
46
+ jnp .ones (input_shape , jnp .float32 ))
47
+ params = variables ['params' ]
27
48
self ._param_shapes = param_utils .jax_param_shapes (params )
28
49
self ._param_types = param_utils .jax_param_types (self ._param_shapes )
29
- model_state = jax_utils . replicate ( model_state )
30
- params = jax_utils . replicate ( params )
31
-
50
+ params = sharding_utils . shard_replicated ( params )
51
+ model_state = None
52
+ self . _model = model
32
53
return params , model_state
33
54
34
55
def model_fn (
@@ -40,15 +61,40 @@ def model_fn(
40
61
rng : spec .RandomState ,
41
62
update_batch_norm : bool ) -> Tuple [spec .Tensor , spec .ModelAuxiliaryState ]:
42
63
43
- del mode , rng , update_batch_norm # Not used for linear model
44
- inputs = batch ['inputs' ]
45
- logits = self ._model .apply ({'params' : params , ** model_state }, inputs )
46
- return logits , model_state
64
+ del mode , rng , update_batch_norm , model_state
65
+ inputs = jax .nn .one_hot (batch ['inputs' ], self ._vocab_size , axis = - 1 )
66
+ logits = self ._model .apply ({'params' : params }, inputs )
67
+ return logits , None
68
+
69
+ def loss_fn (
70
+ self ,
71
+ label_batch : spec .Tensor , # One-hot labels.
72
+ logits_batch : spec .Tensor , # Dense logits.
73
+ mask_batch : Optional [spec .Tensor ] = None ,
74
+ label_smoothing : Optional [float ] = 0.0 ) -> Dict [str , spec .Tensor ]:
75
+ del mask_batch , label_smoothing
76
+ logits_flat = logits_batch .reshape (- 1 , self ._vocab_size )
77
+ targets = jax .nn .one_hot (label_batch , self ._vocab_size , axis = - 1 )
78
+ targets_flat = targets .reshape (- 1 , self ._vocab_size )
79
+ # Cross-entropy loss
80
+ loss = - jnp .sum (targets_flat * jax .nn .log_softmax (logits_flat , axis = - 1 ))
81
+ n_valid_examples = logits_flat .shape [0 ]
82
+ return {'summed' : loss , 'n_valid_examples' : n_valid_examples }
47
83
84
+ def is_output_params (self , param_name : str ) -> bool :
85
+ """Return whether the given parameter is an output parameter."""
86
+ return param_name .contains ('output' )
87
+
48
88
def _eval_batch (self ,
49
89
params : spec .ParameterContainer ,
50
90
batch : Dict [str , spec .Tensor ],
51
91
model_state : spec .ModelAuxiliaryState ,
52
92
rng : spec .RandomState ) -> spec .Tensor :
53
93
"""Evaluate the model on a single batch."""
54
- pass
94
+ logits , _ = self .model_fn (
95
+ params , batch , model_state , spec .ForwardPassMode .EVAL , rng , False )
96
+ targets = batch ['targets' ]
97
+
98
+ # Calculate cross-entropy loss
99
+ loss = - jnp .sum (targets * jax .nn .log_softmax (logits , axis = - 1 ))
100
+ return loss
0 commit comments