Feature Extraction
Transformers
PyTorch
Safetensors
English
magi
Manga
Object Detection
OCR
Clustering
Diarisation
custom_code
Instructions to use ragavsachdeva/magi with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ragavsachdeva/magi with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="ragavsachdeva/magi", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from transformers import ConditionalDetrImageProcessor, TrOCRProcessor, ViTImageProcessor | |
| from transformers.image_transforms import center_to_corners_format | |
| import torch | |
| from typing import List | |
| from shapely.geometry import box | |
| from .utils import UnionFind, sort_panels, sort_text_boxes_in_reading_order, x1y1x2y2_to_xywh | |
| import numpy as np | |
| class MagiProcessor(): | |
| def __init__(self, config): | |
| self.config = config | |
| self.detection_image_preprocessor = None | |
| self.ocr_preprocessor = None | |
| self.crop_embedding_image_preprocessor = None | |
| if not config.disable_detections: | |
| assert config.detection_image_preprocessing_config is not None | |
| self.detection_image_preprocessor = ConditionalDetrImageProcessor.from_dict(config.detection_image_preprocessing_config) | |
| if not config.disable_ocr: | |
| assert config.ocr_pretrained_processor_path is not None | |
| self.ocr_preprocessor = TrOCRProcessor.from_pretrained(config.ocr_pretrained_processor_path) | |
| if not config.disable_crop_embeddings: | |
| assert config.crop_embedding_image_preprocessing_config is not None | |
| self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict(config.crop_embedding_image_preprocessing_config) | |
| def preprocess_inputs_for_detection(self, images, annotations=None): | |
| images = list(images) | |
| assert isinstance(images[0], np.ndarray) | |
| annotations = self._convert_annotations_to_coco_format(annotations) | |
| inputs = self.detection_image_preprocessor(images, annotations=annotations, return_tensors="pt") | |
| return inputs | |
| def preprocess_inputs_for_ocr(self, images): | |
| images = list(images) | |
| assert isinstance(images[0], np.ndarray) | |
| return self.ocr_preprocessor(images, return_tensors="pt").pixel_values | |
| def preprocess_inputs_for_crop_embeddings(self, images): | |
| images = list(images) | |
| assert isinstance(images[0], np.ndarray) | |
| return self.crop_embedding_image_preprocessor(images, return_tensors="pt").pixel_values | |
| def postprocess_detections_and_associations( | |
| self, | |
| predicted_bboxes, | |
| predicted_class_scores, | |
| original_image_sizes, | |
| get_character_character_matching_scores, | |
| get_text_character_matching_scores, | |
| get_dialog_confidence_scores, | |
| character_detection_threshold=0.3, | |
| panel_detection_threshold=0.2, | |
| text_detection_threshold=0.25, | |
| character_character_matching_threshold=0.65, | |
| text_character_matching_threshold=0.4, | |
| ): | |
| assert self.config.disable_detections is False | |
| batch_scores, batch_labels = predicted_class_scores.max(-1) | |
| batch_scores = batch_scores.sigmoid() | |
| batch_labels = batch_labels.long() | |
| batch_bboxes = center_to_corners_format(predicted_bboxes) | |
| # scale the bboxes back to the original image size | |
| if isinstance(original_image_sizes, List): | |
| img_h = torch.Tensor([i[0] for i in original_image_sizes]) | |
| img_w = torch.Tensor([i[1] for i in original_image_sizes]) | |
| else: | |
| img_h, img_w = original_image_sizes.unbind(1) | |
| scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(batch_bboxes.device) | |
| batch_bboxes = batch_bboxes * scale_fct[:, None, :] | |
| batch_panel_indices = self._get_indices_of_panels_to_keep(batch_scores, batch_labels, batch_bboxes, panel_detection_threshold) | |
| batch_character_indices = self._get_indices_of_characters_to_keep(batch_scores, batch_labels, batch_bboxes, character_detection_threshold) | |
| batch_text_indices = self._get_indices_of_texts_to_keep(batch_scores, batch_labels, batch_bboxes, text_detection_threshold) | |
| batch_character_character_matching_scores = get_character_character_matching_scores(batch_character_indices, batch_bboxes) | |
| batch_text_character_matching_scores = get_text_character_matching_scores(batch_text_indices, batch_character_indices) | |
| batch_dialog_confidence_scores = get_dialog_confidence_scores(batch_text_indices) | |
| # sort panels and texts in the reading order | |
| for batch_index in range(len(batch_scores)): | |
| panel_bboxes = batch_bboxes[batch_index][batch_panel_indices[batch_index]] | |
| panel_scores = batch_scores[batch_index][batch_panel_indices[batch_index]] | |
| text_bboxes = batch_bboxes[batch_index][batch_text_indices[batch_index]] | |
| text_scores = batch_scores[batch_index][batch_text_indices[batch_index]] | |
| sorted_panel_indices = sort_panels(panel_bboxes) | |
| batch_bboxes[batch_index][batch_panel_indices[batch_index]] = panel_bboxes[sorted_panel_indices] | |
| batch_scores[batch_index][batch_panel_indices[batch_index]] = panel_scores[sorted_panel_indices] | |
| sorted_panels = batch_bboxes[batch_index][batch_panel_indices[batch_index]] | |
| sorted_text_indices = sort_text_boxes_in_reading_order(text_bboxes, sorted_panels) | |
| batch_bboxes[batch_index][batch_text_indices[batch_index]] = text_bboxes[sorted_text_indices] | |
| batch_scores[batch_index][batch_text_indices[batch_index]] = text_scores[sorted_text_indices] | |
| batch_text_character_matching_scores[batch_index] = batch_text_character_matching_scores[batch_index][sorted_text_indices] | |
| batch_dialog_confidence_scores[batch_index] = batch_dialog_confidence_scores[batch_index][sorted_text_indices] | |
| results = [] | |
| for batch_index in range(len(batch_scores)): | |
| panel_bboxes = batch_bboxes[batch_index][batch_panel_indices[batch_index]] | |
| panel_scores = batch_scores[batch_index][batch_panel_indices[batch_index]] | |
| text_bboxes = batch_bboxes[batch_index][batch_text_indices[batch_index]] | |
| text_scores = batch_scores[batch_index][batch_text_indices[batch_index]] | |
| character_bboxes = batch_bboxes[batch_index][batch_character_indices[batch_index]] | |
| character_scores = batch_scores[batch_index][batch_character_indices[batch_index]] | |
| char_i, char_j = torch.where(batch_character_character_matching_scores[batch_index] > character_character_matching_threshold) | |
| character_character_associations = torch.stack([char_i, char_j], dim=1) | |
| text_boxes_to_match = batch_dialog_confidence_scores[batch_index] > text_character_matching_threshold | |
| if 0 in batch_text_character_matching_scores[batch_index].shape: | |
| text_character_associations = torch.zeros((0, 2), dtype=torch.long) | |
| else: | |
| most_likely_speaker_for_each_text = torch.argmax(batch_text_character_matching_scores[batch_index], dim=1)[text_boxes_to_match] | |
| text_indices = torch.arange(len(text_bboxes)).type_as(most_likely_speaker_for_each_text)[text_boxes_to_match] | |
| text_character_associations = torch.stack([text_indices, most_likely_speaker_for_each_text], dim=1) | |
| character_ufds = UnionFind.from_adj_matrix( | |
| batch_character_character_matching_scores[batch_index] > character_character_matching_threshold | |
| ) | |
| results.append({ | |
| "panels": panel_bboxes.tolist(), | |
| "panel_scores": panel_scores.tolist(), | |
| "texts": text_bboxes.tolist(), | |
| "text_scores": text_scores.tolist(), | |
| "characters": character_bboxes.tolist(), | |
| "character_scores": character_scores.tolist(), | |
| "character_character_associations": character_character_associations.tolist(), | |
| "text_character_associations": text_character_associations.tolist(), | |
| "character_cluster_labels": character_ufds.get_labels_for_connected_components(), | |
| "dialog_confidences": batch_dialog_confidence_scores[batch_index].tolist(), | |
| }) | |
| return results | |
| def postprocess_ocr_tokens(self, generated_ids, skip_special_tokens=True): | |
| return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens) | |
| def crop_image(self, image, bboxes): | |
| crops_for_image = [] | |
| for bbox in bboxes: | |
| x1, y1, x2, y2 = bbox | |
| # fix the bounding box in case it is out of bounds or too small | |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
| x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2) # just incase | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1) | |
| x2, y2 = max(0, x2), max(0, y2) | |
| x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2) | |
| if x2 - x1 < 10: | |
| if image.shape[1] - x1 > 10: | |
| x2 = x1 + 10 | |
| else: | |
| x1 = x2 - 10 | |
| if y2 - y1 < 10: | |
| if image.shape[0] - y1 > 10: | |
| y2 = y1 + 10 | |
| else: | |
| y1 = y2 - 10 | |
| crop = image[y1:y2, x1:x2] | |
| crops_for_image.append(crop) | |
| return crops_for_image | |
| def _get_indices_of_characters_to_keep(self, batch_scores, batch_labels, batch_bboxes, character_detection_threshold): | |
| indices_of_characters_to_keep = [] | |
| for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes): | |
| indices = torch.where((labels == 0) & (scores > character_detection_threshold))[0] | |
| indices_of_characters_to_keep.append(indices) | |
| return indices_of_characters_to_keep | |
| def _get_indices_of_panels_to_keep(self, batch_scores, batch_labels, batch_bboxes, panel_detection_threshold): | |
| indices_of_panels_to_keep = [] | |
| for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes): | |
| indices = torch.where(labels == 2)[0] | |
| bboxes = bboxes[indices] | |
| scores = scores[indices] | |
| labels = labels[indices] | |
| if len(indices) == 0: | |
| indices_of_panels_to_keep.append([]) | |
| continue | |
| scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True)) | |
| panels_to_keep = [] | |
| union_of_panels_so_far = box(0, 0, 0, 0) | |
| for ps, pb, pl, pi in zip(scores, bboxes, labels, indices): | |
| panel_polygon = box(pb[0], pb[1], pb[2], pb[3]) | |
| if ps < panel_detection_threshold: | |
| continue | |
| if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5: | |
| continue | |
| panels_to_keep.append((ps, pl, pb, pi)) | |
| union_of_panels_so_far = union_of_panels_so_far.union(panel_polygon) | |
| indices_of_panels_to_keep.append([p[3].item() for p in panels_to_keep]) | |
| return indices_of_panels_to_keep | |
| def _get_indices_of_texts_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold): | |
| indices_of_texts_to_keep = [] | |
| for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes): | |
| indices = torch.where((labels == 1) & (scores > text_detection_threshold))[0] | |
| bboxes = bboxes[indices] | |
| scores = scores[indices] | |
| labels = labels[indices] | |
| if len(indices) == 0: | |
| indices_of_texts_to_keep.append([]) | |
| continue | |
| scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True)) | |
| texts_to_keep = [] | |
| texts_to_keep_as_shapely_objects = [] | |
| for ts, tb, tl, ti in zip(scores, bboxes, labels, indices): | |
| text_polygon = box(tb[0], tb[1], tb[2], tb[3]) | |
| should_append = True | |
| for t in texts_to_keep_as_shapely_objects: | |
| if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5: | |
| should_append = False | |
| break | |
| if should_append: | |
| texts_to_keep.append((ts, tl, tb, ti)) | |
| texts_to_keep_as_shapely_objects.append(text_polygon) | |
| indices_of_texts_to_keep.append([t[3].item() for t in texts_to_keep]) | |
| return indices_of_texts_to_keep | |
| def _convert_annotations_to_coco_format(self, annotations): | |
| if annotations is None: | |
| return None | |
| self._verify_annotations_are_in_correct_format(annotations) | |
| coco_annotations = [] | |
| for annotation in annotations: | |
| coco_annotation = { | |
| "image_id": annotation["image_id"], | |
| "annotations": [], | |
| } | |
| for bbox, label in zip(annotation["bboxes_as_x1y1x2y2"], annotation["labels"]): | |
| coco_annotation["annotations"].append({ | |
| "bbox": x1y1x2y2_to_xywh(bbox), | |
| "category_id": label, | |
| "area": (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]), | |
| }) | |
| coco_annotations.append(coco_annotation) | |
| return coco_annotations | |
| def _verify_annotations_are_in_correct_format(self, annotations): | |
| error_msg = """ | |
| Annotations must be in the following format: | |
| [ | |
| { | |
| "image_id": 0, | |
| "bboxes_as_x1y1x2y2": [[0, 0, 10, 10], [10, 10, 20, 20], [20, 20, 30, 30]], | |
| "labels": [0, 1, 2], | |
| }, | |
| ... | |
| ] | |
| Labels: 0 for characters, 1 for text, 2 for panels. | |
| """ | |
| if annotations is None: | |
| return | |
| if not isinstance(annotations, List) and not isinstance(annotations, tuple): | |
| raise ValueError( | |
| f"{error_msg} Expected a List/Tuple, found {type(annotations)}." | |
| ) | |
| if len(annotations) == 0: | |
| return | |
| if not isinstance(annotations[0], dict): | |
| raise ValueError( | |
| f"{error_msg} Expected a List[Dict], found {type(annotations[0])}." | |
| ) | |
| if "image_id" not in annotations[0]: | |
| raise ValueError( | |
| f"{error_msg} Dict must contain 'image_id'." | |
| ) | |
| if "bboxes_as_x1y1x2y2" not in annotations[0]: | |
| raise ValueError( | |
| f"{error_msg} Dict must contain 'bboxes_as_x1y1x2y2'." | |
| ) | |
| if "labels" not in annotations[0]: | |
| raise ValueError( | |
| f"{error_msg} Dict must contain 'labels'." | |
| ) | |