| """ |
| Model inference logic for XRD pattern analysis. |
| Loads the pretrained model from HuggingFace Hub and runs predictions. |
| """ |
| import sys |
| from pathlib import Path |
| from typing import Dict, List, Optional |
|
|
| import numpy as np |
| import spglib |
| import torch |
|
|
|
|
| class XRDModelInference: |
| """Handles loading and inference for the XRD analysis model""" |
|
|
| |
| |
| |
| |
| _sg_to_hall: Dict[int, int] = {} |
| for _hall in range(1, 531): |
| _sg_type = spglib.get_spacegroup_type(_hall) |
| _sg_num = _sg_type.number if hasattr(_sg_type, "number") else _sg_type["number"] |
| if _sg_num not in _sg_to_hall: |
| _sg_to_hall[_sg_num] = _hall |
|
|
| HF_REPO_ID = "linked-liszt/OpenAlphaDiffract" |
|
|
| def __init__(self): |
| self.model = None |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| def is_loaded(self) -> bool: |
| """Check if model is loaded""" |
| return self.model is not None |
|
|
| def load_model(self): |
| """Download and load the pretrained model from HuggingFace Hub.""" |
| try: |
| from huggingface_hub import snapshot_download |
|
|
| print(f"Downloading model from {self.HF_REPO_ID}...") |
| model_dir = snapshot_download(self.HF_REPO_ID) |
| print(f"Model downloaded to {model_dir}") |
|
|
| |
| sys.path.insert(0, model_dir) |
| from model import AlphaDiffract |
|
|
| self.model = AlphaDiffract.from_pretrained( |
| model_dir, device=str(self.device) |
| ) |
|
|
| print(f"Model loaded successfully on {self.device}") |
|
|
| except Exception as e: |
| print(f"Error loading model: {e}") |
| import traceback |
| traceback.print_exc() |
| self.model = None |
|
|
| def preprocess_data(self, x: List[float], y: List[float]) -> torch.Tensor: |
| """ |
| Preprocess XRD data for model input. |
| |
| Args: |
| x: 2theta values |
| y: Intensity values |
| |
| Returns: |
| Preprocessed tensor ready for model input |
| """ |
| y_array = np.array(y, dtype=np.float32) |
|
|
| |
| y_array = np.maximum(y_array, 0.0) |
|
|
| |
| y_min = np.min(y_array) |
| y_max = np.max(y_array) |
|
|
| if y_max - y_min < 1e-10: |
| y_scaled = np.zeros_like(y_array, dtype=np.float32) |
| else: |
| y_normalized = (y_array - y_min) / (y_max - y_min) |
| y_scaled = y_normalized * 100.0 |
|
|
| tensor = torch.from_numpy(y_scaled).unsqueeze(0) |
| return tensor.to(self.device) |
|
|
| def predict(self, x: List[float], y: List[float]) -> Dict: |
| """ |
| Run inference on XRD data. |
| |
| Args: |
| x: 2theta values |
| y: Intensity values |
| |
| Returns: |
| Dictionary with predictions and confidence scores |
| """ |
| if self.model is None: |
| return { |
| "status": "error", |
| "error": "Model not loaded.", |
| "http_status": 503, |
| } |
|
|
| try: |
| input_tensor = self.preprocess_data(x, y) |
|
|
| with torch.no_grad(): |
| output = self.model(input_tensor) |
|
|
| processed = self._process_model_output(output) |
| overall_confidence = self._compute_overall_confidence(processed) |
|
|
| predictions = { |
| "status": "success", |
| "predictions": processed, |
| "model_info": { |
| "type": "AlphaDiffract", |
| "device": str(self.device), |
| }, |
| } |
| if overall_confidence is not None: |
| predictions["confidence"] = overall_confidence |
|
|
| return predictions |
|
|
| except Exception as e: |
| return { |
| "status": "error", |
| "error": str(e), |
| "http_status": 500, |
| } |
|
|
| def _process_model_output(self, output) -> Dict: |
| """Process raw model output into readable predictions""" |
| if isinstance(output, dict): |
| predictions = [] |
|
|
| |
| if "cs_logits" in output: |
| cs_logits = output["cs_logits"].cpu() |
| cs_probs = torch.softmax(cs_logits, dim=-1) |
| cs_prob, cs_idx = torch.max(cs_probs, dim=-1) |
|
|
| cs_names = [ |
| "Triclinic", "Monoclinic", "Orthorhombic", "Tetragonal", |
| "Trigonal", "Hexagonal", "Cubic", |
| ] |
|
|
| cs_all_probs = [ |
| { |
| "class_name": cs_names[i], |
| "probability": float(cs_probs[0, i].item()), |
| } |
| for i in range(len(cs_names)) |
| ] |
| cs_all_probs.sort(key=lambda x: x["probability"], reverse=True) |
|
|
| predictions.append({ |
| "phase": "Crystal System", |
| "predicted_class": cs_names[cs_idx.item()], |
| "confidence": float(cs_prob.item()), |
| "all_probabilities": cs_all_probs, |
| }) |
|
|
| |
| if "sg_logits" in output: |
| sg_logits = output["sg_logits"].cpu() |
| sg_probs = torch.softmax(sg_logits, dim=-1) |
| sg_prob, sg_idx = torch.max(sg_probs, dim=-1) |
|
|
| sg_number = sg_idx.item() + 1 |
|
|
| top_k = min(10, sg_probs.shape[-1]) |
| top_probs, top_indices = torch.topk(sg_probs[0], top_k) |
|
|
| sg_top_probs = [ |
| { |
| "space_group_number": int(idx.item()) + 1, |
| "space_group_symbol": self._get_space_group_symbol(int(idx.item()) + 1), |
| "probability": float(prob.item()), |
| } |
| for prob, idx in zip(top_probs, top_indices) |
| ] |
|
|
| predictions.append({ |
| "phase": "Space Group", |
| "predicted_class": f"#{sg_number}", |
| "space_group_symbol": self._get_space_group_symbol(sg_number), |
| "confidence": float(sg_prob.item()), |
| "top_probabilities": sg_top_probs, |
| }) |
|
|
| |
| if "lp" in output: |
| lp_raw = output["lp"].cpu() |
| if lp_raw.shape[0] == 1: |
| lp = lp_raw[0].numpy() |
| else: |
| lp = lp_raw.squeeze().numpy() |
|
|
| lp_labels = ["a", "b", "c", "\u03b1", "\u03b2", "\u03b3"] |
|
|
| predictions.append({ |
| "phase": "Lattice Parameters", |
| "lattice_params": { |
| label: float(val) for label, val in zip(lp_labels, lp) |
| }, |
| "is_lattice": True, |
| }) |
|
|
| return { |
| "phase_predictions": predictions, |
| "intensity_profile": [], |
| } |
|
|
| elif isinstance(output, torch.Tensor): |
| probs = output.cpu().numpy() |
| confidence = None |
| if output.ndim >= 1 and output.shape[-1] > 1: |
| prob_tensor = torch.softmax(output, dim=-1) |
| confidence = float(prob_tensor.max().item()) |
|
|
| predictions = [{"phase": "Predicted Phase", "details": f"Output shape: {probs.shape}"}] |
| if confidence is not None: |
| predictions[0]["confidence"] = confidence |
|
|
| return { |
| "phase_predictions": predictions, |
| "intensity_profile": probs.tolist() if len(probs.shape) == 1 else [], |
| } |
|
|
| return {"phase_predictions": [], "intensity_profile": []} |
|
|
| def _get_space_group_symbol(self, sg_number: int) -> str: |
| """Get space group symbol from number using spglib.""" |
| if sg_number < 1 or sg_number > 230: |
| return f"SG{sg_number}" |
| try: |
| hall_number = self._sg_to_hall.get(sg_number) |
| if hall_number is None: |
| return f"SG{sg_number}" |
| sg_type = spglib.get_spacegroup_type(hall_number) |
| if sg_type is not None: |
| symbol = ( |
| sg_type.international_short |
| if hasattr(sg_type, "international_short") |
| else sg_type["international_short"] |
| ) |
| return symbol |
| return f"SG{sg_number}" |
| except Exception: |
| return f"SG{sg_number}" |
|
|
| def _compute_overall_confidence(self, processed: Dict) -> Optional[float]: |
| """Compute overall confidence from available per-phase confidences.""" |
| phase_predictions = ( |
| processed.get("phase_predictions", []) if isinstance(processed, dict) else [] |
| ) |
| confidences = [ |
| float(p["confidence"]) |
| for p in phase_predictions |
| if isinstance(p, dict) and "confidence" in p and p["confidence"] is not None |
| ] |
| if not confidences: |
| return None |
| return float(np.mean(confidences)) |
|
|