| import os |
| import torch |
| from huggingface_hub import login |
| from transformers import AutoProcessor, AutoModelForImageTextToText |
|
|
| |
| login(token=os.environ["HF_TOKEN"]) |
|
|
| MODEL_ID = "google/medgemma-1.5-4b-it" |
|
|
| print("Loading MedGemma... this may take a few minutes") |
|
|
| processor = AutoProcessor.from_pretrained(MODEL_ID) |
|
|
| use_cuda = torch.cuda.is_available() |
| dtype = torch.bfloat16 if use_cuda else torch.float32 |
| device = "cuda:0" if use_cuda else "cpu" |
|
|
| model = AutoModelForImageTextToText.from_pretrained( |
| MODEL_ID, |
| torch_dtype=dtype, |
| device_map=device |
| ) |
|
|
| model.generation_config.pad_token_id = processor.tokenizer.eos_token_id |
| model.eval() |
|
|
| print("MedGemma loaded successfully!") |
| print(f"MedGemma loaded on: {device}") |
|
|
| def generate_report(image): |
| """ |
| Takes a PIL Image, returns a generated radiology report string. |
| """ |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": """You are an experienced radiologist. |
| Analyze this brain MRI scan and generate a structured clinical report with the following sections: |
| |
| TECHNIQUE: |
| FINDINGS: |
| IMPRESSION: |
| |
| Be specific, professional and concise."""} |
| ] |
| } |
| ] |
|
|
| |
| inputs = processor.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt" |
| ).to(model.device) |
|
|
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=512, |
| do_sample=False |
| ) |
|
|
| |
| input_length = inputs["input_ids"].shape[1] |
| generated_tokens = outputs[0][input_length:] |
| report = processor.decode(generated_tokens, skip_special_tokens=True) |
|
|
| |
| report = report.replace("FINDINGS:", "\nFINDINGS:") |
| report = report.replace("IMPRESSION:", "\nIMPRESSION:") |
| report = report.replace("TECHNIQUE:", "\nTECHNIQUE:") |
|
|
| return report |