From c26a66e96c6c81986b0b772d03bf151ad9975ee7 Mon Sep 17 00:00:00 2001 From: anthonyduong Date: Wed, 4 Jun 2025 16:51:07 -0700 Subject: [PATCH] fixes EmbeddingScorer._prepare() passes arg of wrong type --- delphi/scorers/embedding/embedding.py | 33 ++++++++++++++------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/delphi/scorers/embedding/embedding.py b/delphi/scorers/embedding/embedding.py index 2de89874..26943623 100644 --- a/delphi/scorers/embedding/embedding.py +++ b/delphi/scorers/embedding/embedding.py @@ -51,7 +51,7 @@ async def __call__( # type: ignore random.shuffle(samples) results = self._query( record.explanation, - samples, # type: ignore + samples, ) return ScorerResult(record=record, score=results) @@ -59,30 +59,31 @@ async def __call__( # type: ignore def call_sync(self, record: LatentRecord) -> list[EmbeddingOutput]: return asyncio.run(self.__call__(record)) # type: ignore - def _prepare(self, record: LatentRecord) -> list[list[Sample]]: + def _prepare(self, record: LatentRecord) -> list[Sample]: """ Prepare and shuffle a list of samples for classification. """ + samples = [] - defaults = { - "tokenizer": self.tokenizer, - } - samples = examples_to_samples( - record.extra_examples, # type: ignore - distance=-1, - **defaults, # type: ignore - ) + if record.extra_examples is not None: + samples.extend( + examples_to_samples( + record.extra_examples, + tokenizer=self.tokenizer, + distance=-1, + ) + ) - for i, examples in enumerate(record.test): + for i, example in enumerate(record.test): samples.extend( examples_to_samples( - examples, # type: ignore + [example], + tokenizer=self.tokenizer, distance=i + 1, - **defaults, # type: ignore ) ) - return samples # type: ignore + return samples def _query(self, explanation: str, samples: list[Sample]) -> list[EmbeddingOutput]: explanation_string = ( @@ -110,7 +111,7 @@ def _query(self, explanation: str, samples: list[Sample]) -> list[EmbeddingOutpu def examples_to_samples( examples: list[Example], - tokenizer: PreTrainedTokenizer, + tokenizer: PreTrainedTokenizer | None, **sample_kwargs, ) -> list[Sample]: samples = [] @@ -118,7 +119,7 @@ def examples_to_samples( if tokenizer is not None: text = "".join(tokenizer.batch_decode(example.tokens)) else: - text = "".join(example.tokens) + text = "".join(str(token) for token in example.tokens) activations = example.activations.tolist() samples.append( Sample(