Skip to content

Commit 2d8ab1f

Browse files
committed
janky seq2seq support (will be reverted)
1 parent c0a1569 commit 2d8ab1f

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

bigcode_eval/tasks/shadereval.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def __init__(self, prompt="minimal"):
133133
requires_execution=True, #we run shadercode - could that be harmful? (all in the metric)
134134
)
135135
self.prompt = prompt # "minimal" or "full". "minimal" is the function header and comments before/after it, "full" is the whole code up untill the function declaration ends
136+
# while we develop this dataset, we use a private dataset, so we overwrite the init call here, which silently fails -.- (warn doesn't show up)
137+
self.dataset = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME, trust_remote_code=True, use_auth_token=True)
136138

137139
def get_dataset(self):
138140
# TODO replace with subset once that is set up
@@ -160,6 +162,13 @@ def get_prompt(self, doc):
160162
# only have one alternative, but could be more?
161163
model_context += doc["model_ctx"]
162164
return model_context
165+
166+
def get_prompt_encoder(self, doc):
167+
"""
168+
this is needed for seq2seq models, but not availabel by default?
169+
"""
170+
enc_prompt = doc["model_ctx"] + "<extra_id_0>" #magic token to trigger generation for CodeT5p?
171+
return enc_prompt
163172

164173
def get_reference(self, doc):
165174
# TODO: get the reference solution from a sample `doc` from the dataset
@@ -213,6 +222,7 @@ def postprocess_generation(self, generation, idx):
213222
# from: https://huggingface.co/spaces/Vipitis/ShaderCoder/blob/main/utils/tree_utils.py#L45
214223
# generation = ShaderCoder.utils.parse_functions(generation)[0].text.decode() #not easily imported...
215224

225+
print(generation)
216226

217227
# assemble into the full code with just the function replaced
218228
ref = self.dataset["test"][idx]

bigcode_eval/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,11 @@ def complete_code(
281281
)
282282
else:
283283
generated_tokens = model.generate(
284-
decoder_input_ids=inputs,
284+
# decoder_input_ids=inputs,
285285
input_ids=batch["ids_encoder"][:, : batch["input_len_encoder"]],
286286
num_return_sequences=batch_size,
287-
decoder_start_token_id=tokenizer.pad_token_id,
288-
eos_token_id=tokenizer.eos_token_id,
287+
# decoder_start_token_id=tokenizer.pad_token_id,
288+
# eos_token_id=tokenizer.eos_token_id,
289289
**gen_kwargs,
290290
)
291291
else:

0 commit comments

Comments
 (0)