OpenAlphaDiffract-UI / app /model_inference.py
linked-liszt's picture
Upload folder using huggingface_hub
6d08d46 verified
"""
Model inference logic for XRD pattern analysis.
Loads the pretrained model from HuggingFace Hub and runs predictions.
"""
import sys
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import spglib
import torch
class XRDModelInference:
"""Handles loading and inference for the XRD analysis model"""
# Build a lookup table mapping space group number (1-230) to the
# corresponding Hall number. spglib.get_spacegroup_type() is indexed
# by Hall number (1-530), NOT by space group number. We pick the
# first (standard-setting) Hall number for each space group.
_sg_to_hall: Dict[int, int] = {}
for _hall in range(1, 531):
_sg_type = spglib.get_spacegroup_type(_hall)
_sg_num = _sg_type.number if hasattr(_sg_type, "number") else _sg_type["number"]
if _sg_num not in _sg_to_hall:
_sg_to_hall[_sg_num] = _hall
HF_REPO_ID = "linked-liszt/OpenAlphaDiffract"
def __init__(self):
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def is_loaded(self) -> bool:
"""Check if model is loaded"""
return self.model is not None
def load_model(self):
"""Download and load the pretrained model from HuggingFace Hub."""
try:
from huggingface_hub import snapshot_download
print(f"Downloading model from {self.HF_REPO_ID}...")
model_dir = snapshot_download(self.HF_REPO_ID)
print(f"Model downloaded to {model_dir}")
# Import the pure-PyTorch model class from the downloaded repo
sys.path.insert(0, model_dir)
from model import AlphaDiffract
self.model = AlphaDiffract.from_pretrained(
model_dir, device=str(self.device)
)
print(f"Model loaded successfully on {self.device}")
except Exception as e:
print(f"Error loading model: {e}")
import traceback
traceback.print_exc()
self.model = None
def preprocess_data(self, x: List[float], y: List[float]) -> torch.Tensor:
"""
Preprocess XRD data for model input.
Args:
x: 2theta values
y: Intensity values
Returns:
Preprocessed tensor ready for model input
"""
y_array = np.array(y, dtype=np.float32)
# Floor at zero (remove any negative values)
y_array = np.maximum(y_array, 0.0)
# Rescale intensity values to [0, 100] range (matching training preprocessing)
y_min = np.min(y_array)
y_max = np.max(y_array)
if y_max - y_min < 1e-10:
y_scaled = np.zeros_like(y_array, dtype=np.float32)
else:
y_normalized = (y_array - y_min) / (y_max - y_min)
y_scaled = y_normalized * 100.0
tensor = torch.from_numpy(y_scaled).unsqueeze(0)
return tensor.to(self.device)
def predict(self, x: List[float], y: List[float]) -> Dict:
"""
Run inference on XRD data.
Args:
x: 2theta values
y: Intensity values
Returns:
Dictionary with predictions and confidence scores
"""
if self.model is None:
return {
"status": "error",
"error": "Model not loaded.",
"http_status": 503,
}
try:
input_tensor = self.preprocess_data(x, y)
with torch.no_grad():
output = self.model(input_tensor)
processed = self._process_model_output(output)
overall_confidence = self._compute_overall_confidence(processed)
predictions = {
"status": "success",
"predictions": processed,
"model_info": {
"type": "AlphaDiffract",
"device": str(self.device),
},
}
if overall_confidence is not None:
predictions["confidence"] = overall_confidence
return predictions
except Exception as e:
return {
"status": "error",
"error": str(e),
"http_status": 500,
}
def _process_model_output(self, output) -> Dict:
"""Process raw model output into readable predictions"""
if isinstance(output, dict):
predictions = []
# Crystal System prediction (7 classes)
if "cs_logits" in output:
cs_logits = output["cs_logits"].cpu()
cs_probs = torch.softmax(cs_logits, dim=-1)
cs_prob, cs_idx = torch.max(cs_probs, dim=-1)
cs_names = [
"Triclinic", "Monoclinic", "Orthorhombic", "Tetragonal",
"Trigonal", "Hexagonal", "Cubic",
]
cs_all_probs = [
{
"class_name": cs_names[i],
"probability": float(cs_probs[0, i].item()),
}
for i in range(len(cs_names))
]
cs_all_probs.sort(key=lambda x: x["probability"], reverse=True)
predictions.append({
"phase": "Crystal System",
"predicted_class": cs_names[cs_idx.item()],
"confidence": float(cs_prob.item()),
"all_probabilities": cs_all_probs,
})
# Space Group prediction (230 classes)
if "sg_logits" in output:
sg_logits = output["sg_logits"].cpu()
sg_probs = torch.softmax(sg_logits, dim=-1)
sg_prob, sg_idx = torch.max(sg_probs, dim=-1)
sg_number = sg_idx.item() + 1
top_k = min(10, sg_probs.shape[-1])
top_probs, top_indices = torch.topk(sg_probs[0], top_k)
sg_top_probs = [
{
"space_group_number": int(idx.item()) + 1,
"space_group_symbol": self._get_space_group_symbol(int(idx.item()) + 1),
"probability": float(prob.item()),
}
for prob, idx in zip(top_probs, top_indices)
]
predictions.append({
"phase": "Space Group",
"predicted_class": f"#{sg_number}",
"space_group_symbol": self._get_space_group_symbol(sg_number),
"confidence": float(sg_prob.item()),
"top_probabilities": sg_top_probs,
})
# Lattice Parameters
if "lp" in output:
lp_raw = output["lp"].cpu()
if lp_raw.shape[0] == 1:
lp = lp_raw[0].numpy()
else:
lp = lp_raw.squeeze().numpy()
lp_labels = ["a", "b", "c", "\u03b1", "\u03b2", "\u03b3"]
predictions.append({
"phase": "Lattice Parameters",
"lattice_params": {
label: float(val) for label, val in zip(lp_labels, lp)
},
"is_lattice": True,
})
return {
"phase_predictions": predictions,
"intensity_profile": [],
}
elif isinstance(output, torch.Tensor):
probs = output.cpu().numpy()
confidence = None
if output.ndim >= 1 and output.shape[-1] > 1:
prob_tensor = torch.softmax(output, dim=-1)
confidence = float(prob_tensor.max().item())
predictions = [{"phase": "Predicted Phase", "details": f"Output shape: {probs.shape}"}]
if confidence is not None:
predictions[0]["confidence"] = confidence
return {
"phase_predictions": predictions,
"intensity_profile": probs.tolist() if len(probs.shape) == 1 else [],
}
return {"phase_predictions": [], "intensity_profile": []}
def _get_space_group_symbol(self, sg_number: int) -> str:
"""Get space group symbol from number using spglib."""
if sg_number < 1 or sg_number > 230:
return f"SG{sg_number}"
try:
hall_number = self._sg_to_hall.get(sg_number)
if hall_number is None:
return f"SG{sg_number}"
sg_type = spglib.get_spacegroup_type(hall_number)
if sg_type is not None:
symbol = (
sg_type.international_short
if hasattr(sg_type, "international_short")
else sg_type["international_short"]
)
return symbol
return f"SG{sg_number}"
except Exception:
return f"SG{sg_number}"
def _compute_overall_confidence(self, processed: Dict) -> Optional[float]:
"""Compute overall confidence from available per-phase confidences."""
phase_predictions = (
processed.get("phase_predictions", []) if isinstance(processed, dict) else []
)
confidences = [
float(p["confidence"])
for p in phase_predictions
if isinstance(p, dict) and "confidence" in p and p["confidence"] is not None
]
if not confidences:
return None
return float(np.mean(confidences))