dmusingu commited on
Commit
bcb405b
·
verified ·
1 Parent(s): 0dcb401

Update README with model loading code

Browse files
Files changed (1) hide show
  1. README.md +48 -12
README.md CHANGED
@@ -14,9 +14,29 @@ Part of the [LAPVQA collection](https://huggingface.co/collections/dmusingu/lapv
14
 
15
  ## Description
16
 
17
- Task heads for **Differential VQA (DiffVQA)**: given a *prior* and a *current* chest X-ray,
18
- answer natural-language questions about radiological changes between the two studies.
19
- Trained on MIMIC-Diff-VQA with five **frozen** off-the-shelf vision encoders.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  ## Results (test set)
22
 
@@ -26,14 +46,30 @@ Trained on MIMIC-Diff-VQA with five **frozen** off-the-shelf vision encoders.
26
  | CoCa | 0.196 | 0.138 | 0.320 | 0.317 |
27
  | Florence-2 | 0.191 | 0.138 | 0.319 | 0.318 |
28
  | SigLIP | 0.186 | 0.131 | 0.322 | 0.313 |
29
- | OWLv2 | — | — | — | — |
30
 
31
- ## Files
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- | File | Encoder backbone |
34
- |---|---|
35
- | `clip-vit-l14_best.pt` | CLIP ViT-L/14 |
36
- | `coca_best.pt` | CoCa |
37
- | `florence2_best.pt` | Florence-2 |
38
- | `siglip_best.pt` | SigLIP |
39
- | `owlv2_best.pt` | OWLv2 |
 
 
 
 
 
14
 
15
  ## Description
16
 
17
+ Task heads for **Differential VQA**: given a *prior* and a *current* chest X-ray,
18
+ answer questions about radiological changes. Trained on MIMIC-Diff-VQA with five
19
+ **frozen** encoders. Each `.pt` file is a plain state dict of `DiffVQAHead`.
20
+
21
+ ## Architecture — `DiffVQAHead`
22
+
23
+ ```
24
+ vis_proj : Linear(vis_dim → 512) # shared for both images
25
+ frame_emb : Embedding(2, 512) # 0=reference, 1=current
26
+ memory : [ref_proj + frame_emb(0) ; curr_proj + frame_emb(1)] → [B, 2N, 512]
27
+ tok_emb : Embedding(50257, 512)
28
+ pos_emb : Embedding(200, 512)
29
+ decoder : 6 × TransformerDecoderLayer (pre-norm)
30
+ lm_head : Linear(512 → 50257, bias=False)
31
+ ```
32
+
33
+ | File | Encoder | vis_dim |
34
+ |---|---|---|
35
+ | `clip-vit-l14_best.pt` | CLIP ViT-L/14 | 1024 |
36
+ | `coca_best.pt` | CoCa | 768 |
37
+ | `florence2_best.pt` | Florence-2 | 1024 |
38
+ | `siglip_best.pt` | SigLIP | 1152 |
39
+ | `owlv2_best.pt` | OWLv2 | 1024 |
40
 
41
  ## Results (test set)
42
 
 
46
  | CoCa | 0.196 | 0.138 | 0.320 | 0.317 |
47
  | Florence-2 | 0.191 | 0.138 | 0.319 | 0.318 |
48
  | SigLIP | 0.186 | 0.131 | 0.322 | 0.313 |
 
49
 
50
+ ## Loading
51
+
52
+ ```python
53
+ import torch
54
+ import tiktoken
55
+ from lapvqa.diffvqa.model import DiffVQAHead
56
+
57
+ ckpt = torch.load("coca_best.pt", map_location="cpu")
58
+ head = DiffVQAHead(vis_dim=768) # adjust vis_dim per encoder
59
+ head.load_state_dict(ckpt)
60
+ head.eval()
61
+
62
+ enc = tiktoken.get_encoding("gpt2")
63
+ bos_id = eos_id = enc.eot_token
64
 
65
+ # curr_vis, ref_vis: [B, N, vis_dim] — patch tokens from the frozen encoder
66
+ answers = head.generate(
67
+ curr_vis = curr_vis,
68
+ ref_vis = ref_vis,
69
+ prompt_ids = question_ids, # [B, Q]
70
+ bos_id = bos_id,
71
+ eos_id = eos_id,
72
+ max_new_tokens = 128,
73
+ )
74
+ decoded = [enc.decode(ids) for ids in answers]
75
+ ```