import os import torch from huggingface_hub import login from transformers import AutoProcessor, AutoModelForImageTextToText # Login with your secret token login(token=os.environ["HF_TOKEN"]) MODEL_ID = "google/medgemma-1.5-4b-it" print("Loading MedGemma... this may take a few minutes") processor = AutoProcessor.from_pretrained(MODEL_ID) use_cuda = torch.cuda.is_available() dtype = torch.bfloat16 if use_cuda else torch.float32 device = "cuda:0" if use_cuda else "cpu" model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, torch_dtype=dtype, device_map=device ) model.generation_config.pad_token_id = processor.tokenizer.eos_token_id model.eval() print("MedGemma loaded successfully!") print(f"MedGemma loaded on: {device}") def generate_report(image): """ Takes a PIL Image, returns a generated radiology report string. """ messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": """You are an experienced radiologist. Analyze this brain MRI scan and generate a structured clinical report with the following sections: TECHNIQUE: FINDINGS: IMPRESSION: Be specific, professional and concise."""} ] } ] # Prepare inputs inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, # ← returns a dict, not a raw Tensor return_tensors="pt" ).to(model.device) # Generate the report with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=512, do_sample=False ) # Decode only the newly generated tokens input_length = inputs["input_ids"].shape[1] generated_tokens = outputs[0][input_length:] report = processor.decode(generated_tokens, skip_special_tokens=True) # Clean up the report spacing report = report.replace("FINDINGS:", "\nFINDINGS:") report = report.replace("IMPRESSION:", "\nIMPRESSION:") report = report.replace("TECHNIQUE:", "\nTECHNIQUE:") return report