Skip to content

Commit c574207

Browse files
committed
minor fixes
1 parent fcbbe16 commit c574207

File tree

5 files changed

+21
-28
lines changed

5 files changed

+21
-28
lines changed

models/multiplexing.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,6 @@ def forward(
235235
.expand(modified_batch_size, modified_seq_length),
236236
instance_labels,
237237
]
238-
retrieval_labels = torch.div(retrieval_labels, self.config.retrieval_loss_vocab_scale, rounding_mode='trunc')
239-
retrieval_labels = retrieval_labels.long()
240238
retrieval_labels[:, :special_tokens_end_position] = -100
241239

242240
pad_mask = retrieval_labels == 1
@@ -258,7 +256,7 @@ def forward(
258256
loss_fct = CrossEntropyLoss()
259257
task_loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
260258
retrieval_loss = loss_fct(
261-
retrieval_predictions.view(-1, self.config.vocab_size)
259+
retrieval_predictions.view(-1, self.config.vocab_size),
262260
retrieval_labels.view(-1),
263261
)
264262
loss = (self.task_loss_coeff * task_loss) + (
@@ -582,8 +580,8 @@ def __init__(self, config):
582580
self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
583581
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
584582

585-
self.decoder = nn.Linear(config.hidden_size, math.ceil(config.vocab_size / config.retrieval_loss_vocab_scale))
586-
self.bias = nn.Parameter(torch.zeros(math.ceil(config.vocab_size / config.retrieval_loss_vocab_scale)))
583+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
584+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
587585
self.decoder.bias = self.bias
588586

589587
def forward(self, features, instance_labels, **kwargs):
@@ -727,8 +725,8 @@ def __init__(self, config):
727725
self.layer_norm_pre_vocab = nn.LayerNorm(
728726
config.hidden_size, eps=config.layer_norm_eps
729727
)
730-
self.decoder = nn.Linear(config.hidden_size, math.ceil(config.vocab_size / config.retrieval_loss_vocab_scale))
731-
self.bias = nn.Parameter(torch.zeros(math.ceil(config.vocab_size / config.retrieval_loss_vocab_scale)))
728+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
729+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
732730
self.decoder.bias = self.bias
733731

734732
def forward(self, features, instance_labels, **kwargs):

models/trainer.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1754,12 +1754,13 @@ def prediction_step(
17541754
) = self.compute_loss(model, inputs, return_outputs=True)
17551755
loss = loss.mean().detach()
17561756
if isinstance(outputs, dict):
1757-
logits = tuple(
1758-
v
1759-
for k, v in outputs.items()
1760-
if k
1761-
not in ignore_keys + ["loss", "task_loss", "retrieval_loss"]
1762-
)
1757+
# logits = tuple(
1758+
# v
1759+
# for k, v in outputs.items()
1760+
# if k
1761+
# not in ignore_keys + ["loss", "task_loss", "retrieval_loss"]
1762+
# )
1763+
logits = outputs["logits"] if "logits" in outputs else None
17631764
else:
17641765
logits = outputs[1:]
17651766
if "retrieval_loss" in outputs:
@@ -1778,9 +1779,10 @@ def prediction_step(
17781779
else:
17791780
outputs = model(**inputs)
17801781
if isinstance(outputs, dict):
1781-
logits = tuple(
1782-
v for k, v in outputs.items() if k not in ignore_keys
1783-
)
1782+
# logits = tuple(
1783+
# v for k, v in outputs.items() if k not in ignore_keys
1784+
# )
1785+
logits = outputs["logits"] if "logits" in outputs else None
17841786
else:
17851787
logits = outputs
17861788
if self.args.past_index >= 0:

run_glue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ def compute_metrics(p: EvalPrediction):
696696
# eval_datasets.append(datasets["validation_mismatched"])
697697

698698
for eval_dataset, task in zip(eval_datasets, tasks):
699-
metrics = trainer.evaluate(eval_dataset=eval_dataset)
699+
metrics = trainer.evaluate(eval_dataset=eval_dataset)
700700

701701
max_eval_samples = (
702702
data_args.max_eval_samples

run_ner.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,6 @@ def main():
291291
last_checkpoint = None
292292
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
293293
last_checkpoint = get_last_checkpoint(training_args.output_dir)
294-
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
295-
raise ValueError(
296-
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
297-
"Use --overwrite_output_dir to overcome."
298-
)
299294

300295
# Set seed before initializing model.
301296
set_seed(training_args.seed)
@@ -643,10 +638,6 @@ def compute_metrics(p):
643638
else:
644639
kwargs["dataset"] = data_args.dataset_name
645640

646-
if training_args.push_to_hub:
647-
trainer.push_to_hub(**kwargs)
648-
else:
649-
trainer.create_model_card(**kwargs)
650641

651642

652643
def _mp_fn(index):

run_ner.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,10 @@ CMD="python run_ner.py \
277277
--demuxing_variant ${DEMUXING} \
278278
--should_mux ${SHOULD_MUX} \
279279
--gaussian_hadamard_norm ${RANDOM_ENCODING_NORM} \
280-
--learn_muxing ${LEARN_MUXING}"
281-
280+
--learn_muxing ${LEARN_MUXING} \
281+
--load_best_model_at_end 1 \
282+
--metric_for_best_model eval_f1 \
283+
--save_total_limit 1"
282284
if [ "$DO_TRAIN" -eq 1 ]; then
283285
CMD="${CMD} --do_train"
284286
fi

0 commit comments

Comments
 (0)