Skip to content

Commit d86b611

Browse files
committed
fix minor type issues; add type ignore to loosely typed files
1 parent 715d963 commit d86b611

File tree

8 files changed

+82
-47
lines changed

8 files changed

+82
-47
lines changed

delphi/__main__.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from delphi.pipeline import Pipe, Pipeline, process_wrapper
2727
from delphi.scorers import DetectionScorer, FuzzingScorer
2828
from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders
29-
from delphi.utils import load_tokenized_data
29+
from delphi.utils import assert_type, load_tokenized_data
3030

3131

3232
def load_artifacts(run_cfg: RunConfig):
@@ -325,8 +325,11 @@ async def run(
325325
hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg)
326326
tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token)
327327

328-
nrh = non_redundant_hookpoints(
329-
hookpoint_to_sparse_encode, latents_path, "cache" in run_cfg.overwrite
328+
nrh = assert_type(
329+
dict,
330+
non_redundant_hookpoints(
331+
hookpoint_to_sparse_encode, latents_path, "cache" in run_cfg.overwrite
332+
),
330333
)
331334
if nrh:
332335
populate_cache(
@@ -340,8 +343,11 @@ async def run(
340343

341344
del model, hookpoint_to_sparse_encode
342345
if run_cfg.constructor_cfg.non_activating_source == "neighbours":
343-
nrh = non_redundant_hookpoints(
344-
hookpoints, neighbours_path, "neighbours" in run_cfg.overwrite
346+
nrh = assert_type(
347+
list,
348+
non_redundant_hookpoints(
349+
hookpoints, neighbours_path, "neighbours" in run_cfg.overwrite
350+
),
345351
)
346352
if nrh:
347353
create_neighbours(
@@ -353,8 +359,11 @@ async def run(
353359
else:
354360
print("Skipping neighbour creation")
355361

356-
nrh = non_redundant_hookpoints(
357-
hookpoints, scores_path, "scores" in run_cfg.overwrite
362+
nrh = assert_type(
363+
list,
364+
non_redundant_hookpoints(
365+
hookpoints, scores_path, "scores" in run_cfg.overwrite
366+
),
358367
)
359368
if nrh:
360369
await process_cache(

delphi/latents/latents.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class LatentRecord:
134134
train: list[ActivatingExample] = field(default_factory=list)
135135
"""Training examples."""
136136

137-
test: list[ActivatingExample] = field(default_factory=list)
137+
test: list[ActivatingExample] | list[list[Example]] = field(default_factory=list)
138138
"""Test examples."""
139139

140140
neighbours: list[Neighbour] = field(default_factory=list)
@@ -143,6 +143,9 @@ class LatentRecord:
143143
explanation: str = ""
144144
"""Explanation of the latent."""
145145

146+
extra_examples: Optional[list[Example]] = None
147+
"""Extra examples to include in the record."""
148+
146149
@property
147150
def max_activation(self) -> float:
148151
"""

delphi/scorers/simulator/oai_autointerp/explanations/simulator.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def parse_top_logprobs(top_logprobs: dict[str, float]) -> OrderedDict[int, float
107107
"""
108108
probabilities_by_distribution_value = OrderedDict()
109109
for token, contents in top_logprobs.items():
110-
logprob = contents.logprob
111-
decoded_token = contents.decoded_token
110+
logprob = contents.logprob # type: ignore
111+
decoded_token = contents.decoded_token # type: ignore
112112
if decoded_token in VALID_ACTIVATION_TOKENS:
113113
token_as_int = int(decoded_token)
114114
probabilities_by_distribution_value[token_as_int] = np.exp(logprob)
@@ -134,7 +134,7 @@ def compute_predicted_activation_stats_for_token(
134134

135135

136136
def parse_simulation_response(
137-
response: dict[str, Any],
137+
response: Any,
138138
tokenized_prompt: list[int],
139139
tab_token: int,
140140
tokens: Sequence[str],
@@ -250,11 +250,11 @@ async def simulate(
250250
else:
251251
assert isinstance(prompt, str)
252252

253-
response = await self.client.generate(prompt, **sampling_params)
254-
tokenized_prompt = self.client.tokenizer.apply_chat_template(
253+
response = await self.client.generate(prompt, **sampling_params) # type: ignore
254+
tokenized_prompt = self.client.tokenizer.apply_chat_template( # type: ignore
255255
prompt, add_generation_prompt=True
256256
)
257-
tab_token = self.client.tokenizer.encode("\t")[1]
257+
tab_token = self.client.tokenizer.encode("\t")[1] # type: ignore
258258
logger.debug("response in score_explanation_by_activations is %s", response)
259259
try:
260260
result = parse_simulation_response(
@@ -287,7 +287,7 @@ def make_simulation_prompt(
287287
# Consider reconciling them.
288288
prompt_builder = PromptBuilder()
289289
prompt_builder.add_message(
290-
"system",
290+
"system", # type: ignore
291291
"""We're studying neurons in a neural network.
292292
Each neuron looks for some particular thing in a short document.
293293
Look at summary of what the neuron does, and try to predict how it will fire on each token.
@@ -299,7 +299,7 @@ def make_simulation_prompt(
299299
few_shot_examples = self.few_shot_example_set.get_examples()
300300
for i, example in enumerate(few_shot_examples):
301301
prompt_builder.add_message(
302-
"user",
302+
"user", # type: ignore
303303
f"\n\nNeuron {i + 1}\nExplanation of neuron {i + 1} behavior: {EXPLANATION_PREFIX}"
304304
f"{example.explanation}",
305305
)
@@ -309,17 +309,17 @@ def make_simulation_prompt(
309309
start_indices=example.first_revealed_activation_indices,
310310
)
311311
prompt_builder.add_message(
312-
"assistant", f"\nActivations: {formatted_activation_records}\n"
312+
"assistant", f"\nActivations: {formatted_activation_records}\n" # type: ignore
313313
)
314314

315315
prompt_builder.add_message(
316-
"user",
316+
"user", # type: ignore
317317
f"\n\nNeuron {len(few_shot_examples) + 1}\nExplanation of neuron "
318318
f"{len(few_shot_examples) + 1} behavior: {EXPLANATION_PREFIX} "
319319
f"{self.explanation.strip()}",
320320
)
321321
prompt_builder.add_message(
322-
"assistant",
322+
"assistant", # type: ignore
323323
f"\nActivations: {format_sequences_for_simulation([tokens])}",
324324
)
325325
return prompt_builder.build(self.prompt_format)
@@ -595,7 +595,7 @@ async def simulate(self, tokens: Sequence[str]) -> SequenceSimulation:
595595

596596
result = SequenceSimulation(
597597
activation_scale=ActivationScale.SIMULATED_NORMALIZED_ACTIVATIONS,
598-
expected_activations=predicted_activations,
598+
expected_activations=predicted_activations, # type: ignore
599599
# Since the predicted activation is just a sampled token, we don't have a distribution.
600600
distribution_values=[],
601601
distribution_probabilities=[],
@@ -614,7 +614,7 @@ def _make_simulation_prompt_json(
614614
assert explanation != ""
615615
prompt_builder = PromptBuilder()
616616
prompt_builder.add_message(
617-
"system",
617+
"system", # type: ignore
618618
"""We're studying neurons in a neural network. Each neuron looks for certain things in a short document. Your task is to read the explanation of what the neuron does, and predict the neuron's activations for each token in the document.
619619
620620
For each document, you will see the full text of the document, then the tokens in the document with the activation left blank. You will print, in valid json, the exact same tokens verbatim, but with the activation values filled in according to the explanation. Pay special attention to the explanation's description of the context and order of tokens or words.
@@ -638,7 +638,7 @@ def _make_simulation_prompt_json(
638638
}
639639
"""
640640
prompt_builder.add_message(
641-
"user",
641+
"user", # type: ignore
642642
_format_record_for_logprob_free_simulation_json(
643643
explanation=example.explanation,
644644
activation_record=example.activation_records[0],
@@ -658,7 +658,7 @@ def _make_simulation_prompt_json(
658658
}
659659
"""
660660
prompt_builder.add_message(
661-
"assistant",
661+
"assistant", # type: ignore
662662
_format_record_for_logprob_free_simulation_json(
663663
explanation=example.explanation,
664664
activation_record=example.activation_records[0],
@@ -678,10 +678,10 @@ def _make_simulation_prompt_json(
678678
}
679679
"""
680680
prompt_builder.add_message(
681-
"user",
681+
"user", # type: ignore
682682
_format_record_for_logprob_free_simulation_json(
683683
explanation=explanation,
684-
activation_record=ActivationRecord(tokens=tokens, activations=[]),
684+
activation_record=ActivationRecord(tokens=tokens, activations=[]), # type: ignore
685685
include_activations=False,
686686
),
687687
)
@@ -698,7 +698,7 @@ def _make_simulation_prompt(
698698
assert explanation != ""
699699
prompt_builder = PromptBuilder()
700700
prompt_builder.add_message(
701-
"system",
701+
"system", # type: ignore
702702
"""We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token.
703703
704704
The activation format is token<tab>activation, and activations range from 0 to 10. Most activations will be 0.
@@ -716,7 +716,7 @@ def _make_simulation_prompt(
716716
example.activation_records[0], include_activations=False
717717
)
718718
prompt_builder.add_message(
719-
"user",
719+
"user", # type: ignore
720720
f"Neuron {i + 1}\nExplanation of neuron {i + 1} behavior: {EXPLANATION_PREFIX} "
721721
f"{example.explanation}\n\n"
722722
f"Sequence 1 Tokens without Activations:\n{tokens_without_activations}\n\n"
@@ -728,7 +728,7 @@ def _make_simulation_prompt(
728728
max_activation=few_shot_example_max_activation,
729729
)
730730
prompt_builder.add_message(
731-
"assistant",
731+
"assistant", # type: ignore
732732
f"{tokens_with_activations}\n\n",
733733
)
734734

@@ -737,7 +737,7 @@ def _make_simulation_prompt(
737737
record, include_activations=False
738738
)
739739
prompt_builder.add_message(
740-
"user",
740+
"user", # type: ignore
741741
f"Sequence {record_index + 2} Tokens without Activations:\n{tks_without}\n\n"
742742
f"Sequence {record_index + 2} Tokens with Activations:\n",
743743
)
@@ -747,16 +747,16 @@ def _make_simulation_prompt(
747747
max_activation=few_shot_example_max_activation,
748748
)
749749
prompt_builder.add_message(
750-
"assistant",
750+
"assistant", # type: ignore
751751
f"{tokens_with_activations}\n\n",
752752
)
753753

754754
neuron_index = len(few_shot_examples) + 1
755755
tokens_without_activations = _format_record_for_logprob_free_simulation(
756-
ActivationRecord(tokens=tokens, activations=[]), include_activations=False
756+
ActivationRecord(tokens=tokens, activations=[]), include_activations=False # type: ignore
757757
)
758758
prompt_builder.add_message(
759-
"user",
759+
"user", # type: ignore
760760
f"Neuron {neuron_index}\nExplanation of neuron {neuron_index} behavior: {EXPLANATION_PREFIX} "
761761
f"{explanation}\n\n"
762762
f"Sequence 1 Tokens without Activations:\n{tokens_without_activations}\n\n"

delphi/scorers/surprisal/surprisal.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
from typing import NamedTuple
44

55
import torch
6+
from simple_parsing import field
67
from torch.nn.functional import cross_entropy
78
from transformers import PreTrainedTokenizer
89

9-
from ...latents import Example, LatentRecord
10+
from delphi.utils import assert_type
11+
12+
from ...latents import ActivatingExample, Example, LatentRecord
1013
from ..scorer import Scorer, ScorerResult
1114
from .prompts import BASEPROMPT as base_prompt
1215

@@ -19,13 +22,13 @@ class SurprisalOutput:
1922
distance: float | int
2023
"""Quantile or neighbor distance"""
2124

22-
no_explanation: list[float] = 0
25+
no_explanation: list[float] = field(default_factory=list)
2326
"""What is the surprisal of the model with no explanation"""
2427

25-
explanation: list[float] = 0
28+
explanation: list[float] = field(default_factory=list)
2629
"""What is the surprisal of the model with an explanation"""
2730

28-
activations: list[float] = 0
31+
activations: list[float] = field(default_factory=list)
2932
"""What are the activations of the model"""
3033

3134

@@ -55,7 +58,7 @@ def __init__(
5558
async def __call__(
5659
self,
5760
record: LatentRecord,
58-
) -> list[SurprisalOutput]:
61+
) -> ScorerResult:
5962
samples = self._prepare(record)
6063

6164
random.shuffle(samples)
@@ -66,21 +69,24 @@ async def __call__(
6669

6770
return ScorerResult(record=record, score=results)
6871

69-
def _prepare(self, record: LatentRecord) -> list[list[Sample]]:
72+
def _prepare(self, record: LatentRecord) -> list[Sample]:
7073
"""
7174
Prepare and shuffle a list of samples for classification.
7275
"""
7376

7477
defaults = {
7578
"tokenizer": self.tokenizer,
7679
}
80+
81+
assert record.extra_examples is not None, "No extra examples provided"
7782
samples = examples_to_samples(
7883
record.extra_examples,
7984
distance=-1,
8085
**defaults,
8186
)
8287

8388
for i, examples in enumerate(record.test):
89+
examples = assert_type(list, examples)
8490
samples.extend(
8591
examples_to_samples(
8692
examples,
@@ -181,7 +187,7 @@ def _query(self, explanation: str, samples: list[Sample]) -> list[SurprisalOutpu
181187

182188

183189
def examples_to_samples(
184-
examples: list[Example],
190+
examples: list[Example] | list[ActivatingExample],
185191
tokenizer: PreTrainedTokenizer,
186192
**sample_kwargs,
187193
) -> list[Sample]:

delphi/sparse_coders/load_sparsify.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,21 @@
33
from typing import Callable
44

55
import torch
6-
from sparsify import Sae
6+
from sparsify import SparseCoder
77
from torch import Tensor
88
from transformers import PreTrainedModel
99

1010

11-
def sae_dense_latents(x: Tensor, sae: Sae) -> Tensor:
11+
def sae_dense_latents(x: Tensor, sae: SparseCoder) -> Tensor:
1212
"""Run `sae` on `x`, yielding the dense activations."""
1313
pre_acts = sae.pre_acts(x)
1414
acts, indices = sae.select_topk(pre_acts)
1515
return torch.zeros_like(pre_acts).scatter_(-1, indices, acts)
1616

1717

18-
def resolve_path(model: PreTrainedModel, path_segments: list[str]) -> list[str] | None:
18+
def resolve_path(
19+
model: PreTrainedModel | torch.nn.Module, path_segments: list[str]
20+
) -> list[str] | None:
1921
"""Attempt to resolve the path segments to the model in the case where it
2022
has been wrapped (e.g. by a LanguageModel, causal model, or classifier)."""
2123
# If the first segment is a valid attribute, return the path segments
@@ -45,7 +47,7 @@ def load_sparsify_sparse_coders(
4547
hookpoints: list[str],
4648
device: str | torch.device,
4749
compile: bool = False,
48-
) -> dict[str, Sae]:
50+
) -> dict[str, SparseCoder]:
4951
"""
5052
Load sparsify sparse coders for specified hookpoints.
5153
@@ -67,7 +69,7 @@ def load_sparsify_sparse_coders(
6769
name_path = Path(name)
6870
if name_path.exists():
6971
for hookpoint in hookpoints:
70-
sparse_model_dict[hookpoint] = Sae.load_from_disk(
72+
sparse_model_dict[hookpoint] = SparseCoder.load_from_disk(
7173
name_path / hookpoint, device=device
7274
)
7375
if compile:
@@ -76,7 +78,7 @@ def load_sparsify_sparse_coders(
7678
)
7779
else:
7880
# Load on CPU first to not run out of memory
79-
sparse_models = Sae.load_many(name, device="cpu")
81+
sparse_models = SparseCoder.load_many(name, device="cpu")
8082
for hookpoint in hookpoints:
8183
sparse_model_dict[hookpoint] = sparse_models[hookpoint].to(device)
8284
if compile:

delphi/sparse_coders/sparse_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import torch.nn as nn
5+
from sparsify import SparseCoder
56
from transformers import PreTrainedModel
67

78
from delphi.config import RunConfig
@@ -74,7 +75,7 @@ def load_sparse_coders(
7475
run_cfg: RunConfig,
7576
device: str | torch.device,
7677
compile: bool = False,
77-
) -> dict[str, nn.Module]:
78+
) -> dict[str, nn.Module] | dict[str, SparseCoder]:
7879
"""
7980
Load sparse coders for specified hookpoints.
8081

0 commit comments

Comments
 (0)