File size: 3,158 Bytes
16e8aa3
 
 
 
 
 
 
 
 
 
462cc5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16e8aa3
 
 
 
dae598b
 
08aa07f
dae598b
16e8aa3
 
08aa07f
 
16e8aa3
 
08aa07f
16e8aa3
 
7773eb5
08aa07f
7773eb5
 
 
462cc5a
 
7773eb5
462cc5a
 
 
 
 
 
 
 
7773eb5
 
 
904cdd6
 
 
 
 
7773eb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import torch
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForImageTextToText

# Login with your secret token
login(token=os.environ["HF_TOKEN"])

MODEL_ID = "google/medgemma-1.5-4b-it"

PROMPT = """You are a senior consultant radiologist reporting a brain MRI study.

You have been provided with multiple MRI sequences: T1, T2, T2 FLAIR, DWI, ADC, GRE, and T1 with contrast.

Write a structured report using EXACTLY this format:

TECHNIQUE:
MRI of the brain was performed on 1.5T MRI using T1 and T2 weighted sequences in multiple planes along with FLAIR, DWI/ADC, GRE and post-contrast T1 images.

FINDINGS:
- Cerebral parenchyma: [signal intensity, any focal or diffuse changes]
- Diffusion: [any restricted diffusion]
- Haemorrhage/Mass: [presence or absence]
- Extra/Intra axial collections: [midline shift, fluid collections]
- Hippocampi: [signal, volume]
- Basal ganglia, thalami, brainstem and cerebellum: [appearance]
- Sellar/Parasellar region: [pituitary, cavernous sinuses]
- Ventricular system and subarachnoid spaces: [appearance]
- Cranial nerves and cerebellopontine angles: [appearance]
- Intracranial vasculature: [flow voids]
- Paranasal sinuses and mastoid air cells: [appearance]
- Orbits: [appearance]
- Calvarium: [marrow signal]

CONCLUSION:
[Clear single line summary,  e.g. 'No abnormality detected' or specific finding]

Rules:
- Never invent clinical history
- If a finding cannot be confidently assessed, say so explicitly
- Be specific about location using standard anatomical terms
- Keep language professional and concise"""

print("Loading MedGemma... this may take a few minutes")

processor = AutoProcessor.from_pretrained(MODEL_ID)

use_cuda = torch.cuda.is_available()
dtype = torch.bfloat16 if use_cuda else torch.float32
device = "cuda:0" if use_cuda else "cpu"

model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    torch_dtype=dtype,
    device_map=device
)

model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
model.eval()

print("MedGemma loaded successfully!")
print(f"MedGemma loaded on: {device}")

def generate_report(image):
    """
    Takes a list of PIL Images (one per MRI sequence),
    returns a structured radiology report.
    """

    content = []
    for img in images:
        content.append({"type": "image", "image": img})
        
    content.append({"type": "text", "text": PROMPT})
    
    messages = [{"role": "user", "content": content}]

    # Prepare inputs
    inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,        # ← returns a dict, not a raw Tensor
    return_tensors="pt"
    ).to(model.device)

    # Generate the report
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False        
        )

    # Decode only the newly generated tokens
    input_length = inputs["input_ids"].shape[1]
    generated_tokens = outputs[0][input_length:]
    report = processor.decode(generated_tokens, skip_special_tokens=True)

    return report