Neemah's picture
Update model.py
dd7ca26 verified
import os
import torch
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForImageTextToText
# Login with your secret token
login(token=os.environ["HF_TOKEN"])
MODEL_ID = "google/medgemma-1.5-4b-it"
print("Loading MedGemma... this may take a few minutes")
processor = AutoProcessor.from_pretrained(MODEL_ID)
use_cuda = torch.cuda.is_available()
dtype = torch.bfloat16 if use_cuda else torch.float32
device = "cuda:0" if use_cuda else "cpu"
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
device_map=device
)
model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
model.eval()
print("MedGemma loaded successfully!")
print(f"MedGemma loaded on: {device}")
def generate_report(image):
"""
Takes a PIL Image, returns a generated radiology report string.
"""
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": """You are an experienced radiologist.
Analyze this brain MRI scan and generate a structured clinical report with the following sections:
TECHNIQUE:
FINDINGS:
IMPRESSION:
Be specific, professional and concise."""}
]
}
]
# Prepare inputs
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True, # ← returns a dict, not a raw Tensor
return_tensors="pt"
).to(model.device)
# Generate the report
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False
)
# Decode only the newly generated tokens
input_length = inputs["input_ids"].shape[1]
generated_tokens = outputs[0][input_length:]
report = processor.decode(generated_tokens, skip_special_tokens=True)
# Clean up the report spacing
report = report.replace("FINDINGS:", "\nFINDINGS:")
report = report.replace("IMPRESSION:", "\nIMPRESSION:")
report = report.replace("TECHNIQUE:", "\nTECHNIQUE:")
return report