Update README with model loading code
Browse files
README.md
CHANGED
|
@@ -14,9 +14,29 @@ Part of the [LAPVQA collection](https://huggingface.co/collections/dmusingu/lapv
|
|
| 14 |
|
| 15 |
## Description
|
| 16 |
|
| 17 |
-
Task heads for **Differential VQA
|
| 18 |
-
answer
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
## Results (test set)
|
| 22 |
|
|
@@ -26,14 +46,30 @@ Trained on MIMIC-Diff-VQA with five **frozen** off-the-shelf vision encoders.
|
|
| 26 |
| CoCa | 0.196 | 0.138 | 0.320 | 0.317 |
|
| 27 |
| Florence-2 | 0.191 | 0.138 | 0.319 | 0.318 |
|
| 28 |
| SigLIP | 0.186 | 0.131 | 0.322 | 0.313 |
|
| 29 |
-
| OWLv2 | — | — | — | — |
|
| 30 |
|
| 31 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
## Description
|
| 16 |
|
| 17 |
+
Task heads for **Differential VQA**: given a *prior* and a *current* chest X-ray,
|
| 18 |
+
answer questions about radiological changes. Trained on MIMIC-Diff-VQA with five
|
| 19 |
+
**frozen** encoders. Each `.pt` file is a plain state dict of `DiffVQAHead`.
|
| 20 |
+
|
| 21 |
+
## Architecture — `DiffVQAHead`
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
vis_proj : Linear(vis_dim → 512) # shared for both images
|
| 25 |
+
frame_emb : Embedding(2, 512) # 0=reference, 1=current
|
| 26 |
+
memory : [ref_proj + frame_emb(0) ; curr_proj + frame_emb(1)] → [B, 2N, 512]
|
| 27 |
+
tok_emb : Embedding(50257, 512)
|
| 28 |
+
pos_emb : Embedding(200, 512)
|
| 29 |
+
decoder : 6 × TransformerDecoderLayer (pre-norm)
|
| 30 |
+
lm_head : Linear(512 → 50257, bias=False)
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
| File | Encoder | vis_dim |
|
| 34 |
+
|---|---|---|
|
| 35 |
+
| `clip-vit-l14_best.pt` | CLIP ViT-L/14 | 1024 |
|
| 36 |
+
| `coca_best.pt` | CoCa | 768 |
|
| 37 |
+
| `florence2_best.pt` | Florence-2 | 1024 |
|
| 38 |
+
| `siglip_best.pt` | SigLIP | 1152 |
|
| 39 |
+
| `owlv2_best.pt` | OWLv2 | 1024 |
|
| 40 |
|
| 41 |
## Results (test set)
|
| 42 |
|
|
|
|
| 46 |
| CoCa | 0.196 | 0.138 | 0.320 | 0.317 |
|
| 47 |
| Florence-2 | 0.191 | 0.138 | 0.319 | 0.318 |
|
| 48 |
| SigLIP | 0.186 | 0.131 | 0.322 | 0.313 |
|
|
|
|
| 49 |
|
| 50 |
+
## Loading
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
import torch
|
| 54 |
+
import tiktoken
|
| 55 |
+
from lapvqa.diffvqa.model import DiffVQAHead
|
| 56 |
+
|
| 57 |
+
ckpt = torch.load("coca_best.pt", map_location="cpu")
|
| 58 |
+
head = DiffVQAHead(vis_dim=768) # adjust vis_dim per encoder
|
| 59 |
+
head.load_state_dict(ckpt)
|
| 60 |
+
head.eval()
|
| 61 |
+
|
| 62 |
+
enc = tiktoken.get_encoding("gpt2")
|
| 63 |
+
bos_id = eos_id = enc.eot_token
|
| 64 |
|
| 65 |
+
# curr_vis, ref_vis: [B, N, vis_dim] — patch tokens from the frozen encoder
|
| 66 |
+
answers = head.generate(
|
| 67 |
+
curr_vis = curr_vis,
|
| 68 |
+
ref_vis = ref_vis,
|
| 69 |
+
prompt_ids = question_ids, # [B, Q]
|
| 70 |
+
bos_id = bos_id,
|
| 71 |
+
eos_id = eos_id,
|
| 72 |
+
max_new_tokens = 128,
|
| 73 |
+
)
|
| 74 |
+
decoded = [enc.decode(ids) for ids in answers]
|
| 75 |
+
```
|