| """ |
| LookingGlass Classifiers - Fine-tuned DNA sequence classifiers |
| |
| Pure PyTorch implementation of LookingGlass classifiers from the paper. |
| Uses LookingGlass encoder with classification head. |
| |
| Usage: |
| from lookingglass_classifier import LookingGlassClassifier, LookingGlassTokenizer |
| |
| model = LookingGlassClassifier.from_pretrained('.') |
| tokenizer = LookingGlassTokenizer() |
| |
| inputs = tokenizer(["GATTACA"], return_tensors=True) |
| logits = model(inputs['input_ids']) # (batch, num_classes) |
| predictions = logits.argmax(dim=-1) |
| """ |
|
|
| import json |
| import os |
| from dataclasses import dataclass, asdict, field |
| from typing import Optional, List |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from lookingglass import ( |
| LookingGlassConfig, |
| LookingGlassTokenizer, |
| _AWDLSTMEncoder, |
| _is_hf_hub_id, |
| _download_from_hub, |
| ) |
|
|
| __version__ = "1.1.0" |
| __all__ = ["LookingGlassClassifierConfig", "LookingGlassClassifier", "LookingGlassTokenizer"] |
|
|
|
|
| @dataclass |
| class LookingGlassClassifierConfig(LookingGlassConfig): |
| """Configuration for LookingGlass classifier.""" |
| num_classes: int = 2 |
| classifier_hidden: int = 50 |
| classifier_dropout: float = 0.0 |
| class_names: List[str] = field(default_factory=list) |
|
|
| def save_pretrained(self, save_directory: str): |
| os.makedirs(save_directory, exist_ok=True) |
| with open(os.path.join(save_directory, "config.json"), 'w') as f: |
| json.dump(self.to_dict(), f, indent=2) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_path: str) -> "LookingGlassClassifierConfig": |
| if _is_hf_hub_id(pretrained_path): |
| try: |
| config_path = _download_from_hub(pretrained_path, "config.json") |
| except Exception: |
| return cls() |
| elif os.path.isdir(pretrained_path): |
| config_path = os.path.join(pretrained_path, "config.json") |
| else: |
| config_path = pretrained_path |
|
|
| if os.path.exists(config_path): |
| with open(config_path, 'r') as f: |
| config_dict = json.load(f) |
| valid_fields = {f.name for f in cls.__dataclass_fields__.values()} |
| return cls(**{k: v for k, v in config_dict.items() if k in valid_fields}) |
| return cls() |
|
|
|
|
| class LookingGlassClassifier(nn.Module): |
| """ |
| LookingGlass with classification head. |
| |
| Uses concat pooling (max + mean + last) followed by classification layers. |
| |
| Example: |
| >>> model = LookingGlassClassifier.from_pretrained('.') |
| >>> tokenizer = LookingGlassTokenizer() |
| >>> inputs = tokenizer("GATTACA", return_tensors=True) |
| >>> logits = model(inputs['input_ids']) # (1, num_classes) |
| >>> prediction = logits.argmax(dim=-1) |
| """ |
|
|
| def __init__(self, config: Optional[LookingGlassClassifierConfig] = None): |
| super().__init__() |
| self.config = config or LookingGlassClassifierConfig() |
| self.encoder = _AWDLSTMEncoder(self.config) |
|
|
| |
| pooled_size = 3 * self.config.hidden_size |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.BatchNorm1d(pooled_size), |
| nn.Dropout(self.config.classifier_dropout), |
| nn.Linear(pooled_size, self.config.classifier_hidden), |
| nn.ReLU(), |
| nn.BatchNorm1d(self.config.classifier_hidden), |
| nn.Dropout(self.config.classifier_dropout), |
| nn.Linear(self.config.classifier_hidden, self.config.num_classes), |
| ) |
|
|
| def forward(self, input_ids: torch.LongTensor) -> torch.Tensor: |
| """ |
| Forward pass returning classification logits. |
| |
| Args: |
| input_ids: Token indices (batch, seq_len) |
| |
| Returns: |
| Logits (batch, num_classes) |
| """ |
| self.encoder.reset() |
| hidden = self.encoder(input_ids) |
|
|
| |
| max_pool = hidden.max(dim=1).values |
| mean_pool = hidden.mean(dim=1) |
| last_pool = hidden[:, -1] |
| pooled = torch.cat([max_pool, mean_pool, last_pool], dim=-1) |
|
|
| return self.classifier(pooled) |
|
|
| def predict(self, input_ids: torch.LongTensor) -> torch.Tensor: |
| """Return predicted class indices.""" |
| logits = self.forward(input_ids) |
| return logits.argmax(dim=-1) |
|
|
| def predict_proba(self, input_ids: torch.LongTensor) -> torch.Tensor: |
| """Return class probabilities.""" |
| logits = self.forward(input_ids) |
| return torch.softmax(logits, dim=-1) |
|
|
| def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: |
| """Get sequence embeddings (last token) from encoder.""" |
| self.encoder.reset() |
| hidden = self.encoder(input_ids) |
| return hidden[:, -1] |
|
|
| def save_pretrained(self, save_directory: str): |
| os.makedirs(save_directory, exist_ok=True) |
| self.config.save_pretrained(save_directory) |
| torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) |
|
|
| @classmethod |
| def from_pretrained( |
| cls, pretrained_path: str, config: Optional[LookingGlassClassifierConfig] = None |
| ) -> "LookingGlassClassifier": |
| config = config or LookingGlassClassifierConfig.from_pretrained(pretrained_path) |
| model = cls(config) |
|
|
| if _is_hf_hub_id(pretrained_path): |
| model_path = _download_from_hub(pretrained_path, "pytorch_model.bin") |
| else: |
| model_path = os.path.join(pretrained_path, "pytorch_model.bin") |
|
|
| if os.path.exists(model_path): |
| state_dict = torch.load(model_path, map_location='cpu') |
| model.load_state_dict(state_dict, strict=False) |
|
|
| return model |
|
|
|
|
| def convert_classifier_weights( |
| original_path: str, |
| output_dir: str, |
| num_classes: int, |
| class_names: Optional[List[str]] = None, |
| ): |
| """ |
| Convert original fastai classifier weights to pure PyTorch format. |
| |
| Args: |
| original_path: Path to original .pth file |
| output_dir: Output directory for converted model |
| num_classes: Number of output classes |
| class_names: Optional list of class names |
| """ |
| print(f"Loading weights from {original_path}...") |
| original = torch.load(original_path, map_location='cpu') |
| if 'model' in original: |
| original = original['model'] |
|
|
| |
| config = LookingGlassClassifierConfig( |
| num_classes=num_classes, |
| classifier_hidden=50, |
| class_names=class_names or [], |
| ) |
|
|
| |
| model = LookingGlassClassifier(config) |
|
|
| |
| new_state = {} |
|
|
| |
| weight_map = { |
| '0.module.encoder.weight': 'encoder.embed_tokens.weight', |
| '0.module.encoder_dp.emb.weight': 'encoder.embed_dropout.embedding.weight', |
| } |
|
|
| for i in range(3): |
| weight_map.update({ |
| f'0.module.rnns.{i}.weight_hh_l0_raw': f'encoder.layers.{i}.weight_hh_l0_raw', |
| f'0.module.rnns.{i}.module.weight_ih_l0': f'encoder.layers.{i}.module.weight_ih_l0', |
| f'0.module.rnns.{i}.module.weight_hh_l0': f'encoder.layers.{i}.module.weight_hh_l0', |
| f'0.module.rnns.{i}.module.bias_ih_l0': f'encoder.layers.{i}.module.bias_ih_l0', |
| f'0.module.rnns.{i}.module.bias_hh_l0': f'encoder.layers.{i}.module.bias_hh_l0', |
| }) |
|
|
| |
| |
| classifier_map = { |
| '1.layers.0.weight': 'classifier.0.weight', |
| '1.layers.0.bias': 'classifier.0.bias', |
| '1.layers.0.running_mean': 'classifier.0.running_mean', |
| '1.layers.0.running_var': 'classifier.0.running_var', |
| '1.layers.0.num_batches_tracked': 'classifier.0.num_batches_tracked', |
| '1.layers.2.weight': 'classifier.2.weight', |
| '1.layers.2.bias': 'classifier.2.bias', |
| '1.layers.4.weight': 'classifier.4.weight', |
| '1.layers.4.bias': 'classifier.4.bias', |
| '1.layers.4.running_mean': 'classifier.4.running_mean', |
| '1.layers.4.running_var': 'classifier.4.running_var', |
| '1.layers.4.num_batches_tracked': 'classifier.4.num_batches_tracked', |
| '1.layers.6.weight': 'classifier.6.weight', |
| '1.layers.6.bias': 'classifier.6.bias', |
| } |
| weight_map.update(classifier_map) |
|
|
| for old_key, new_key in weight_map.items(): |
| if old_key in original: |
| new_state[new_key] = original[old_key] |
|
|
| |
| model.load_state_dict(new_state, strict=False) |
|
|
| os.makedirs(output_dir, exist_ok=True) |
| config.save_pretrained(output_dir) |
| torch.save(model.state_dict(), os.path.join(output_dir, "pytorch_model.bin")) |
|
|
| print(f"Saved to {output_dir}") |
| return model |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="Convert LookingGlass classifier weights") |
| parser.add_argument("--input", required=True, help="Path to original .pth file") |
| parser.add_argument("--output", required=True, help="Output directory") |
| parser.add_argument("--num-classes", type=int, required=True, help="Number of classes") |
| parser.add_argument("--class-names", nargs="+", help="Class names") |
|
|
| args = parser.parse_args() |
| convert_classifier_weights(args.input, args.output, args.num_classes, args.class_names) |
|
|