DeepSeeNet / run_inference.py
farrell236's picture
add src
b8c9192
Raw
History Blame Contribute Delete
2.73 kB
"""Run DeepSeeNet inference for AREDS simplified score."""
import argparse
import json
import torch
from PIL import Image
from dataloader import DEFAULT_TRANSFORM
from model import DeepSeeNet
N_CLASSES = {
"ADVAMD": 2,
"DRUS": 3,
"PIG": 2,
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--left-image", required=True)
parser.add_argument("--right-image", required=True)
parser.add_argument("--advamd-checkpoint", required=True)
parser.add_argument("--drus-checkpoint", required=True)
parser.add_argument("--pig-checkpoint", required=True)
parser.add_argument("--backbone", default="inception_v3")
return parser.parse_args()
def load_model(checkpoint_path: str, task: str, backbone: str, device) -> DeepSeeNet:
checkpoint = torch.load(checkpoint_path, map_location=device)
checkpoint_args = checkpoint.get("args", {})
model = DeepSeeNet(
n_classes=N_CLASSES[task],
backbone=checkpoint_args.get("backbone", backbone),
pretrained=False,
).to(device)
model.load_state_dict(checkpoint["model"])
model.eval()
return model
def load_image(path: str, device) -> torch.Tensor:
image = Image.open(path).convert("RGB")
return DEFAULT_TRANSFORM(image).unsqueeze(0).to(device)
@torch.no_grad()
def predict(model: DeepSeeNet, image: torch.Tensor) -> int:
return int(model(image).argmax(dim=1).item())
def simplified_score(scores: dict[str, tuple[int, int]]) -> int:
score = 0
if scores["ADVAMD"][0] or scores["ADVAMD"][1]:
return 5
score += scores["PIG"][0] == 1
score += scores["PIG"][1] == 1
score += scores["DRUS"][0] == 2
score += scores["DRUS"][1] == 2
score += scores["DRUS"][0] == 1 and scores["DRUS"][1] == 1
return min(score, 5)
def main() -> None:
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
images = {
"left": load_image(args.left_image, device),
"right": load_image(args.right_image, device),
}
checkpoints = {
"ADVAMD": args.advamd_checkpoint,
"DRUS": args.drus_checkpoint,
"PIG": args.pig_checkpoint,
}
scores = {}
for task, checkpoint in checkpoints.items():
model = load_model(checkpoint, task, args.backbone, device)
scores[task] = (
predict(model, images["left"]),
predict(model, images["right"]),
)
print(
json.dumps(
{
"simplified_score": simplified_score(scores),
"risk_factors": scores,
},
indent=2,
)
)
if __name__ == "__main__":
main()