lapvqa-diffvqa / README.md
dmusingu's picture
Update README with model loading code
bcb405b verified
|
Raw
History Blame Contribute Delete
2.12 kB
---
tags:
- chest-xray
- radiology
- visual-question-answering
- differential-vqa
- mimic-cxr
license: apache-2.0
---
# LAPVQA β€” Differential VQA (Frozen Off-the-shelf Encoders)
Part of the [LAPVQA collection](https://huggingface.co/collections/dmusingu/lapvqa).
## Description
Task heads for **Differential VQA**: given a *prior* and a *current* chest X-ray,
answer questions about radiological changes. Trained on MIMIC-Diff-VQA with five
**frozen** encoders. Each `.pt` file is a plain state dict of `DiffVQAHead`.
## Architecture β€” `DiffVQAHead`
```
vis_proj : Linear(vis_dim β†’ 512) # shared for both images
frame_emb : Embedding(2, 512) # 0=reference, 1=current
memory : [ref_proj + frame_emb(0) ; curr_proj + frame_emb(1)] β†’ [B, 2N, 512]
tok_emb : Embedding(50257, 512)
pos_emb : Embedding(200, 512)
decoder : 6 Γ— TransformerDecoderLayer (pre-norm)
lm_head : Linear(512 β†’ 50257, bias=False)
```
| File | Encoder | vis_dim |
|---|---|---|
| `clip-vit-l14_best.pt` | CLIP ViT-L/14 | 1024 |
| `coca_best.pt` | CoCa | 768 |
| `florence2_best.pt` | Florence-2 | 1024 |
| `siglip_best.pt` | SigLIP | 1152 |
| `owlv2_best.pt` | OWLv2 | 1024 |
## Results (test set)
| Encoder | BLEU-1 | BLEU-4 | ROUGE-1 | RadGraph-s |
|---|---|---|---|---|
| CLIP ViT-L/14 | 0.184 | 0.128 | 0.336 | 0.322 |
| CoCa | 0.196 | 0.138 | 0.320 | 0.317 |
| Florence-2 | 0.191 | 0.138 | 0.319 | 0.318 |
| SigLIP | 0.186 | 0.131 | 0.322 | 0.313 |
## Loading
```python
import torch
import tiktoken
from lapvqa.diffvqa.model import DiffVQAHead
ckpt = torch.load("coca_best.pt", map_location="cpu")
head = DiffVQAHead(vis_dim=768) # adjust vis_dim per encoder
head.load_state_dict(ckpt)
head.eval()
enc = tiktoken.get_encoding("gpt2")
bos_id = eos_id = enc.eot_token
# curr_vis, ref_vis: [B, N, vis_dim] β€” patch tokens from the frozen encoder
answers = head.generate(
curr_vis = curr_vis,
ref_vis = ref_vis,
prompt_ids = question_ids, # [B, Q]
bos_id = bos_id,
eos_id = eos_id,
max_new_tokens = 128,
)
decoded = [enc.decode(ids) for ids in answers]
```