File size: 3,158 Bytes
16e8aa3 462cc5a 16e8aa3 dae598b 08aa07f dae598b 16e8aa3 08aa07f 16e8aa3 08aa07f 16e8aa3 7773eb5 08aa07f 7773eb5 462cc5a 7773eb5 462cc5a 7773eb5 904cdd6 7773eb5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | import os
import torch
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForImageTextToText
# Login with your secret token
login(token=os.environ["HF_TOKEN"])
MODEL_ID = "google/medgemma-1.5-4b-it"
PROMPT = """You are a senior consultant radiologist reporting a brain MRI study.
You have been provided with multiple MRI sequences: T1, T2, T2 FLAIR, DWI, ADC, GRE, and T1 with contrast.
Write a structured report using EXACTLY this format:
TECHNIQUE:
MRI of the brain was performed on 1.5T MRI using T1 and T2 weighted sequences in multiple planes along with FLAIR, DWI/ADC, GRE and post-contrast T1 images.
FINDINGS:
- Cerebral parenchyma: [signal intensity, any focal or diffuse changes]
- Diffusion: [any restricted diffusion]
- Haemorrhage/Mass: [presence or absence]
- Extra/Intra axial collections: [midline shift, fluid collections]
- Hippocampi: [signal, volume]
- Basal ganglia, thalami, brainstem and cerebellum: [appearance]
- Sellar/Parasellar region: [pituitary, cavernous sinuses]
- Ventricular system and subarachnoid spaces: [appearance]
- Cranial nerves and cerebellopontine angles: [appearance]
- Intracranial vasculature: [flow voids]
- Paranasal sinuses and mastoid air cells: [appearance]
- Orbits: [appearance]
- Calvarium: [marrow signal]
CONCLUSION:
[Clear single line summary, e.g. 'No abnormality detected' or specific finding]
Rules:
- Never invent clinical history
- If a finding cannot be confidently assessed, say so explicitly
- Be specific about location using standard anatomical terms
- Keep language professional and concise"""
print("Loading MedGemma... this may take a few minutes")
processor = AutoProcessor.from_pretrained(MODEL_ID)
use_cuda = torch.cuda.is_available()
dtype = torch.bfloat16 if use_cuda else torch.float32
device = "cuda:0" if use_cuda else "cpu"
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
device_map=device
)
model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
model.eval()
print("MedGemma loaded successfully!")
print(f"MedGemma loaded on: {device}")
def generate_report(image):
"""
Takes a list of PIL Images (one per MRI sequence),
returns a structured radiology report.
"""
content = []
for img in images:
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": PROMPT})
messages = [{"role": "user", "content": content}]
# Prepare inputs
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True, # ← returns a dict, not a raw Tensor
return_tensors="pt"
).to(model.device)
# Generate the report
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False
)
# Decode only the newly generated tokens
input_length = inputs["input_ids"].shape[1]
generated_tokens = outputs[0][input_length:]
report = processor.decode(generated_tokens, skip_special_tokens=True)
return report |