theapemachine commited on
Commit
0ac64e3
·
1 Parent(s): 9522390

Refactor device handling in benchmark and cortex modules to use resolve_torch_device function for improved device selection. Update README with benchmark table formatting and additional device options.

Browse files
README.md CHANGED
@@ -132,22 +132,22 @@ Cortex includes a comprehensive benchmark harness for comparing base LLMs agains
132
 
133
  ### Standard Benchmarks
134
 
135
- | Task | Type | Choices | Dataset | Few-Shot |
136
- |------|------|---------|---------|----------|
137
- | **HellaSwag** | Commonsense NLI | 4 | `Rowan/hellaswag` | 5-shot |
138
- | **ARC-Easy** | Science QA | 3-5 | `allenai/ai2_arc` | 5-shot |
139
- | **ARC-Challenge** | Science QA (hard) | 3-5 | `allenai/ai2_arc` | 5-shot |
140
- | **PIQA** | Physical intuition | 2 | `gimmaru/piqa` | 0-shot |
141
- | **WinoGrande** | Coreference | 2 | `allenai/winogrande` | 5-shot |
142
- | **MMLU** | Multi-domain knowledge | 4 | `cais/mmlu` | 5-shot |
143
- | **HaluEval** | Hallucination detection | 2 | `pminervini/HaluEval` | 0-shot |
144
 
145
  ### Cortex-Specific Benchmarks
146
 
147
- | Task | Tests | Method |
148
- |------|-------|--------|
149
  | **Passkey Retrieval** | Long-context memory, attention to details | Generation + substring match at 128/256/512/1024 token contexts |
150
- | **Multi-Hop Memory** | Compositional reasoning, fact chaining | Generation + answer extraction from 3-hop fact chains |
151
 
152
  ### Running Benchmarks
153
 
@@ -172,6 +172,8 @@ python -m benchmark.run_benchmark --n 50 --no-memory
172
 
173
  # Custom passkey test
174
  python -m benchmark.run_benchmark --n 20 --passkey-lengths 128 256 512 1024 --n-passkey 10
 
 
175
  ```
176
 
177
  ### Scoring Method
@@ -242,13 +244,13 @@ All modules are independent and composable. Use any combination:
242
 
243
  ## Injection Points
244
 
245
- | Point | Location | Best For |
246
- |-------|----------|----------|
247
- | `PRE_ATTENTION` | Before self-attention | Input preprocessing, prefix injection |
248
- | `POST_ATTENTION` | After attention, before FFN | Memory augmentation (reads enhance attention output) |
249
- | `PRE_FFN` | Before FFN | Gate what the FFN processes |
250
- | `POST_FFN` | After full block | Gating, confidence estimation |
251
- | `RESIDUAL_STREAM` | Wraps entire block | Steering vectors, thinking tokens, backtracking |
252
 
253
  ## Layer Targeting
254
 
 
132
 
133
  ### Standard Benchmarks
134
 
135
+ | Task | Type | Choices | Dataset | Few-Shot |
136
+ |-------------------|-------------------------|---------|-----------------------|----------|
137
+ | **HellaSwag** | Commonsense NLI | 4 | `Rowan/hellaswag` | 5-shot |
138
+ | **ARC-Easy** | Science QA | 3-5 | `allenai/ai2_arc` | 5-shot |
139
+ | **ARC-Challenge** | Science QA (hard) | 3-5 | `allenai/ai2_arc` | 5-shot |
140
+ | **PIQA** | Physical intuition | 2 | `gimmaru/piqa` | 0-shot |
141
+ | **WinoGrande** | Coreference | 2 | `allenai/winogrande` | 5-shot |
142
+ | **MMLU** | Multi-domain knowledge | 4 | `cais/mmlu` | 5-shot |
143
+ | **HaluEval** | Hallucination detection | 2 | `pminervini/HaluEval` | 0-shot |
144
 
145
  ### Cortex-Specific Benchmarks
146
 
147
+ | Task | Tests | Method |
148
+ |-----------------------|-------------------------------------------|-----------------------------------------------------------------|
149
  | **Passkey Retrieval** | Long-context memory, attention to details | Generation + substring match at 128/256/512/1024 token contexts |
150
+ | **Multi-Hop Memory** | Compositional reasoning, fact chaining | Generation + answer extraction from 3-hop fact chains |
151
 
152
  ### Running Benchmarks
153
 
 
172
 
173
  # Custom passkey test
174
  python -m benchmark.run_benchmark --n 20 --passkey-lengths 128 256 512 1024 --n-passkey 10
175
+
176
+ python -m benchmark.run_benchmark --n 10 --model meta-llama/Llama-3.2-1B --tasks hellaswag piqa arc-easy arc-challenge winogrande mmlu
177
  ```
178
 
179
  ### Scoring Method
 
244
 
245
  ## Injection Points
246
 
247
+ | Point | Location | Best For |
248
+ |-------------------|-----------------------------|------------------------------------------------------|
249
+ | `PRE_ATTENTION` | Before self-attention | Input preprocessing, prefix injection |
250
+ | `POST_ATTENTION` | After attention, before FFN | Memory augmentation (reads enhance attention output) |
251
+ | `PRE_FFN` | Before FFN | Gate what the FFN processes |
252
+ | `POST_FFN` | After full block | Gating, confidence estimation |
253
+ | `RESIDUAL_STREAM` | Wraps entire block | Steering vectors, thinking tokens, backtracking |
254
 
255
  ## Layer Targeting
256
 
benchmark/memory_tasks.py CHANGED
@@ -16,6 +16,7 @@ import string
16
  from typing import List, Dict, Optional, Tuple
17
 
18
  from benchmark.scoring import generate_and_check
 
19
 
20
 
21
  class PasskeyRetrieval:
@@ -87,7 +88,7 @@ class PasskeyRetrieval:
87
  model,
88
  tokenizer,
89
  n_per_length: int = 5,
90
- device: str = "cuda",
91
  seed: int = 42,
92
  ) -> Dict:
93
  """
@@ -95,6 +96,8 @@ class PasskeyRetrieval:
95
 
96
  Returns dict with results per context length.
97
  """
 
 
98
  results = {}
99
 
100
  for ctx_len in self.context_lengths:
@@ -224,13 +227,15 @@ class MultiHopMemory:
224
  model,
225
  tokenizer,
226
  n: Optional[int] = None,
227
- device: str = "cuda",
228
  ) -> Dict:
229
  """
230
  Run multi-hop memory benchmark.
231
 
232
  Returns accuracy and per-example results.
233
  """
 
 
234
  chains = self.FACT_CHAINS
235
  if n is not None:
236
  chains = chains[:n]
 
16
  from typing import List, Dict, Optional, Tuple
17
 
18
  from benchmark.scoring import generate_and_check
19
+ from cortex.torch_device import resolve_torch_device
20
 
21
 
22
  class PasskeyRetrieval:
 
88
  model,
89
  tokenizer,
90
  n_per_length: int = 5,
91
+ device: Optional[str] = None,
92
  seed: int = 42,
93
  ) -> Dict:
94
  """
 
96
 
97
  Returns dict with results per context length.
98
  """
99
+ if device is None:
100
+ device = resolve_torch_device("auto")
101
  results = {}
102
 
103
  for ctx_len in self.context_lengths:
 
227
  model,
228
  tokenizer,
229
  n: Optional[int] = None,
230
+ device: Optional[str] = None,
231
  ) -> Dict:
232
  """
233
  Run multi-hop memory benchmark.
234
 
235
  Returns accuracy and per-example results.
236
  """
237
+ if device is None:
238
+ device = resolve_torch_device("auto")
239
  chains = self.FACT_CHAINS
240
  if n is not None:
241
  chains = chains[:n]
benchmark/run_benchmark.py CHANGED
@@ -57,7 +57,7 @@ def main():
57
  )
58
  parser.add_argument(
59
  "--device", type=str, default="auto",
60
- help="Device: cuda, cpu, or auto",
61
  )
62
  parser.add_argument(
63
  "--dtype", type=str, default="float32",
 
57
  )
58
  parser.add_argument(
59
  "--device", type=str, default="auto",
60
+ help="Device: cuda, mps, cpu, or auto (auto: cuda > mps > cpu)",
61
  )
62
  parser.add_argument(
63
  "--dtype", type=str, default="float32",
benchmark/runner.py CHANGED
@@ -21,6 +21,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
  from benchmark.scoring import log_likelihood_score, accuracy_from_loglikelihoods
22
  from benchmark.tasks import TASK_REGISTRY, BenchmarkTask
23
  from benchmark.memory_tasks import PasskeyRetrieval, MultiHopMemory
 
24
 
25
 
26
  class BenchmarkRunner:
@@ -43,7 +44,7 @@ class BenchmarkRunner:
43
  self.model_name = model_name
44
 
45
  if device == "auto":
46
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
47
  else:
48
  self.device = device
49
 
 
21
  from benchmark.scoring import log_likelihood_score, accuracy_from_loglikelihoods
22
  from benchmark.tasks import TASK_REGISTRY, BenchmarkTask
23
  from benchmark.memory_tasks import PasskeyRetrieval, MultiHopMemory
24
+ from cortex.torch_device import resolve_torch_device
25
 
26
 
27
  class BenchmarkRunner:
 
44
  self.model_name = model_name
45
 
46
  if device == "auto":
47
+ self.device = resolve_torch_device("auto")
48
  else:
49
  self.device = device
50
 
benchmark/scoring.py CHANGED
@@ -14,6 +14,8 @@ import torch.nn.functional as F
14
  from typing import List, Optional, Tuple, Dict
15
  import re
16
 
 
 
17
 
18
  @torch.no_grad()
19
  def log_likelihood_score(
@@ -21,7 +23,7 @@ def log_likelihood_score(
21
  tokenizer,
22
  context: str,
23
  continuations: List[str],
24
- device: str = "cuda",
25
  ) -> List[float]:
26
  """
27
  Compute normalized log-likelihood for each continuation given a context.
@@ -36,11 +38,13 @@ def log_likelihood_score(
36
  tokenizer: The tokenizer
37
  context: The prompt/context string
38
  continuations: List of possible continuations to score
39
- device: Device to use
40
 
41
  Returns:
42
  List of normalized log-likelihood scores (higher = model prefers this continuation)
43
  """
 
 
44
  scores = []
45
 
46
  for cont in continuations:
@@ -96,7 +100,7 @@ def generate_and_check(
96
  prompt: str,
97
  expected: str,
98
  max_new_tokens: int = 64,
99
- device: str = "cuda",
100
  exact_match: bool = False,
101
  ) -> Tuple[bool, str]:
102
  """
@@ -108,12 +112,14 @@ def generate_and_check(
108
  prompt: The input prompt
109
  expected: The expected answer string
110
  max_new_tokens: Max tokens to generate
111
- device: Device
112
  exact_match: If True, requires exact match; otherwise substring match
113
 
114
  Returns:
115
  (is_correct, generated_text)
116
  """
 
 
117
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
118
 
119
  # Pad token
 
14
  from typing import List, Optional, Tuple, Dict
15
  import re
16
 
17
+ from cortex.torch_device import resolve_torch_device
18
+
19
 
20
  @torch.no_grad()
21
  def log_likelihood_score(
 
23
  tokenizer,
24
  context: str,
25
  continuations: List[str],
26
+ device: Optional[str] = None,
27
  ) -> List[float]:
28
  """
29
  Compute normalized log-likelihood for each continuation given a context.
 
38
  tokenizer: The tokenizer
39
  context: The prompt/context string
40
  continuations: List of possible continuations to score
41
+ device: Device to use (default: auto — cuda, then mps, then cpu)
42
 
43
  Returns:
44
  List of normalized log-likelihood scores (higher = model prefers this continuation)
45
  """
46
+ if device is None:
47
+ device = resolve_torch_device("auto")
48
  scores = []
49
 
50
  for cont in continuations:
 
100
  prompt: str,
101
  expected: str,
102
  max_new_tokens: int = 64,
103
+ device: Optional[str] = None,
104
  exact_match: bool = False,
105
  ) -> Tuple[bool, str]:
106
  """
 
112
  prompt: The input prompt
113
  expected: The expected answer string
114
  max_new_tokens: Max tokens to generate
115
+ device: Device (default: auto — cuda, then mps, then cpu)
116
  exact_match: If True, requires exact match; otherwise substring match
117
 
118
  Returns:
119
  (is_correct, generated_text)
120
  """
121
+ if device is None:
122
+ device = resolve_torch_device("auto")
123
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
124
 
125
  # Pad token
cortex/steering_vector.py CHANGED
@@ -26,6 +26,7 @@ import torch.nn as nn
26
  import torch.nn.functional as F
27
  from typing import Optional, Union, List, Dict, Tuple
28
  from cortex.core import CortexModule, InjectionPoint
 
29
 
30
 
31
  class SteeringVector(CortexModule):
@@ -120,7 +121,7 @@ class SteeringVector(CortexModule):
120
  negative_prompts: List[str],
121
  tokenizer,
122
  layer_idx: int,
123
- device: str = "cuda"
124
  ) -> torch.Tensor:
125
  """
126
  Extract a steering direction via contrastive activation analysis.
@@ -137,11 +138,13 @@ class SteeringVector(CortexModule):
137
  negative_prompts: Prompts exemplifying the undesired behavior
138
  tokenizer: Model's tokenizer
139
  layer_idx: Which layer to extract from
140
- device: Device
141
 
142
  Returns:
143
  direction: [hidden_dim] steering direction vector
144
  """
 
 
145
  model.eval()
146
 
147
  def get_activations(prompts):
 
26
  import torch.nn.functional as F
27
  from typing import Optional, Union, List, Dict, Tuple
28
  from cortex.core import CortexModule, InjectionPoint
29
+ from cortex.torch_device import resolve_torch_device
30
 
31
 
32
  class SteeringVector(CortexModule):
 
121
  negative_prompts: List[str],
122
  tokenizer,
123
  layer_idx: int,
124
+ device: Optional[str] = None,
125
  ) -> torch.Tensor:
126
  """
127
  Extract a steering direction via contrastive activation analysis.
 
138
  negative_prompts: Prompts exemplifying the undesired behavior
139
  tokenizer: Model's tokenizer
140
  layer_idx: Which layer to extract from
141
+ device: Device (default: auto — cuda, then mps, then cpu)
142
 
143
  Returns:
144
  direction: [hidden_dim] steering direction vector
145
  """
146
+ if device is None:
147
+ device = resolve_torch_device("auto")
148
  model.eval()
149
 
150
  def get_activations(prompts):
cortex/torch_device.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch device selection (CUDA, Apple MPS, CPU)."""
2
+
3
+ import torch
4
+
5
+
6
+ def resolve_torch_device(preference: str = "auto") -> str:
7
+ """
8
+ Resolve a device string for PyTorch.
9
+
10
+ ``auto`` prefers CUDA, then Apple Metal (MPS) on macOS, then CPU.
11
+ Any other string is returned as-is (e.g. ``cuda:0``).
12
+ """
13
+ pref = preference.strip().lower()
14
+ if pref != "auto":
15
+ return preference
16
+ if torch.cuda.is_available():
17
+ return "cuda"
18
+ if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
19
+ return "mps"
20
+ return "cpu"
test_cortex.py CHANGED
@@ -16,6 +16,7 @@ import sys
16
  import os
17
 
18
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
19
  from cortex import (
20
  CortexSurgeon,
21
  MemoryBank,
@@ -29,7 +30,7 @@ import logging
29
  logging.basicConfig(level=logging.INFO, format="%(name)s | %(message)s")
30
 
31
  def main():
32
- device = "cuda" if torch.cuda.is_available() else "cpu"
33
  print(f"\n{'='*60}")
34
  print(f"CORTEX TEST — Device: {device}")
35
  print(f"{'='*60}\n")
 
16
  import os
17
 
18
  from transformers import AutoModelForCausalLM, AutoTokenizer
19
+ from cortex.torch_device import resolve_torch_device
20
  from cortex import (
21
  CortexSurgeon,
22
  MemoryBank,
 
30
  logging.basicConfig(level=logging.INFO, format="%(name)s | %(message)s")
31
 
32
  def main():
33
+ device = resolve_torch_device("auto")
34
  print(f"\n{'='*60}")
35
  print(f"CORTEX TEST — Device: {device}")
36
  print(f"{'='*60}\n")